From 6a35549d8948e5d607d65d637fdf50a2b287536b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 12:59:46 +0200 Subject: [PATCH 01/24] refactor[next]: otf split build and load --- src/gt4py/next/otf/compilation/compiler.py | 29 +++++--- src/gt4py/next/otf/definitions.py | 6 +- src/gt4py/next/otf/recipes.py | 27 +++++++- src/gt4py/next/otf/stages.py | 10 +++ .../program_processors/formatters/gtfn.py | 2 +- .../runners/dace/program.py | 18 ++--- .../runners/dace/workflow/backend.py | 22 +++--- .../runners/dace/workflow/compilation.py | 69 ++++++++++++++++--- .../runners/dace/workflow/factory.py | 47 ++++++++++++- .../next/program_processors/runners/gtfn.py | 47 +++++++++++-- .../test_temporaries_with_sizes.py | 14 ++-- .../iterator_tests/test_builtins.py | 6 +- .../gtfn_tests/test_gtfn_module.py | 4 +- .../runners_tests/test_gtfn.py | 21 +++--- 14 files changed, 254 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..8fa999bb3c 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -48,15 +48,15 @@ def __call__( class Compiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.ExecutableProgram, + stages.BuildArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.ExecutableProgram, + stages.BuildArtifact, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + """Use any build system (via configured factory) to compile a GT4Py program into an on-disk ``BuildArtifact``.""" cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] @@ -65,7 +65,7 @@ class Compiler( def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> stages.ExecutableProgram: + ) -> stages.BuildArtifact: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -83,12 +83,25 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - m = importer.import_from_path( - src_dir / new_data.module, sys_modules_prefix="gt4py.__compiled_programs__." + return stages.BuildArtifact( + src_dir=src_dir, + module=new_data.module, + entry_point_name=new_data.entry_point_name, ) - func = getattr(m, new_data.entry_point_name) - return func + +def load_artifact(artifact: stages.BuildArtifact) -> stages.ExecutableProgram: + """Dynamically import a previously-built module and return its entry point. + + Must run in the process that will ultimately call the returned program, since + the module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. + """ + m = importer.import_from_path( + artifact.src_dir / artifact.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return getattr(m, artifact.entry_point_name) class CompilerFactory(factory.Factory): diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 11b42dc6ce..9e4f7dc586 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -57,12 +57,12 @@ def __call__( class CompilationStep( workflow.Workflow[ - stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.ExecutableProgram + stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" + """Run the build system and produce an on-disk artifact (CompilableSource -> BuildArtifact).""" def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> stages.ExecutableProgram: ... + ) -> stages.BuildArtifact: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 79cd17162b..573c3581fe 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,10 +14,31 @@ @dataclasses.dataclass(frozen=True) -class OTFCompileWorkflow(workflow.NamedStepSequence): - """The typical compiled backend steps composed into a workflow.""" +class OTFBuildWorkflow( + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] +): + """Translation + bindings + build system; ends at an on-disk :class:`stages.BuildArtifact`.""" translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, stages.ExecutableProgram] + compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] + + +@dataclasses.dataclass(frozen=True) +class OTFFinalizeWorkflow( + workflow.NamedStepSequence[stages.BuildArtifact, stages.ExecutableProgram] +): + """Import the built module and apply decoration to get a live callable.""" + + load: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram] + + +@dataclasses.dataclass(frozen=True) +class OTFCompileWorkflow( + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.ExecutableProgram] +): + """Full OTF pipeline: the ``build`` phase ends at a picklable artifact, ``finalize`` rehydrates it.""" + + build: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] + finalize: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index b6816b1cc3..a0a6c6216e 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import pathlib from collections.abc import Callable from typing import Generic, Optional, Protocol, TypeAlias, TypeVar @@ -129,6 +130,15 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable +@dataclasses.dataclass(frozen=True) +class BuildArtifact: + """On-disk result of a compilation: everything a later step needs to import it.""" + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + + def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index 1d65b8d8d0..c20f7a8555 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -17,7 +17,7 @@ @program_formatter.program_formatter def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. - gtfn_translation = gtfn.GTFNBackendFactory().executor.translation # type: ignore[attr-defined] + gtfn_translation = gtfn.GTFNBackendFactory().executor.build.translation # type: ignore[attr-defined] assert isinstance(gtfn_translation, GTFNTranslationStep) return gtfn_translation.generate_stencil_source( program, diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index f8c8fd84a3..c13daa249f 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -76,16 +76,16 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: gt4py_program_args=[p.type for p in program.params], ) - compile_workflow = typing.cast( - recipes.OTFCompileWorkflow, - self.backend.executor - if not hasattr(self.backend.executor, "step") - else self.backend.executor.step, - ) # We know which backend we are using, but we don't know if the compile workflow is cached. + compile_workflow = typing.cast(recipes.OTFCompileWorkflow, self.backend.executor) + build_workflow = ( + compile_workflow.build.step + if hasattr(compile_workflow.build, "step") + else compile_workflow.build + ) # the `build` phase may be wrapped in a `CachedStep` depending on backend configuration. compile_workflow_translation = ( - compile_workflow.translation - if not hasattr(compile_workflow.translation, "step") - else compile_workflow.translation.step + build_workflow.translation.step + if hasattr(build_workflow.translation, "step") + else build_workflow.translation ) # Same for the translation stage, which could be a `CachedStep` depending on backend configuration. # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index de6778a750..935655a422 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -8,6 +8,7 @@ from __future__ import annotations +import dataclasses import warnings from typing import Any, Final @@ -44,7 +45,12 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + lambda o: dataclasses.replace( + o.otf_workflow, + build=workflow.CachedStep( + o.otf_workflow.build, hash_function=o.hash_function + ), + ) ), name_cached="_cached", ) @@ -127,13 +133,13 @@ def make_dace_backend( gpu=gpu, cached=cached, auto_optimize=auto_optimize, - otf_workflow__cached_translation=cached, - otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), - otf_workflow__bare_translation__auto_optimize_args=optimization_args, - otf_workflow__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, - otf_workflow__bare_translation__use_metrics=use_metrics, - otf_workflow__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, - otf_workflow__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, + otf_workflow__build__cached_translation=cached, + otf_workflow__build__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), + otf_workflow__build__bare_translation__auto_optimize_args=optimization_args, + otf_workflow__build__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, + otf_workflow__build__bare_translation__use_metrics=use_metrics, + otf_workflow__build__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, + otf_workflow__build__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index e1747b7ac3..dcbe73454f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -10,6 +10,8 @@ import dataclasses import os +import pathlib +import types import warnings from collections.abc import Callable, MutableSequence, Sequence from typing import Any @@ -114,19 +116,28 @@ def __call__(self, **kwargs: Any) -> None: assert result is None +@dataclasses.dataclass(frozen=True) +class DaCeBuildArtifact: + """On-disk result of a DaCe compilation.""" + + build_folder: pathlib.Path + binding_source_code: str + bind_func_name: str + + @dataclasses.dataclass(frozen=True) class DaCeCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - CompiledDaceProgram, + DaCeBuildArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - CompiledDaceProgram, + DaCeBuildArtifact, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], ): - """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + """Run the DaCe build system and produce an on-disk :class:`DaCeBuildArtifact`.""" bind_func_name: str cache_lifetime: config.BuildCacheLifetime @@ -136,7 +147,7 @@ class DaCeCompiler( def __call__( self, inp: stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - ) -> CompiledDaceProgram: + ) -> DaCeBuildArtifact: with gtx_wfdcommon.dace_context( device_type=self.device_type, cmake_build_type=self.cmake_build_type, @@ -147,16 +158,56 @@ def __call__( sdfg = dace.SDFG.from_json(inp.program_source.source_code) sdfg.build_folder = sdfg_build_folder with locking.lock(sdfg_build_folder): - sdfg_program = sdfg.compile(validate=False) + sdfg.compile(validate=False, return_program_handle=False) assert inp.binding_source is not None - return CompiledDaceProgram( - sdfg_program, - self.bind_func_name, - inp.binding_source, + return DaCeBuildArtifact( + build_folder=pathlib.Path(sdfg_build_folder), + binding_source_code=inp.binding_source.source_code, + bind_func_name=self.bind_func_name, ) +@dataclasses.dataclass(frozen=True) +class DaCeLoader( + workflow.ChainableWorkflowMixin[DaCeBuildArtifact, CompiledDaceProgram], + workflow.ReplaceEnabledWorkflowMixin[DaCeBuildArtifact, CompiledDaceProgram], +): + """Rehydrate a :class:`DaCeBuildArtifact` into a live :class:`CompiledDaceProgram`.""" + + device_type: core_defs.DeviceType + cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + + def __call__(self, artifact: DaCeBuildArtifact) -> CompiledDaceProgram: + for dump_name in ("program.sdfgz", "program.sdfg"): + sdfg_dump = artifact.build_folder / dump_name + if sdfg_dump.exists(): + break + else: + raise RuntimeError( + f"No SDFG dump (program.sdfgz / program.sdfg) found in '{artifact.build_folder}'." + ) + + sdfg = dace.SDFG.from_file(str(sdfg_dump)) + sdfg.build_folder = str(artifact.build_folder) + + with gtx_wfdcommon.dace_context( + device_type=self.device_type, + cmake_build_type=self.cmake_build_type, + ): + # use_cache=True forces DaCe to load the existing .so without re-codegen. + with dace.config.set_temporary("compiler", "use_cache", value=True): + sdfg_program = sdfg.compile(validate=False) + + binding_source_shim = types.SimpleNamespace(source_code=artifact.binding_source_code) + return CompiledDaceProgram(sdfg_program, artifact.bind_func_name, binding_source_shim) # type: ignore[arg-type] + + class DaCeCompilationStepFactory(factory.Factory): class Meta: model = DaCeCompiler + + +class DaCeLoaderFactory(factory.Factory): + class Meta: + model = DaCeLoader diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 62febd0965..12441587c5 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -22,6 +22,7 @@ ) from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, + DaCeLoaderFactory, ) from gt4py.next.program_processors.runners.dace.workflow.translation import ( DaCeTranslationStepFactory, @@ -31,9 +32,9 @@ _GT_DACE_BINDING_FUNCTION_NAME: Final[str] = "update_sdfg_args" -class DaCeWorkflowFactory(factory.Factory): +class DaCeBuildWorkflowFactory(factory.Factory): class Meta: - model = recipes.OTFCompileWorkflow + model = recipes.OTFBuildWorkflow class Params: auto_optimize: bool = False @@ -72,9 +73,51 @@ class Params: device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) + + +class DaCeFinalizeWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFFinalizeWorkflow + + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough + lambda: config.CMAKE_BUILD_TYPE + ) + + load = factory.SubFactory( + DaCeLoaderFactory, + device_type=factory.SelfAttribute("..device_type"), + cmake_build_type=factory.SelfAttribute("..cmake_build_type"), + ) decoration = factory.LazyAttribute( lambda o: functools.partial( decoration_step.convert_args, device=o.device_type, ) ) + + +class DaCeWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFCompileWorkflow + + class Params: + auto_optimize: bool = False + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough + lambda: config.CMAKE_BUILD_TYPE + ) + cached_translation = factory.Trait(build__cached_translation=True) + + build = factory.SubFactory( + DaCeBuildWorkflowFactory, + device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), + cmake_build_type=factory.SelfAttribute("..cmake_build_type"), + ) + finalize = factory.SubFactory( + DaCeFinalizeWorkflowFactory, + device_type=factory.SelfAttribute("..device_type"), + cmake_build_type=factory.SelfAttribute("..cmake_build_type"), + ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c1743dea6a..6a600d5b5f 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import dataclasses import functools from typing import Any @@ -106,9 +107,9 @@ def extract_connectivity_args( return args -class GTFNCompileWorkflowFactory(factory.Factory): +class GTFNBuildWorkflowFactory(factory.Factory): class Meta: - model = recipes.OTFCompileWorkflow + model = recipes.OTFBuildWorkflow class Params: device_type: core_defs.DeviceType = core_defs.DeviceType.CPU @@ -144,11 +145,37 @@ class Params: cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), ) + + +class GTFNFinalizeWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFFinalizeWorkflow + + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + + load = factory.LazyFunction(lambda: compiler.load_artifact) decoration = factory.LazyAttribute( lambda o: functools.partial(convert_args, device=o.device_type) ) +class GTFNCompileWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFCompileWorkflow + + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cached_translation = factory.Trait(build__cached_translation=True) + + build = factory.SubFactory( + GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + ) + finalize = factory.SubFactory( + GTFNFinalizeWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + ) + + class GTFNBackendFactory(factory.Factory): class Meta: model = backend.Backend @@ -165,7 +192,12 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + lambda o: dataclasses.replace( + o.otf_workflow, + build=workflow.CachedStep( + o.otf_workflow.build, hash_function=o.hash_function + ), + ) ), name_cached="_cached", ) @@ -187,17 +219,18 @@ class Params: run_gtfn = GTFNBackendFactory() run_gtfn_imperative = GTFNBackendFactory( - name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True + name_postfix="_imperative", + otf_workflow__build__translation__use_imperative_backend=True, ) -run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__build__cached_translation=True) run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory( - gpu=True, cached=True, otf_workflow__cached_translation=True + gpu=True, cached=True, otf_workflow__build__cached_translation=True ) run_gtfn_no_transforms = GTFNBackendFactory( - otf_workflow__bare_translation__enable_itir_transforms=False + otf_workflow__build__bare_translation__enable_itir_transforms=False ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 90c0d775f2..c3058f33a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -34,12 +34,14 @@ def exec_alloc_descriptor(): name="run_gtfn_with_temporaries_and_sizes", transforms=backend.DEFAULT_TRANSFORMS, executor=run_gtfn.executor.replace( - translation=run_gtfn.executor.translation.replace( - symbolic_domain_sizes={ - "Cell": "num_cells", - "Edge": "num_edges", - "Vertex": "num_vertices", - } + build=run_gtfn.executor.build.replace( + translation=run_gtfn.executor.build.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + } + ) ) ), allocator=run_gtfn.allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 4fff5192aa..9d39b5d63a 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -192,7 +192,11 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): gtfn_without_transforms = dataclasses.replace( run_gtfn, executor=run_gtfn.executor.replace( - translation=run_gtfn.executor.translation.replace(enable_itir_transforms=False), + build=run_gtfn.executor.build.replace( + translation=run_gtfn.executor.build.translation.replace( + enable_itir_transforms=False + ), + ), ), # avoid inlining the function ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index d027c9dcb1..3611abad89 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -135,11 +135,11 @@ def test_gtfn_file_cache(program_example): ) cached_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=True - ).executor.step.translation + ).executor.build.step.translation bare_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=False - ).executor.step.translation + ).executor.build.step.translation cache_key = stages.fingerprint_compilable_program(compilable_program) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 96d8c6e27c..f088761ffd 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -34,11 +34,11 @@ def test_backend_factory_trait_device(): assert cpu_version.name == "run_gtfn_cpu" assert gpu_version.name == "run_gtfn_gpu" - assert cpu_version.executor.translation.device_type is core_defs.DeviceType.CPU - assert gpu_version.executor.translation.device_type is core_defs.DeviceType.CUDA + assert cpu_version.executor.build.translation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.build.translation.device_type is core_defs.DeviceType.CUDA - assert cpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CPU - assert gpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CUDA + assert cpu_version.executor.finalize.decoration.keywords["device"] is core_defs.DeviceType.CPU + assert gpu_version.executor.finalize.decoration.keywords["device"] is core_defs.DeviceType.CUDA assert custom_layout_allocators.is_field_allocator_for( cpu_version.allocator, core_defs.DeviceType.CPU @@ -50,7 +50,7 @@ def test_backend_factory_trait_device(): def test_backend_factory_trait_cached(): cached_version = gtfn.GTFNBackendFactory(gpu=False, cached=True) - assert isinstance(cached_version.executor, workflow.CachedStep) + assert isinstance(cached_version.executor.build, workflow.CachedStep) assert cached_version.name == "run_gtfn_cpu_cached" @@ -60,9 +60,12 @@ def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.PERSISTENT) persistent_version = gtfn.GTFNBackendFactory() - assert session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION assert ( - persistent_version.executor.compilation.cache_lifetime + session_version.executor.build.compilation.cache_lifetime + is config.BuildCacheLifetime.SESSION + ) + assert ( + persistent_version.executor.build.compilation.cache_lifetime is config.BuildCacheLifetime.PERSISTENT ) @@ -74,10 +77,10 @@ def test_backend_factory_build_type_config(monkeypatch): min_size_version = gtfn.GTFNBackendFactory() assert ( - release_version.executor.compilation.builder_factory.cmake_build_type + release_version.executor.build.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.RELEASE ) assert ( - min_size_version.executor.compilation.builder_factory.cmake_build_type + min_size_version.executor.build.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.MIN_SIZE_REL ) From d2510573866e50dba2597b22d72379e7b1653824 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 15:04:56 +0200 Subject: [PATCH 02/24] refactorings to combine load+decoration --- src/gt4py/next/otf/compilation/compiler.py | 45 +++++----- src/gt4py/next/otf/definitions.py | 16 ++-- src/gt4py/next/otf/recipes.py | 51 +++++++---- src/gt4py/next/otf/stages.py | 10 --- .../runners/dace/workflow/compilation.py | 89 ++++++++++--------- .../runners/dace/workflow/factory.py | 32 +------ .../next/program_processors/runners/gtfn.py | 39 ++++---- .../runners_tests/test_gtfn.py | 5 +- 8 files changed, 142 insertions(+), 145 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 8fa999bb3c..9d77d50b07 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -14,10 +14,10 @@ import factory -from gt4py._core import locking +from gt4py._core import definitions as core_defs, locking from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow -from gt4py.next.otf.compilation import build_data, cache, importer +from gt4py.next.otf.compilation import build_data, cache T = TypeVar("T") @@ -44,28 +44,44 @@ def __call__( ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... +@dataclasses.dataclass(frozen=True) +class GTFNBuildArtifact: + """On-disk result of a GTFN compilation: a Python extension module. + + Bindings are baked into the .so via nanobind, so the load step is just an + ``importlib`` import + entry-point symbol lookup. ``device_type`` is + intrinsic to the artifact: a CPU-built .so cannot be loaded as GPU. + """ + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + device_type: core_defs.DeviceType + + @dataclasses.dataclass(frozen=True) class Compiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.BuildArtifact, + GTFNBuildArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.BuildArtifact, + GTFNBuildArtifact, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Use any build system (via configured factory) to compile a GT4Py program into an on-disk ``BuildArtifact``.""" + """Use any build system (via configured factory) to compile a GT4Py program into a :class:`GTFNBuildArtifact`.""" cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + device_type: core_defs.DeviceType force_recompile: bool = False def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> stages.BuildArtifact: + ) -> GTFNBuildArtifact: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -83,27 +99,14 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - return stages.BuildArtifact( + return GTFNBuildArtifact( src_dir=src_dir, module=new_data.module, entry_point_name=new_data.entry_point_name, + device_type=self.device_type, ) -def load_artifact(artifact: stages.BuildArtifact) -> stages.ExecutableProgram: - """Dynamically import a previously-built module and return its entry point. - - Must run in the process that will ultimately call the returned program, since - the module is registered in that process's ``sys.modules`` under the - ``gt4py.__compiled_programs__.`` prefix. - """ - m = importer.import_from_path( - artifact.src_dir / artifact.module, - sys_modules_prefix="gt4py.__compiled_programs__.", - ) - return getattr(m, artifact.entry_point_name) - - class CompilerFactory(factory.Factory): class Meta: model = Compiler diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 9e4f7dc586..c242e02aa2 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Protocol, TypeAlias, TypeVar +from typing import Any, Protocol, TypeAlias, TypeVar from gt4py.next.ffront import stages as ffront_stages from gt4py.next.iterator import ir as itir @@ -56,13 +56,17 @@ def __call__( class CompilationStep( - workflow.Workflow[ - stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact - ], + workflow.Workflow[stages.CompilableProject[CodeSpecT, TargetCodeSpecT], Any], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Run the build system and produce an on-disk artifact (CompilableSource -> BuildArtifact).""" + """Run the build system and produce an on-disk, backend-specific build artifact. + + The artifact type is intentionally :class:`Any` here — each backend defines + its own concrete dataclass (frozen, picklable). The build/finalize boundary + in :class:`recipes.OTFCompileWorkflow` only requires that whatever + ``CompilationStep`` produces is what the backend's ``finalize`` consumes. + """ def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> stages.BuildArtifact: ... + ) -> Any: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 573c3581fe..0ded8426aa 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -9,36 +9,53 @@ from __future__ import annotations import dataclasses +from typing import Any from gt4py.next.otf import definitions, stages, workflow @dataclasses.dataclass(frozen=True) class OTFBuildWorkflow( - workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] + workflow.NamedStepSequence[definitions.CompilableProgramDef, Any] ): - """Translation + bindings + build system; ends at an on-disk :class:`stages.BuildArtifact`.""" + """Translation + bindings + build system; ends at an on-disk artifact. - translation: definitions.TranslationStep - bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] + The artifact type is backend-specific (e.g. a ``GTFNBuildArtifact`` or a + ``DaCeBuildArtifact``); a workflow only ever pairs a backend's build with + that same backend's finalize, so no cross-backend artifact protocol is + needed. + Grouped as a sub-workflow so the ``cached=True`` backend trait can wrap + just this sub-workflow in a :class:`workflow.CachedStep` — caching keys + on :class:`definitions.CompilableProgramDef` and values on a picklable, + backend-specific artifact dataclass. + """ -@dataclasses.dataclass(frozen=True) -class OTFFinalizeWorkflow( - workflow.NamedStepSequence[stages.BuildArtifact, stages.ExecutableProgram] -): - """Import the built module and apply decoration to get a live callable.""" - - load: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] - decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram] + translation: definitions.TranslationStep + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] + compilation: workflow.Workflow[stages.CompilableProject, Any] @dataclasses.dataclass(frozen=True) class OTFCompileWorkflow( workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.ExecutableProgram] ): - """Full OTF pipeline: the ``build`` phase ends at a picklable artifact, ``finalize`` rehydrates it.""" - - build: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] - finalize: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] + """Full OTF pipeline: two phases separated by an on-disk artifact boundary. + + 1. ``build`` — produces a picklable, backend-specific build artifact. + Heavy, idempotent, parallelizable across processes; the natural cache + target. + 2. ``finalize`` — rehydrates the artifact into a directly-callable + :class:`stages.ExecutableProgram`. Backend-internal; whatever + sequence of "load the .so / wrap with gt4py calling convention / + attach metrics" the backend needs. + + The artifact dataclass is the contract between these two phases. By + convention, artifacts are frozen dataclasses, picklable across process + boundaries, and self-describing (carry every property finalize needs, + e.g. ``device_type``). Each backend defines its own; nothing about that + contract is enforced by this module — it is per-backend convention. + """ + + build: workflow.Workflow[definitions.CompilableProgramDef, Any] + finalize: workflow.Workflow[Any, stages.ExecutableProgram] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index a0a6c6216e..b6816b1cc3 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -9,7 +9,6 @@ from __future__ import annotations import dataclasses -import pathlib from collections.abc import Callable from typing import Generic, Optional, Protocol, TypeAlias, TypeVar @@ -130,15 +129,6 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable -@dataclasses.dataclass(frozen=True) -class BuildArtifact: - """On-disk result of a compilation: everything a later step needs to import it.""" - - src_dir: pathlib.Path - module: pathlib.Path - entry_point_name: str - - def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index dcbe73454f..3215d44c10 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -11,7 +11,6 @@ import dataclasses import os import pathlib -import types import warnings from collections.abc import Callable, MutableSequence, Sequence from typing import Any @@ -57,7 +56,7 @@ def __init__( self, program: dace.CompiledSDFG, bind_func_name: str, - binding_source: stages.BindingSource[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], + binding_source_code: str, ): self.sdfg_program = program @@ -66,9 +65,10 @@ def __init__( # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. self.sdfg_argtypes = list(program.sdfg.arglist().values()) - # Note that `binding_source` contains Python code tailored to this specific SDFG. - # Here we dinamically compile this function and add it to the compiled program. - exec(binding_source.source_code, global_namespace := {}) # type: ignore[var-annotated] + # The binding source code is Python tailored to this specific SDFG. + # We dynamically compile that function and add it to the compiled program. + global_namespace: dict[str, Any] = {} + exec(binding_source_code, global_namespace) self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] # For debug purpose, we set a unique module name on the compiled function. self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) @@ -118,11 +118,18 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) class DaCeBuildArtifact: - """On-disk result of a DaCe compilation.""" + """On-disk result of a DaCe compilation. + + Carries the ``device_type`` the artifact was built for; a CPU-built .so + cannot be loaded as GPU. Also carries the bindings (Python source code + that the loader ``exec``\\ s to materialize the SDFG argument-marshalling + function). + """ build_folder: pathlib.Path binding_source_code: str bind_func_name: str + device_type: core_defs.DeviceType @dataclasses.dataclass(frozen=True) @@ -165,49 +172,49 @@ def __call__( build_folder=pathlib.Path(sdfg_build_folder), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, + device_type=self.device_type, ) -@dataclasses.dataclass(frozen=True) -class DaCeLoader( - workflow.ChainableWorkflowMixin[DaCeBuildArtifact, CompiledDaceProgram], - workflow.ReplaceEnabledWorkflowMixin[DaCeBuildArtifact, CompiledDaceProgram], -): - """Rehydrate a :class:`DaCeBuildArtifact` into a live :class:`CompiledDaceProgram`.""" - - device_type: core_defs.DeviceType - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG - - def __call__(self, artifact: DaCeBuildArtifact) -> CompiledDaceProgram: - for dump_name in ("program.sdfgz", "program.sdfg"): - sdfg_dump = artifact.build_folder / dump_name - if sdfg_dump.exists(): - break - else: - raise RuntimeError( - f"No SDFG dump (program.sdfgz / program.sdfg) found in '{artifact.build_folder}'." - ) +def dace_finalize(artifact: DaCeBuildArtifact) -> stages.ExecutableProgram: + """Turn a :class:`DaCeBuildArtifact` into a directly-callable program. + + Re-deserializes the SDFG dump from the build folder, links against the + pre-built .so via ``compiler.use_cache=True`` (no re-codegen), wraps it + in a :class:`CompiledDaceProgram`, and applies gt4py's calling convention + via :func:`decoration.convert_args`. Reads the target device from the + artifact. + + Must run in the process that will ultimately call the returned program. + """ + # Local import to avoid a circular reference (decoration imports compilation). + from gt4py.next.program_processors.runners.dace.workflow import ( + decoration as gtx_wfddecoration, + ) + + for dump_name in ("program.sdfgz", "program.sdfg"): + sdfg_dump = artifact.build_folder / dump_name + if sdfg_dump.exists(): + break + else: + raise RuntimeError( + f"No SDFG dump (program.sdfgz / program.sdfg) found in '{artifact.build_folder}'." + ) - sdfg = dace.SDFG.from_file(str(sdfg_dump)) - sdfg.build_folder = str(artifact.build_folder) + sdfg = dace.SDFG.from_file(str(sdfg_dump)) + sdfg.build_folder = str(artifact.build_folder) - with gtx_wfdcommon.dace_context( - device_type=self.device_type, - cmake_build_type=self.cmake_build_type, - ): - # use_cache=True forces DaCe to load the existing .so without re-codegen. - with dace.config.set_temporary("compiler", "use_cache", value=True): - sdfg_program = sdfg.compile(validate=False) + with gtx_wfdcommon.dace_context(device_type=artifact.device_type): + # use_cache=True forces DaCe to load the existing .so without re-codegen. + with dace.config.set_temporary("compiler", "use_cache", value=True): + sdfg_program = sdfg.compile(validate=False) - binding_source_shim = types.SimpleNamespace(source_code=artifact.binding_source_code) - return CompiledDaceProgram(sdfg_program, artifact.bind_func_name, binding_source_shim) # type: ignore[arg-type] + program = CompiledDaceProgram( + sdfg_program, artifact.bind_func_name, artifact.binding_source_code + ) + return gtx_wfddecoration.convert_args(program, device=artifact.device_type) class DaCeCompilationStepFactory(factory.Factory): class Meta: model = DaCeCompiler - - -class DaCeLoaderFactory(factory.Factory): - class Meta: - model = DaCeLoader diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 12441587c5..5855ef5cc4 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -18,11 +18,10 @@ from gt4py.next.otf import recipes, stages, workflow from gt4py.next.program_processors.runners.dace.workflow import ( bindings as bindings_step, - decoration as decoration_step, + compilation as compilation_step, ) from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, - DaCeLoaderFactory, ) from gt4py.next.program_processors.runners.dace.workflow.translation import ( DaCeTranslationStepFactory, @@ -75,29 +74,6 @@ class Params: ) -class DaCeFinalizeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFFinalizeWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE - ) - - load = factory.SubFactory( - DaCeLoaderFactory, - device_type=factory.SelfAttribute("..device_type"), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - decoration_step.convert_args, - device=o.device_type, - ) - ) - - class DaCeWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFCompileWorkflow @@ -116,8 +92,4 @@ class Params: auto_optimize=factory.SelfAttribute("..auto_optimize"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - finalize = factory.SubFactory( - DaCeFinalizeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) + finalize = factory.LazyFunction(lambda: compilation_step.dace_finalize) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 6a600d5b5f..01610ec0c4 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -import functools from typing import Any import factory @@ -21,7 +20,7 @@ from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler +from gt4py.next.otf.compilation import compiler, importer from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -107,6 +106,24 @@ def extract_connectivity_args( return args +def gtfn_finalize(artifact: compiler.GTFNBuildArtifact) -> stages.ExecutableProgram: + """Turn a :class:`compiler.GTFNBuildArtifact` into a directly-callable program. + + Imports the .so as a Python extension module and wraps the entry point in + gt4py's calling convention (argument conversion, device-aware connectivity + handling, metric collection). Reads the target device from the artifact. + + Must run in the process that will ultimately call the returned program; + the module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. + """ + m = importer.import_from_path( + artifact.src_dir / artifact.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return convert_args(getattr(m, artifact.entry_point_name), device=artifact.device_type) + + class GTFNBuildWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFBuildWorkflow @@ -144,19 +161,7 @@ class Params: compiler.CompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), - ) - - -class GTFNFinalizeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFFinalizeWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - - load = factory.LazyFunction(lambda: compiler.load_artifact) - decoration = factory.LazyAttribute( - lambda o: functools.partial(convert_args, device=o.device_type) + device_type=factory.SelfAttribute("..device_type"), ) @@ -171,9 +176,7 @@ class Params: build = factory.SubFactory( GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) - finalize = factory.SubFactory( - GTFNFinalizeWorkflowFactory, device_type=factory.SelfAttribute("..device_type") - ) + finalize = factory.LazyFunction(lambda: gtfn_finalize) class GTFNBackendFactory(factory.Factory): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index f088761ffd..cd3bddb19a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -37,8 +37,9 @@ def test_backend_factory_trait_device(): assert cpu_version.executor.build.translation.device_type is core_defs.DeviceType.CPU assert gpu_version.executor.build.translation.device_type is core_defs.DeviceType.CUDA - assert cpu_version.executor.finalize.decoration.keywords["device"] is core_defs.DeviceType.CPU - assert gpu_version.executor.finalize.decoration.keywords["device"] is core_defs.DeviceType.CUDA + # The compilation step now also carries device_type so it can stamp the artifact. + assert cpu_version.executor.build.compilation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.build.compilation.device_type is core_defs.DeviceType.CUDA assert custom_layout_allocators.is_field_allocator_for( cpu_version.allocator, core_defs.DeviceType.CPU From b8efa44840f24ab2a7ba7d057371d8bdc926b1d6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 15:28:14 +0200 Subject: [PATCH 03/24] BuildArtifact can materialize() itself --- src/gt4py/next/otf/compilation/compiler.py | 27 +++++- src/gt4py/next/otf/definitions.py | 17 ++-- src/gt4py/next/otf/recipes.py | 60 ++++++++------ src/gt4py/next/otf/stages.py | 35 ++++++++ .../runners/dace/workflow/compilation.py | 83 ++++++++++--------- .../runners/dace/workflow/factory.py | 5 +- .../next/program_processors/runners/gtfn.py | 24 +----- 7 files changed, 151 insertions(+), 100 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 9d77d50b07..f1b076b55c 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -17,7 +17,7 @@ from gt4py._core import definitions as core_defs, locking from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow -from gt4py.next.otf.compilation import build_data, cache +from gt4py.next.otf.compilation import build_data, cache, importer T = TypeVar("T") @@ -48,9 +48,10 @@ def __call__( class GTFNBuildArtifact: """On-disk result of a GTFN compilation: a Python extension module. - Bindings are baked into the .so via nanobind, so the load step is just an - ``importlib`` import + entry-point symbol lookup. ``device_type`` is - intrinsic to the artifact: a CPU-built .so cannot be loaded as GPU. + Bindings are baked into the .so via nanobind, so materialization is just + an ``importlib`` import + entry-point symbol lookup, plus a wrapping in + gt4py's calling convention. ``device_type`` is intrinsic to the artifact: + a CPU-built .so cannot be loaded as GPU. """ src_dir: pathlib.Path @@ -58,6 +59,24 @@ class GTFNBuildArtifact: entry_point_name: str device_type: core_defs.DeviceType + def materialize(self) -> stages.ExecutableProgram: + """Bring the artifact up as a directly-callable program. + + Must run in the process that will ultimately call the returned + program; the imported module is registered in that process's + ``sys.modules`` under the ``gt4py.__compiled_programs__.`` prefix. + """ + # Imported lazily to avoid a circular module dependency: ``runners.gtfn`` + # imports this module to construct the workflow, while the + # gt4py-shaped argument-conversion lives there. + from gt4py.next.program_processors.runners.gtfn import convert_args + + m = importer.import_from_path( + self.src_dir / self.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return convert_args(getattr(m, self.entry_point_name), device=self.device_type) + @dataclasses.dataclass(frozen=True) class Compiler( diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index c242e02aa2..1fe56a1f11 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Protocol, TypeAlias, TypeVar +from typing import Protocol, TypeAlias, TypeVar from gt4py.next.ffront import stages as ffront_stages from gt4py.next.iterator import ir as itir @@ -56,17 +56,18 @@ def __call__( class CompilationStep( - workflow.Workflow[stages.CompilableProject[CodeSpecT, TargetCodeSpecT], Any], + workflow.Workflow[ + stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact + ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Run the build system and produce an on-disk, backend-specific build artifact. + """Run the build system and produce a :class:`stages.BuildArtifact`. - The artifact type is intentionally :class:`Any` here — each backend defines - its own concrete dataclass (frozen, picklable). The build/finalize boundary - in :class:`recipes.OTFCompileWorkflow` only requires that whatever - ``CompilationStep`` produces is what the backend's ``finalize`` consumes. + Each backend defines its own concrete artifact dataclass (frozen, + picklable, self-materializing); they all satisfy the + :class:`stages.BuildArtifact` Protocol structurally. """ def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> Any: ... + ) -> stages.BuildArtifact: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 0ded8426aa..88188aebbd 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -9,53 +9,63 @@ from __future__ import annotations import dataclasses -from typing import Any from gt4py.next.otf import definitions, stages, workflow +def materialize_artifact(artifact: stages.BuildArtifact) -> stages.ExecutableProgram: + """Default ``finalize`` step for :class:`OTFCompileWorkflow`. + + Universal across backends: dispatches into the artifact's own + :meth:`stages.BuildArtifact.materialize` method. The dispatch happens + through ordinary Python method resolution on the artifact's concrete + type — no separate registry, no backend-specific finalize plumbing. + """ + return artifact.materialize() + + @dataclasses.dataclass(frozen=True) class OTFBuildWorkflow( - workflow.NamedStepSequence[definitions.CompilableProgramDef, Any] + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] ): - """Translation + bindings + build system; ends at an on-disk artifact. + """Translation + bindings + build system; ends at a :class:`stages.BuildArtifact`. - The artifact type is backend-specific (e.g. a ``GTFNBuildArtifact`` or a - ``DaCeBuildArtifact``); a workflow only ever pairs a backend's build with - that same backend's finalize, so no cross-backend artifact protocol is - needed. + The artifact's concrete type is backend-specific (e.g. ``GTFNBuildArtifact`` + or ``DaCeBuildArtifact``); both share only the + :class:`stages.BuildArtifact` Protocol — frozen, picklable, self- + materializing. Grouped as a sub-workflow so the ``cached=True`` backend trait can wrap just this sub-workflow in a :class:`workflow.CachedStep` — caching keys - on :class:`definitions.CompilableProgramDef` and values on a picklable, - backend-specific artifact dataclass. + on :class:`definitions.CompilableProgramDef` and values on a picklable + artifact. """ translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, Any] + compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] @dataclasses.dataclass(frozen=True) class OTFCompileWorkflow( workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.ExecutableProgram] ): - """Full OTF pipeline: two phases separated by an on-disk artifact boundary. + """Full OTF pipeline: build an artifact, then materialize a callable. - 1. ``build`` — produces a picklable, backend-specific build artifact. - Heavy, idempotent, parallelizable across processes; the natural cache - target. + 1. ``build`` — produces a picklable :class:`stages.BuildArtifact`. Heavy, + idempotent, parallelizable; the natural cache target. 2. ``finalize`` — rehydrates the artifact into a directly-callable - :class:`stages.ExecutableProgram`. Backend-internal; whatever - sequence of "load the .so / wrap with gt4py calling convention / - attach metrics" the backend needs. - - The artifact dataclass is the contract between these two phases. By - convention, artifacts are frozen dataclasses, picklable across process - boundaries, and self-describing (carry every property finalize needs, - e.g. ``device_type``). Each backend defines its own; nothing about that - contract is enforced by this module — it is per-backend convention. + :class:`stages.ExecutableProgram`. Defaults to + :func:`materialize_artifact`, which dispatches through the artifact's + own :meth:`stages.BuildArtifact.materialize` — backend-specific code + lives on the artifact, not in a sibling free function. + + Backends typically only configure ``build``; ``finalize`` falls through + to the artifact's own materialization logic. Override ``finalize`` only + to wrap the entire post-build phase (e.g. add a tracing wrapper). """ - build: workflow.Workflow[definitions.CompilableProgramDef, Any] - finalize: workflow.Workflow[Any, stages.ExecutableProgram] + build: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] + finalize: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] = ( + materialize_artifact + ) diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index b6816b1cc3..8448db9334 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -129,6 +129,41 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable +class BuildArtifact(Protocol): + """A picklable, self-contained, compiled gt4py program in transit. + + A *build artifact* is the output of the ``build`` phase of + :class:`recipes.OTFCompileWorkflow` — the explicit boundary between the + build phase (heavy, idempotent, parallelizable, picklable output) and the + live-callable phase (cheap, process-bound). + + Each backend defines its own concrete artifact dataclass, carrying + whatever fields it needs to bring up a runnable callable in any process + that has the backend module on the import path. Conventions: + + 1. **Frozen dataclass.** Implementations are + ``@dataclasses.dataclass(frozen=True)`` so they have value semantics + (hashable, structurally equatable) for use as cache keys. + + 2. **Picklable.** Implementations round-trip safely through :mod:`pickle` + so they can cross process boundaries: process-pool / distributed + compilation, AOT pipelines that build now and run later from a + different process, persistent caches keyed on the artifact. Live, + process-bound state (open files, ``ctypes`` handles, imported Python + modules) is therefore not allowed in the artifact — that is what + :meth:`materialize` rehydrates. + + 3. **Self-materializing.** Calling :meth:`materialize` returns a + directly-callable :class:`ExecutableProgram` taking gt4py-shaped + arguments. The method body is the backend's full post-build + sequence (load the .so, wrap with the calling convention, attach + metric hooks, etc.). Receivers don't need to know which backend + produced the artifact — they just call ``artifact.materialize()``. + """ + + def materialize(self) -> ExecutableProgram: ... + + def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 3215d44c10..9bbb035997 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -122,8 +122,8 @@ class DaCeBuildArtifact: Carries the ``device_type`` the artifact was built for; a CPU-built .so cannot be loaded as GPU. Also carries the bindings (Python source code - that the loader ``exec``\\ s to materialize the SDFG argument-marshalling - function). + that materialization ``exec``\\ s to bring the SDFG argument-marshalling + function into existence). """ build_folder: pathlib.Path @@ -131,6 +131,46 @@ class DaCeBuildArtifact: bind_func_name: str device_type: core_defs.DeviceType + def materialize(self) -> stages.ExecutableProgram: + """Bring the artifact up as a directly-callable program. + + Re-deserializes the SDFG dump from the build folder, links against + the pre-built .so via ``compiler.use_cache=True`` (no re-codegen), + wraps it in a :class:`CompiledDaceProgram`, and applies gt4py's + calling convention via :func:`decoration.convert_args`. + + Must run in the process that will ultimately call the returned + program; the imported binding code is bound into a per-call namespace + within the produced :class:`CompiledDaceProgram`. + """ + # Imported lazily to avoid a circular module dependency: + # ``decoration`` imports this module. + from gt4py.next.program_processors.runners.dace.workflow import ( + decoration as gtx_wfddecoration, + ) + + for dump_name in ("program.sdfgz", "program.sdfg"): + sdfg_dump = self.build_folder / dump_name + if sdfg_dump.exists(): + break + else: + raise RuntimeError( + f"No SDFG dump (program.sdfgz / program.sdfg) found in '{self.build_folder}'." + ) + + sdfg = dace.SDFG.from_file(str(sdfg_dump)) + sdfg.build_folder = str(self.build_folder) + + with gtx_wfdcommon.dace_context(device_type=self.device_type): + # use_cache=True forces DaCe to load the existing .so without re-codegen. + with dace.config.set_temporary("compiler", "use_cache", value=True): + sdfg_program = sdfg.compile(validate=False) + + program = CompiledDaceProgram( + sdfg_program, self.bind_func_name, self.binding_source_code + ) + return gtx_wfddecoration.convert_args(program, device=self.device_type) + @dataclasses.dataclass(frozen=True) class DaCeCompiler( @@ -176,45 +216,6 @@ def __call__( ) -def dace_finalize(artifact: DaCeBuildArtifact) -> stages.ExecutableProgram: - """Turn a :class:`DaCeBuildArtifact` into a directly-callable program. - - Re-deserializes the SDFG dump from the build folder, links against the - pre-built .so via ``compiler.use_cache=True`` (no re-codegen), wraps it - in a :class:`CompiledDaceProgram`, and applies gt4py's calling convention - via :func:`decoration.convert_args`. Reads the target device from the - artifact. - - Must run in the process that will ultimately call the returned program. - """ - # Local import to avoid a circular reference (decoration imports compilation). - from gt4py.next.program_processors.runners.dace.workflow import ( - decoration as gtx_wfddecoration, - ) - - for dump_name in ("program.sdfgz", "program.sdfg"): - sdfg_dump = artifact.build_folder / dump_name - if sdfg_dump.exists(): - break - else: - raise RuntimeError( - f"No SDFG dump (program.sdfgz / program.sdfg) found in '{artifact.build_folder}'." - ) - - sdfg = dace.SDFG.from_file(str(sdfg_dump)) - sdfg.build_folder = str(artifact.build_folder) - - with gtx_wfdcommon.dace_context(device_type=artifact.device_type): - # use_cache=True forces DaCe to load the existing .so without re-codegen. - with dace.config.set_temporary("compiler", "use_cache", value=True): - sdfg_program = sdfg.compile(validate=False) - - program = CompiledDaceProgram( - sdfg_program, artifact.bind_func_name, artifact.binding_source_code - ) - return gtx_wfddecoration.convert_args(program, device=artifact.device_type) - - class DaCeCompilationStepFactory(factory.Factory): class Meta: model = DaCeCompiler diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 5855ef5cc4..4eeced2341 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -18,7 +18,6 @@ from gt4py.next.otf import recipes, stages, workflow from gt4py.next.program_processors.runners.dace.workflow import ( bindings as bindings_step, - compilation as compilation_step, ) from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, @@ -92,4 +91,6 @@ class Params: auto_optimize=factory.SelfAttribute("..auto_optimize"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - finalize = factory.LazyFunction(lambda: compilation_step.dace_finalize) + # ``finalize`` is left at its OTFCompileWorkflow default + # (``stages.materialize_artifact``), which dispatches via the artifact's + # own :meth:`stages.BuildArtifact.materialize` method. diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 01610ec0c4..80c918df34 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -20,7 +20,7 @@ from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler, importer +from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -106,24 +106,6 @@ def extract_connectivity_args( return args -def gtfn_finalize(artifact: compiler.GTFNBuildArtifact) -> stages.ExecutableProgram: - """Turn a :class:`compiler.GTFNBuildArtifact` into a directly-callable program. - - Imports the .so as a Python extension module and wraps the entry point in - gt4py's calling convention (argument conversion, device-aware connectivity - handling, metric collection). Reads the target device from the artifact. - - Must run in the process that will ultimately call the returned program; - the module is registered in that process's ``sys.modules`` under the - ``gt4py.__compiled_programs__.`` prefix. - """ - m = importer.import_from_path( - artifact.src_dir / artifact.module, - sys_modules_prefix="gt4py.__compiled_programs__.", - ) - return convert_args(getattr(m, artifact.entry_point_name), device=artifact.device_type) - - class GTFNBuildWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFBuildWorkflow @@ -176,7 +158,9 @@ class Params: build = factory.SubFactory( GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) - finalize = factory.LazyFunction(lambda: gtfn_finalize) + # ``finalize`` is left at its OTFCompileWorkflow default + # (``stages.materialize_artifact``), which dispatches via the artifact's + # own :meth:`stages.BuildArtifact.materialize` method. class GTFNBackendFactory(factory.Factory): From 915ce278a2962d48c889724fbbcf4135cabbbf9e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 16:36:14 +0200 Subject: [PATCH 04/24] remove OTFCompileWorkflow --- src/gt4py/next/backend.py | 5 +- src/gt4py/next/otf/recipes.py | 49 +++---------------- src/gt4py/next/otf/stages.py | 8 +-- .../program_processors/formatters/gtfn.py | 2 +- .../runners/dace/program.py | 16 +++--- .../runners/dace/workflow/__init__.py | 2 +- .../runners/dace/workflow/backend.py | 22 +++------ .../runners/dace/workflow/factory.py | 25 +--------- .../next/program_processors/runners/gtfn.py | 34 +++---------- .../test_temporaries_with_sizes.py | 14 +++--- .../iterator_tests/test_builtins.py | 6 +-- .../otf_tests/test_compiled_program.py | 13 +++-- .../gtfn_tests/test_gtfn_module.py | 4 +- .../runners_tests/test_gtfn.py | 18 +++---- 14 files changed, 69 insertions(+), 149 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 6123da97e0..b7ad2b2d2c 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -147,16 +147,17 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): name: str - executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef] def compile( self, program: definitions.IRDefinitionT, compile_time_args: arguments.CompileTimeArgs ) -> stages.ExecutableProgram: - return self.executor( + artifact = self.executor( self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args)) ) + return artifact.materialize() @property def __gt_allocator__( diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 88188aebbd..c1af3388ae 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -13,17 +13,6 @@ from gt4py.next.otf import definitions, stages, workflow -def materialize_artifact(artifact: stages.BuildArtifact) -> stages.ExecutableProgram: - """Default ``finalize`` step for :class:`OTFCompileWorkflow`. - - Universal across backends: dispatches into the artifact's own - :meth:`stages.BuildArtifact.materialize` method. The dispatch happens - through ordinary Python method resolution on the artifact's concrete - type — no separate registry, no backend-specific finalize plumbing. - """ - return artifact.materialize() - - @dataclasses.dataclass(frozen=True) class OTFBuildWorkflow( workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] @@ -33,39 +22,17 @@ class OTFBuildWorkflow( The artifact's concrete type is backend-specific (e.g. ``GTFNBuildArtifact`` or ``DaCeBuildArtifact``); both share only the :class:`stages.BuildArtifact` Protocol — frozen, picklable, self- - materializing. - - Grouped as a sub-workflow so the ``cached=True`` backend trait can wrap - just this sub-workflow in a :class:`workflow.CachedStep` — caching keys - on :class:`definitions.CompilableProgramDef` and values on a picklable + materializing. The whole post-build phase lives on the artifact itself + (``artifact.materialize()`` returns the directly-callable program); this + workflow's job is just to produce the artifact. + + Used directly as :attr:`gt4py.next.backend.Backend.executor`. The + ``cached=True`` backend trait wraps this whole workflow in a + :class:`workflow.CachedStep` — caching keys on + :class:`definitions.CompilableProgramDef` and values on a picklable artifact. """ translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] - - -@dataclasses.dataclass(frozen=True) -class OTFCompileWorkflow( - workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.ExecutableProgram] -): - """Full OTF pipeline: build an artifact, then materialize a callable. - - 1. ``build`` — produces a picklable :class:`stages.BuildArtifact`. Heavy, - idempotent, parallelizable; the natural cache target. - 2. ``finalize`` — rehydrates the artifact into a directly-callable - :class:`stages.ExecutableProgram`. Defaults to - :func:`materialize_artifact`, which dispatches through the artifact's - own :meth:`stages.BuildArtifact.materialize` — backend-specific code - lives on the artifact, not in a sibling free function. - - Backends typically only configure ``build``; ``finalize`` falls through - to the artifact's own materialization logic. Override ``finalize`` only - to wrap the entire post-build phase (e.g. add a tracing wrapper). - """ - - build: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] - finalize: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] = ( - materialize_artifact - ) diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 8448db9334..35cfe0e425 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -132,10 +132,10 @@ def build(self) -> None: ... class BuildArtifact(Protocol): """A picklable, self-contained, compiled gt4py program in transit. - A *build artifact* is the output of the ``build`` phase of - :class:`recipes.OTFCompileWorkflow` — the explicit boundary between the - build phase (heavy, idempotent, parallelizable, picklable output) and the - live-callable phase (cheap, process-bound). + A *build artifact* is the output of an :class:`recipes.OTFBuildWorkflow` + (the value of :attr:`gt4py.next.backend.Backend.executor`) — the explicit + boundary between the build phase (heavy, idempotent, parallelizable, + picklable output) and the live-callable phase (cheap, process-bound). Each backend defines its own concrete artifact dataclass, carrying whatever fields it needs to bring up a runnable callable in any process diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index c20f7a8555..1d65b8d8d0 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -17,7 +17,7 @@ @program_formatter.program_formatter def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. - gtfn_translation = gtfn.GTFNBackendFactory().executor.build.translation # type: ignore[attr-defined] + gtfn_translation = gtfn.GTFNBackendFactory().executor.translation # type: ignore[attr-defined] assert isinstance(gtfn_translation, GTFNTranslationStep) return gtfn_translation.generate_stencil_source( program, diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index c13daa249f..310b9634ba 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -76,17 +76,19 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: gt4py_program_args=[p.type for p in program.params], ) - compile_workflow = typing.cast(recipes.OTFCompileWorkflow, self.backend.executor) - build_workflow = ( - compile_workflow.build.step - if hasattr(compile_workflow.build, "step") - else compile_workflow.build - ) # the `build` phase may be wrapped in a `CachedStep` depending on backend configuration. + # ``backend.executor`` is an :class:`recipes.OTFBuildWorkflow`, optionally wrapped + # in a :class:`workflow.CachedStep` when ``cached=True``. + build_workflow = typing.cast( + recipes.OTFBuildWorkflow, + self.backend.executor.step + if hasattr(self.backend.executor, "step") + else self.backend.executor, + ) compile_workflow_translation = ( build_workflow.translation.step if hasattr(build_workflow.translation, "step") else build_workflow.translation - ) # Same for the translation stage, which could be a `CachedStep` depending on backend configuration. + ) # the translation stage could also be a `CachedStep` depending on backend configuration. # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with # the other parts of the workaround when possible. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py index 4d825c0c9b..f822709cd2 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py @@ -10,7 +10,7 @@ The main module is `backend`, that exports the backends for CPU and GPU devices. The `backend` module uses `factory` to define a workflow that implements the -`OTFCompileWorkflow` recipe. The different stages are implemeted in separate modules: +`OTFBuildWorkflow` recipe. The different stages are implemeted in separate modules: - `translation` for lowering of GTIR to SDFG and applying SDFG transformations - `compilation` for compiling the SDFG into a program - `decoration` to parse the program arguments and pass them to the program call diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index 935655a422..de6778a750 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -8,7 +8,6 @@ from __future__ import annotations -import dataclasses import warnings from typing import Any, Final @@ -45,12 +44,7 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: dataclasses.replace( - o.otf_workflow, - build=workflow.CachedStep( - o.otf_workflow.build, hash_function=o.hash_function - ), - ) + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) ), name_cached="_cached", ) @@ -133,13 +127,13 @@ def make_dace_backend( gpu=gpu, cached=cached, auto_optimize=auto_optimize, - otf_workflow__build__cached_translation=cached, - otf_workflow__build__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), - otf_workflow__build__bare_translation__auto_optimize_args=optimization_args, - otf_workflow__build__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, - otf_workflow__build__bare_translation__use_metrics=use_metrics, - otf_workflow__build__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, - otf_workflow__build__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, + otf_workflow__cached_translation=cached, + otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), + otf_workflow__bare_translation__auto_optimize_args=optimization_args, + otf_workflow__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, + otf_workflow__bare_translation__use_metrics=use_metrics, + otf_workflow__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, + otf_workflow__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 4eeced2341..f022cdea64 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -30,7 +30,7 @@ _GT_DACE_BINDING_FUNCTION_NAME: Final[str] = "update_sdfg_args" -class DaCeBuildWorkflowFactory(factory.Factory): +class DaCeWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFBuildWorkflow @@ -71,26 +71,3 @@ class Params: device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - auto_optimize: bool = False - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE - ) - cached_translation = factory.Trait(build__cached_translation=True) - - build = factory.SubFactory( - DaCeBuildWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - auto_optimize=factory.SelfAttribute("..auto_optimize"), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - # ``finalize`` is left at its OTFCompileWorkflow default - # (``stages.materialize_artifact``), which dispatches via the artifact's - # own :meth:`stages.BuildArtifact.materialize` method. diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 80c918df34..f865384271 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import dataclasses from typing import Any import factory @@ -147,22 +146,6 @@ class Params: ) -class GTFNCompileWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cached_translation = factory.Trait(build__cached_translation=True) - - build = factory.SubFactory( - GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") - ) - # ``finalize`` is left at its OTFCompileWorkflow default - # (``stages.materialize_artifact``), which dispatches via the artifact's - # own :meth:`stages.BuildArtifact.materialize` method. - - class GTFNBackendFactory(factory.Factory): class Meta: model = backend.Backend @@ -179,19 +162,14 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: dataclasses.replace( - o.otf_workflow, - build=workflow.CachedStep( - o.otf_workflow.build, hash_function=o.hash_function - ), - ) + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) ), name_cached="_cached", ) device_type = core_defs.DeviceType.CPU hash_function = stages.compilation_hash otf_workflow = factory.SubFactory( - GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) name = factory.LazyAttribute( @@ -207,17 +185,17 @@ class Params: run_gtfn_imperative = GTFNBackendFactory( name_postfix="_imperative", - otf_workflow__build__translation__use_imperative_backend=True, + otf_workflow__translation__use_imperative_backend=True, ) -run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__build__cached_translation=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory( - gpu=True, cached=True, otf_workflow__build__cached_translation=True + gpu=True, cached=True, otf_workflow__cached_translation=True ) run_gtfn_no_transforms = GTFNBackendFactory( - otf_workflow__build__bare_translation__enable_itir_transforms=False + otf_workflow__bare_translation__enable_itir_transforms=False ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index c3058f33a8..90c0d775f2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -34,14 +34,12 @@ def exec_alloc_descriptor(): name="run_gtfn_with_temporaries_and_sizes", transforms=backend.DEFAULT_TRANSFORMS, executor=run_gtfn.executor.replace( - build=run_gtfn.executor.build.replace( - translation=run_gtfn.executor.build.translation.replace( - symbolic_domain_sizes={ - "Cell": "num_cells", - "Edge": "num_edges", - "Vertex": "num_vertices", - } - ) + translation=run_gtfn.executor.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + } ) ), allocator=run_gtfn.allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 9d39b5d63a..4fff5192aa 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -192,11 +192,7 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): gtfn_without_transforms = dataclasses.replace( run_gtfn, executor=run_gtfn.executor.replace( - build=run_gtfn.executor.build.replace( - translation=run_gtfn.executor.build.translation.replace( - enable_itir_transforms=False - ), - ), + translation=run_gtfn.executor.translation.replace(enable_itir_transforms=False), ), # avoid inlining the function ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index def8800c98..233e5a2f6e 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -114,12 +114,19 @@ def test_inlining_of_scalar_works_integration(testee_prog): hijacked_program = None + @dataclasses.dataclass(frozen=True) + class _NoOpArtifact: + """A trivial BuildArtifact that materializes to a no-op callable.""" + + def materialize(self): + return lambda *args, **kwargs: None + def pirate(program: toolchain.ConcreteArtifact): - # Replaces the gtfn otf_workflow: and steals the compilable program, - # then returns a dummy "CompiledProgram" that does nothing. + # Replaces the gtfn otf_workflow: steals the compilable program, then + # returns a dummy artifact whose materialization is a no-op callable. nonlocal hijacked_program hijacked_program = program - return lambda *args, **kwargs: None + return _NoOpArtifact() hacked_gtfn_backend = gtfn.GTFNBackendFactory(name_postfix="_custom", executor=pirate) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 3611abad89..d027c9dcb1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -135,11 +135,11 @@ def test_gtfn_file_cache(program_example): ) cached_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=True - ).executor.build.step.translation + ).executor.step.translation bare_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=False - ).executor.build.step.translation + ).executor.step.translation cache_key = stages.fingerprint_compilable_program(compilable_program) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index cd3bddb19a..ab4697ed73 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -34,12 +34,12 @@ def test_backend_factory_trait_device(): assert cpu_version.name == "run_gtfn_cpu" assert gpu_version.name == "run_gtfn_gpu" - assert cpu_version.executor.build.translation.device_type is core_defs.DeviceType.CPU - assert gpu_version.executor.build.translation.device_type is core_defs.DeviceType.CUDA + assert cpu_version.executor.translation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.translation.device_type is core_defs.DeviceType.CUDA # The compilation step now also carries device_type so it can stamp the artifact. - assert cpu_version.executor.build.compilation.device_type is core_defs.DeviceType.CPU - assert gpu_version.executor.build.compilation.device_type is core_defs.DeviceType.CUDA + assert cpu_version.executor.compilation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.compilation.device_type is core_defs.DeviceType.CUDA assert custom_layout_allocators.is_field_allocator_for( cpu_version.allocator, core_defs.DeviceType.CPU @@ -51,7 +51,7 @@ def test_backend_factory_trait_device(): def test_backend_factory_trait_cached(): cached_version = gtfn.GTFNBackendFactory(gpu=False, cached=True) - assert isinstance(cached_version.executor.build, workflow.CachedStep) + assert isinstance(cached_version.executor, workflow.CachedStep) assert cached_version.name == "run_gtfn_cpu_cached" @@ -62,11 +62,11 @@ def test_backend_factory_build_cache_config(monkeypatch): persistent_version = gtfn.GTFNBackendFactory() assert ( - session_version.executor.build.compilation.cache_lifetime + session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION ) assert ( - persistent_version.executor.build.compilation.cache_lifetime + persistent_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.PERSISTENT ) @@ -78,10 +78,10 @@ def test_backend_factory_build_type_config(monkeypatch): min_size_version = gtfn.GTFNBackendFactory() assert ( - release_version.executor.build.compilation.builder_factory.cmake_build_type + release_version.executor.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.RELEASE ) assert ( - min_size_version.executor.build.compilation.builder_factory.cmake_build_type + min_size_version.executor.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.MIN_SIZE_REL ) From c51ac5592cb182b4537dfcc5e840fb416a716105 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 16:47:24 +0200 Subject: [PATCH 05/24] cleanup docstrings --- src/gt4py/next/otf/compilation/compiler.py | 19 +++++----- src/gt4py/next/otf/recipes.py | 15 ++------ src/gt4py/next/otf/stages.py | 36 ++++--------------- .../runners/dace/program.py | 6 ++-- .../runners/dace/workflow/compilation.py | 23 +++--------- 5 files changed, 25 insertions(+), 74 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index f1b076b55c..66bd74556f 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -48,10 +48,9 @@ def __call__( class GTFNBuildArtifact: """On-disk result of a GTFN compilation: a Python extension module. - Bindings are baked into the .so via nanobind, so materialization is just - an ``importlib`` import + entry-point symbol lookup, plus a wrapping in - gt4py's calling convention. ``device_type`` is intrinsic to the artifact: - a CPU-built .so cannot be loaded as GPU. + Bindings are baked into the .so via nanobind, so :meth:`materialize` is + just an ``importlib`` import + entry-point symbol lookup, plus a wrap in + gt4py's calling convention. """ src_dir: pathlib.Path @@ -60,15 +59,13 @@ class GTFNBuildArtifact: device_type: core_defs.DeviceType def materialize(self) -> stages.ExecutableProgram: - """Bring the artifact up as a directly-callable program. + """Import the module and wrap its entry point in gt4py's calling convention. - Must run in the process that will ultimately call the returned - program; the imported module is registered in that process's - ``sys.modules`` under the ``gt4py.__compiled_programs__.`` prefix. + Must run in the process that will call the returned program: the + module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. """ - # Imported lazily to avoid a circular module dependency: ``runners.gtfn`` - # imports this module to construct the workflow, while the - # gt4py-shaped argument-conversion lives there. + # Lazy import: ``runners.gtfn`` imports this module to construct the workflow. from gt4py.next.program_processors.runners.gtfn import convert_args m = importer.import_from_path( diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index c1af3388ae..f784a20a12 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -19,18 +19,9 @@ class OTFBuildWorkflow( ): """Translation + bindings + build system; ends at a :class:`stages.BuildArtifact`. - The artifact's concrete type is backend-specific (e.g. ``GTFNBuildArtifact`` - or ``DaCeBuildArtifact``); both share only the - :class:`stages.BuildArtifact` Protocol — frozen, picklable, self- - materializing. The whole post-build phase lives on the artifact itself - (``artifact.materialize()`` returns the directly-callable program); this - workflow's job is just to produce the artifact. - - Used directly as :attr:`gt4py.next.backend.Backend.executor`. The - ``cached=True`` backend trait wraps this whole workflow in a - :class:`workflow.CachedStep` — caching keys on - :class:`definitions.CompilableProgramDef` and values on a picklable - artifact. + Used as :attr:`gt4py.next.backend.Backend.executor`. The ``cached=True`` + backend trait wraps it in a :class:`workflow.CachedStep` keyed on + :class:`definitions.CompilableProgramDef`. """ translation: definitions.TranslationStep diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 35cfe0e425..4a735a76aa 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -130,35 +130,13 @@ def build(self) -> None: ... class BuildArtifact(Protocol): - """A picklable, self-contained, compiled gt4py program in transit. - - A *build artifact* is the output of an :class:`recipes.OTFBuildWorkflow` - (the value of :attr:`gt4py.next.backend.Backend.executor`) — the explicit - boundary between the build phase (heavy, idempotent, parallelizable, - picklable output) and the live-callable phase (cheap, process-bound). - - Each backend defines its own concrete artifact dataclass, carrying - whatever fields it needs to bring up a runnable callable in any process - that has the backend module on the import path. Conventions: - - 1. **Frozen dataclass.** Implementations are - ``@dataclasses.dataclass(frozen=True)`` so they have value semantics - (hashable, structurally equatable) for use as cache keys. - - 2. **Picklable.** Implementations round-trip safely through :mod:`pickle` - so they can cross process boundaries: process-pool / distributed - compilation, AOT pipelines that build now and run later from a - different process, persistent caches keyed on the artifact. Live, - process-bound state (open files, ``ctypes`` handles, imported Python - modules) is therefore not allowed in the artifact — that is what - :meth:`materialize` rehydrates. - - 3. **Self-materializing.** Calling :meth:`materialize` returns a - directly-callable :class:`ExecutableProgram` taking gt4py-shaped - arguments. The method body is the backend's full post-build - sequence (load the .so, wrap with the calling convention, attach - metric hooks, etc.). Receivers don't need to know which backend - produced the artifact — they just call ``artifact.materialize()``. + """The output of an :class:`recipes.OTFBuildWorkflow`. + + Each backend defines its own concrete artifact dataclass; all share this + Protocol. Implementations are frozen dataclasses, picklable, and have no + live process-bound state — that is reconstructed by :meth:`materialize`, + which returns a directly-callable :class:`ExecutableProgram` taking + gt4py-shaped arguments. """ def materialize(self) -> ExecutableProgram: ... diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 310b9634ba..1435080f52 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -76,8 +76,8 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: gt4py_program_args=[p.type for p in program.params], ) - # ``backend.executor`` is an :class:`recipes.OTFBuildWorkflow`, optionally wrapped - # in a :class:`workflow.CachedStep` when ``cached=True``. + # The executor and the translation stage may each be wrapped in a `CachedStep` + # depending on backend configuration; unwrap when so. build_workflow = typing.cast( recipes.OTFBuildWorkflow, self.backend.executor.step @@ -88,7 +88,7 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: build_workflow.translation.step if hasattr(build_workflow.translation, "step") else build_workflow.translation - ) # the translation stage could also be a `CachedStep` depending on backend configuration. + ) # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with # the other parts of the workaround when possible. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 9bbb035997..d79de03233 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -118,13 +118,7 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) class DaCeBuildArtifact: - """On-disk result of a DaCe compilation. - - Carries the ``device_type`` the artifact was built for; a CPU-built .so - cannot be loaded as GPU. Also carries the bindings (Python source code - that materialization ``exec``\\ s to bring the SDFG argument-marshalling - function into existence). - """ + """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" build_folder: pathlib.Path binding_source_code: str @@ -132,19 +126,11 @@ class DaCeBuildArtifact: device_type: core_defs.DeviceType def materialize(self) -> stages.ExecutableProgram: - """Bring the artifact up as a directly-callable program. - - Re-deserializes the SDFG dump from the build folder, links against - the pre-built .so via ``compiler.use_cache=True`` (no re-codegen), - wraps it in a :class:`CompiledDaceProgram`, and applies gt4py's - calling convention via :func:`decoration.convert_args`. + """Re-deserialize the SDFG, link the .so, and wrap in gt4py's calling convention. - Must run in the process that will ultimately call the returned - program; the imported binding code is bound into a per-call namespace - within the produced :class:`CompiledDaceProgram`. + Must run in the process that will call the returned program. """ - # Imported lazily to avoid a circular module dependency: - # ``decoration`` imports this module. + # Lazy import: ``decoration`` imports this module. from gt4py.next.program_processors.runners.dace.workflow import ( decoration as gtx_wfddecoration, ) @@ -162,7 +148,6 @@ def materialize(self) -> stages.ExecutableProgram: sdfg.build_folder = str(self.build_folder) with gtx_wfdcommon.dace_context(device_type=self.device_type): - # use_cache=True forces DaCe to load the existing .so without re-codegen. with dace.config.set_temporary("compiler", "use_cache", value=True): sdfg_program = sdfg.compile(validate=False) From 9b842e2ab79ed8f053fb9dbe60f578e1fc1b5739 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 17:09:58 +0200 Subject: [PATCH 06/24] refactor: remove lazy import --- src/gt4py/next/otf/compilation/compiler.py | 8 +- .../runners/dace/workflow/compilation.py | 110 ++---------------- .../runners/dace/workflow/compiled_program.py | 110 ++++++++++++++++++ .../runners/dace/workflow/decoration.py | 8 +- .../next/program_processors/runners/gtfn.py | 88 +------------- .../runners/gtfn_decoration.py | 105 +++++++++++++++++ .../runners_tests/dace_tests/test_dace.py | 8 +- 7 files changed, 236 insertions(+), 201 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py create mode 100644 src/gt4py/next/program_processors/runners/gtfn_decoration.py diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 66bd74556f..a858a0b02d 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -18,6 +18,7 @@ from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer +from gt4py.next.program_processors.runners import gtfn_decoration T = TypeVar("T") @@ -65,14 +66,13 @@ def materialize(self) -> stages.ExecutableProgram: module is registered in that process's ``sys.modules`` under the ``gt4py.__compiled_programs__.`` prefix. """ - # Lazy import: ``runners.gtfn`` imports this module to construct the workflow. - from gt4py.next.program_processors.runners.gtfn import convert_args - m = importer.import_from_path( self.src_dir / self.module, sys_modules_prefix="gt4py.__compiled_programs__.", ) - return convert_args(getattr(m, self.entry_point_name), device=self.device_type) + return gtfn_decoration.convert_args( + getattr(m, self.entry_point_name), device=self.device_type + ) @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index d79de03233..a7c628daea 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -9,111 +9,22 @@ from __future__ import annotations import dataclasses -import os import pathlib -import warnings -from collections.abc import Callable, MutableSequence, Sequence -from typing import Any import dace import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import common, config +from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache -from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon - - -class CompiledDaceProgram: - sdfg_program: dace.CompiledSDFG - - # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; - # scalar arguments that are not used in the SDFG will not be present. - sdfg_argtypes: list[dace.dtypes.Data] - - # The compiled program contains a callable object to update the SDFG arguments list. - update_sdfg_ctype_arglist: Callable[ - [ - core_defs.DeviceType, - Sequence[dace.dtypes.Data], - Sequence[Any], - MutableSequence[Any], - common.OffsetProvider, - ], - None, - ] - - # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` - # means that it has not been initialized, i.e. no call was ever performed. - # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. - # - csdfg_init_argv: Arguments used for initialization; used only the first time and - # never updated. - csdfg_argv: MutableSequence[Any] | None - csdfg_init_argv: Sequence[Any] | None - - def __init__( - self, - program: dace.CompiledSDFG, - bind_func_name: str, - binding_source_code: str, - ): - self.sdfg_program = program - - # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument - # name to its data type, in the same order as arguments appear in the program ABI. - # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. - self.sdfg_argtypes = list(program.sdfg.arglist().values()) - - # The binding source code is Python tailored to this specific SDFG. - # We dynamically compile that function and add it to the compiled program. - global_namespace: dict[str, Any] = {} - exec(binding_source_code, global_namespace) - self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] - # For debug purpose, we set a unique module name on the compiled function. - self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) - - # Since the SDFG hasn't been called yet. - self.csdfg_argv = None - self.csdfg_init_argv = None - - def construct_arguments(self, **kwargs: Any) -> None: - """ - This function will process the arguments and store the processed argument - vectors in `self.csdfg_args`, to call them use `self.fast_call()`. - """ - with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): - csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) - # Note we only care about `csdfg_argv` (normal call), since we have to update it, - # we ensure that it is a `list`. - self.csdfg_argv = [*csdfg_argv] - self.csdfg_init_argv = csdfg_init_argv - - def fast_call(self) -> None: - """ - Perform a call to the compiled SDFG using the previously generated argument - vectors, see `self.construct_arguments()`. - """ - assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( - "Argument vector was not set properly." - ) - self.sdfg_program.fast_call( - self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG - ) - - def __call__(self, **kwargs: Any) -> None: - """Call the compiled SDFG with the given arguments. - - Note that this function will not update the argument vectors stored inside - `self`. Furthermore, it is not recommended to use this function as it is - very slow. - """ - warnings.warn( - "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", - stacklevel=1, - ) - result = self.sdfg_program(**kwargs) - assert result is None +from gt4py.next.program_processors.runners.dace.workflow import ( + common as gtx_wfdcommon, + decoration as gtx_wfddecoration, +) +from gt4py.next.program_processors.runners.dace.workflow.compiled_program import ( + CompiledDaceProgram, +) @dataclasses.dataclass(frozen=True) @@ -130,11 +41,6 @@ def materialize(self) -> stages.ExecutableProgram: Must run in the process that will call the returned program. """ - # Lazy import: ``decoration`` imports this module. - from gt4py.next.program_processors.runners.dace.workflow import ( - decoration as gtx_wfddecoration, - ) - for dump_name in ("program.sdfgz", "program.sdfg"): sdfg_dump = self.build_folder / dump_name if sdfg_dump.exists(): diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py b/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py new file mode 100644 index 0000000000..5e28853902 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py @@ -0,0 +1,110 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import os +import warnings +from collections.abc import Callable, MutableSequence, Sequence +from typing import Any + +import dace + +from gt4py._core import definitions as core_defs +from gt4py.next import common, config + + +class CompiledDaceProgram: + sdfg_program: dace.CompiledSDFG + + # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; + # scalar arguments that are not used in the SDFG will not be present. + sdfg_argtypes: list[dace.dtypes.Data] + + # The compiled program contains a callable object to update the SDFG arguments list. + update_sdfg_ctype_arglist: Callable[ + [ + core_defs.DeviceType, + Sequence[dace.dtypes.Data], + Sequence[Any], + MutableSequence[Any], + common.OffsetProvider, + ], + None, + ] + + # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` + # means that it has not been initialized, i.e. no call was ever performed. + # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. + # - csdfg_init_argv: Arguments used for initialization; used only the first time and + # never updated. + csdfg_argv: MutableSequence[Any] | None + csdfg_init_argv: Sequence[Any] | None + + def __init__( + self, + program: dace.CompiledSDFG, + bind_func_name: str, + binding_source_code: str, + ): + self.sdfg_program = program + + # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument + # name to its data type, in the same order as arguments appear in the program ABI. + # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. + self.sdfg_argtypes = list(program.sdfg.arglist().values()) + + # The binding source code is Python tailored to this specific SDFG. + # We dynamically compile that function and add it to the compiled program. + global_namespace: dict[str, Any] = {} + exec(binding_source_code, global_namespace) + self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] + # For debug purpose, we set a unique module name on the compiled function. + self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) + + # Since the SDFG hasn't been called yet. + self.csdfg_argv = None + self.csdfg_init_argv = None + + def construct_arguments(self, **kwargs: Any) -> None: + """ + This function will process the arguments and store the processed argument + vectors in `self.csdfg_args`, to call them use `self.fast_call()`. + """ + with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): + csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) + # Note we only care about `csdfg_argv` (normal call), since we have to update it, + # we ensure that it is a `list`. + self.csdfg_argv = [*csdfg_argv] + self.csdfg_init_argv = csdfg_init_argv + + def fast_call(self) -> None: + """ + Perform a call to the compiled SDFG using the previously generated argument + vectors, see `self.construct_arguments()`. + """ + assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( + "Argument vector was not set properly." + ) + self.sdfg_program.fast_call( + self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG + ) + + def __call__(self, **kwargs: Any) -> None: + """Call the compiled SDFG with the given arguments. + + Note that this function will not update the argument vectors stored inside + `self`. Furthermore, it is not recommended to use this function as it is + very slow. + """ + warnings.warn( + "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", + stacklevel=1, + ) + result = self.sdfg_program(**kwargs) + assert result is None diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 103e7af33b..9681785206 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -18,14 +18,14 @@ from gt4py.next.instrumentation import metrics from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable -from gt4py.next.program_processors.runners.dace.workflow import ( - common as gtx_wfdcommon, - compilation as gtx_wfdcompilation, +from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon +from gt4py.next.program_processors.runners.dace.workflow.compiled_program import ( + CompiledDaceProgram, ) def convert_args( - fun: gtx_wfdcompilation.CompiledDaceProgram, + fun: CompiledDaceProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU, ) -> stages.ExecutableProgram: # Retieve metrics level from GT4Py environment variable. diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index f865384271..bddf06a1f3 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,17 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any - import factory -import numpy as np import gt4py._core.definitions as core_defs import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, common, config, field_utils -from gt4py.next.embedded import nd_array_field -from gt4py.next.instrumentation import metrics +from gt4py.next import backend, config from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -24,87 +19,6 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module -def convert_arg(arg: Any) -> Any: - # Note: this function is on the hot path and needs to have minimal overhead. - if (origin := getattr(arg, "__gt_origin__", None)) is not None: - # `Field` is the most likely case, we use `__gt_origin__` as the property is needed anyway - # and (currently) uniquely identifies a `NDArrayField` (which is the only supported `Field`) - assert isinstance(arg, nd_array_field.NdArrayField) - return arg.ndarray, origin - if isinstance(arg, tuple): - return tuple(convert_arg(a) for a in arg) - if isinstance(arg, np.bool_): - # nanobind does not support implicit conversion of `np.bool` to `bool` - return bool(arg) - # TODO(havogt): if this function still appears in profiles, - # we should avoid going through the previous isinstance checks for detecting a scalar. - # E.g. functools.cache on the arg type, returning a function that does the conversion - return arg - - -def convert_args( - inp: stages.ExecutableProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU -) -> stages.ExecutableProgram: - def decorated_program( - *args: Any, - offset_provider: dict[str, common.Connectivity | common.Dimension], - out: Any = None, - ) -> None: - # Note: this function is on the hot path and needs to have minimal overhead. - if out is not None: - args = (*args, out) - converted_args = (convert_arg(arg) for arg in args) - conn_args = extract_connectivity_args(offset_provider, device) - - opt_kwargs: dict[str, Any] = {} - if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): - # If we are collecting metrics, we need to add the `exec_info` argument - # to the `inp` call, which will be used to collect performance metrics. - exec_info: dict[str, float] = {} - opt_kwargs["exec_info"] = exec_info - - # generate implicit domain size arguments only if necessary, using `iter_size_args()` - inp( - *converted_args, - *conn_args, - **opt_kwargs, - ) - - if collect_metrics: - metrics.add_sample_to_current_source( - metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] - ) - - return decorated_program - - -def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: - # Note: this function is on the hot path and needs to have minimal overhead. - zero_origin = (0, 0) - assert all( - hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) - for conn in offset_provider.values() - ) - # Note: the order here needs to agree with the order of the generated bindings. - # This is currently true only because when hashing offset provider dicts, - # the keys' order is taken into account. Any modification to the hashing - # of offset providers may break this assumption here. - args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ - (ndarray, zero_origin) - for conn in offset_provider.values() - if (ndarray := getattr(conn, "ndarray", None)) is not None - ] - assert all( - common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) - for conn in offset_provider.values() - if hasattr(conn, "ndarray") - ) - - return args - - class GTFNBuildWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFBuildWorkflow diff --git a/src/gt4py/next/program_processors/runners/gtfn_decoration.py b/src/gt4py/next/program_processors/runners/gtfn_decoration.py new file mode 100644 index 0000000000..1ea2b222ca --- /dev/null +++ b/src/gt4py/next/program_processors/runners/gtfn_decoration.py @@ -0,0 +1,105 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Calling-convention adapter for GTFN-compiled programs. + +Wraps a freshly-imported GTFN entry point with gt4py's user-facing +argument convention: unpacks fields, splits offset_provider into +connectivity args, attaches metric collection. +""" + +from typing import Any + +import numpy as np + +import gt4py._core.definitions as core_defs +from gt4py.next import common, field_utils +from gt4py.next.embedded import nd_array_field +from gt4py.next.instrumentation import metrics +from gt4py.next.otf import stages + + +def convert_arg(arg: Any) -> Any: + # Note: this function is on the hot path and needs to have minimal overhead. + if (origin := getattr(arg, "__gt_origin__", None)) is not None: + # `Field` is the most likely case, we use `__gt_origin__` as the property is needed anyway + # and (currently) uniquely identifies a `NDArrayField` (which is the only supported `Field`) + assert isinstance(arg, nd_array_field.NdArrayField) + return arg.ndarray, origin + if isinstance(arg, tuple): + return tuple(convert_arg(a) for a in arg) + if isinstance(arg, np.bool_): + # nanobind does not support implicit conversion of `np.bool` to `bool` + return bool(arg) + # TODO(havogt): if this function still appears in profiles, + # we should avoid going through the previous isinstance checks for detecting a scalar. + # E.g. functools.cache on the arg type, returning a function that does the conversion + return arg + + +def convert_args( + inp: stages.ExecutableProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU +) -> stages.ExecutableProgram: + def decorated_program( + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, + ) -> None: + # Note: this function is on the hot path and needs to have minimal overhead. + if out is not None: + args = (*args, out) + converted_args = (convert_arg(arg) for arg in args) + conn_args = extract_connectivity_args(offset_provider, device) + + opt_kwargs: dict[str, Any] = {} + if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): + # If we are collecting metrics, we need to add the `exec_info` argument + # to the `inp` call, which will be used to collect performance metrics. + exec_info: dict[str, float] = {} + opt_kwargs["exec_info"] = exec_info + + # generate implicit domain size arguments only if necessary, using `iter_size_args()` + inp( + *converted_args, + *conn_args, + **opt_kwargs, + ) + + if collect_metrics: + metrics.add_sample_to_current_source( + metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] + ) + + return decorated_program + + +def extract_connectivity_args( + offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: + # Note: this function is on the hot path and needs to have minimal overhead. + zero_origin = (0, 0) + assert all( + hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) + for conn in offset_provider.values() + ) + # Note: the order here needs to agree with the order of the generated bindings. + # This is currently true only because when hashing offset provider dicts, + # the keys' order is taken into account. Any modification to the hashing + # of offset providers may break this assumption here. + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ + (ndarray, zero_origin) + for conn in offset_provider.values() + if (ndarray := getattr(conn, "ndarray", None)) is not None + ] + assert all( + common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) + for conn in offset_provider.values() + if hasattr(conn, "ndarray") + ) + + return args diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index a204886690..06b3d428bb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -82,7 +82,7 @@ def make_mocks(monkeypatch): # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object mock_fast_call = unittest.mock.MagicMock() gt4py_fast_call = ( - gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.fast_call + gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram.fast_call ) def mocked_fast_call(self): @@ -99,21 +99,21 @@ def mocked_fast_call(self): return fast_call_result monkeypatch.setattr( - gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, + gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram, "fast_call", mocked_fast_call, ) # Wrap `compiled_sdfg.CompiledSDFG.construct_arguments` with mock object mock_construct_arguments = unittest.mock.MagicMock() - gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.construct_arguments + gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram.construct_arguments def mocked_construct_arguments(self, *args, **kwargs): mock_construct_arguments.__call__(*args, **kwargs) return gt4py_construct_arguments(self, *args, **kwargs) monkeypatch.setattr( - gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, + gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram, "construct_arguments", mocked_construct_arguments, ) From c022bde47ee8e54a4d72750385bd3a9328cb489e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 17:24:52 +0200 Subject: [PATCH 07/24] apply pre-commit --- src/gt4py/next/otf/definitions.py | 4 +--- .../runners/dace/workflow/compilation.py | 8 ++------ .../runners/dace/workflow/decoration.py | 4 +--- .../program_processors/runners/dace/workflow/factory.py | 4 +--- .../program_processor_tests/runners_tests/test_gtfn.py | 5 +---- 5 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 1fe56a1f11..b5d6a0ecfa 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -56,9 +56,7 @@ def __call__( class CompilationStep( - workflow.Workflow[ - stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact - ], + workflow.Workflow[stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact], Protocol[CodeSpecT, TargetCodeSpecT], ): """Run the build system and produce a :class:`stages.BuildArtifact`. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index a7c628daea..71b7934fc1 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -22,9 +22,7 @@ common as gtx_wfdcommon, decoration as gtx_wfddecoration, ) -from gt4py.next.program_processors.runners.dace.workflow.compiled_program import ( - CompiledDaceProgram, -) +from gt4py.next.program_processors.runners.dace.workflow.compiled_program import CompiledDaceProgram @dataclasses.dataclass(frozen=True) @@ -57,9 +55,7 @@ def materialize(self) -> stages.ExecutableProgram: with dace.config.set_temporary("compiler", "use_cache", value=True): sdfg_program = sdfg.compile(validate=False) - program = CompiledDaceProgram( - sdfg_program, self.bind_func_name, self.binding_source_code - ) + program = CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) return gtx_wfddecoration.convert_args(program, device=self.device_type) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 9681785206..6b828d5a97 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -19,9 +19,7 @@ from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon -from gt4py.next.program_processors.runners.dace.workflow.compiled_program import ( - CompiledDaceProgram, -) +from gt4py.next.program_processors.runners.dace.workflow.compiled_program import CompiledDaceProgram def convert_args( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index f022cdea64..9f6c80fd07 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -16,9 +16,7 @@ from gt4py._core import definitions as core_defs, filecache from gt4py.next import config from gt4py.next.otf import recipes, stages, workflow -from gt4py.next.program_processors.runners.dace.workflow import ( - bindings as bindings_step, -) +from gt4py.next.program_processors.runners.dace.workflow import bindings as bindings_step from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index ab4697ed73..712e0500f5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -61,10 +61,7 @@ def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.PERSISTENT) persistent_version = gtfn.GTFNBackendFactory() - assert ( - session_version.executor.compilation.cache_lifetime - is config.BuildCacheLifetime.SESSION - ) + assert session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION assert ( persistent_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.PERSISTENT From a107dd61b03aa3fc0aef5817318f542d8ae66bbf Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 17:31:08 +0200 Subject: [PATCH 08/24] update roundtrip --- .../program_processors/runners/roundtrip.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 5ee0a67f25..d97b3ab238 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -213,13 +213,28 @@ def fencil_generator( @dataclasses.dataclass(frozen=True) -class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram]): +class RoundtripArtifact: + """In-memory artifact for the roundtrip backend. + + Roundtrip generates and ``exec``\\ s a Python module per program, so its + output is a live callable rather than something on disk. Not picklable — + roundtrip is in-process only. + """ + + program: stages.ExecutableProgram + + def materialize(self) -> stages.ExecutableProgram: + return self.program + + +@dataclasses.dataclass(frozen=True) +class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, RoundtripArtifact]): debug: Optional[bool] = None use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` - def __call__(self, inp: definitions.CompilableProgramDef) -> stages.ExecutableProgram: + def __call__(self, inp: definitions.CompilableProgramDef) -> RoundtripArtifact: debug = config.DEBUG if self.debug is None else self.debug fencil = fencil_generator( @@ -249,7 +264,7 @@ def decorated_fencil( **kwargs, ) - return decorated_fencil + return RoundtripArtifact(program=decorated_fencil) # TODO(tehrengruber): introduce factory From efecdf5b26da61897b4d5ad8146d324865f5c4e1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 18:01:01 +0200 Subject: [PATCH 09/24] separate gtfn from generic compiler --- .../otf/compilation/build_orchestrator.py | 81 +++++++++++++++++++ .../next/otf/compilation/build_system.py | 33 ++++++++ .../otf/compilation/build_systems/cmake.py | 4 +- .../compilation/build_systems/compiledb.py | 4 +- .../next/program_processors/runners/gtfn.py | 7 +- .../runners/gtfn_compiler.py} | 62 +++----------- .../otf_tests/test_nanobind_build.py | 34 ++++++-- 7 files changed, 163 insertions(+), 62 deletions(-) create mode 100644 src/gt4py/next/otf/compilation/build_orchestrator.py create mode 100644 src/gt4py/next/otf/compilation/build_system.py rename src/gt4py/next/{otf/compilation/compiler.py => program_processors/runners/gtfn_compiler.py} (54%) diff --git a/src/gt4py/next/otf/compilation/build_orchestrator.py b/src/gt4py/next/otf/compilation/build_orchestrator.py new file mode 100644 index 0000000000..4aea7b4a0d --- /dev/null +++ b/src/gt4py/next/otf/compilation/build_orchestrator.py @@ -0,0 +1,81 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Generic build orchestration for backends that produce a Python extension module. + +Wraps the lock + build-data tracking + builder-factory invocation + post-build +validation into a single :func:`run_build` call. Returns a :class:`BuildResult` +descriptor (paths + entry-point name) that backends wrap into their own +artifact type. +""" + +from __future__ import annotations + +import dataclasses +import pathlib +from typing import TypeVar + +from gt4py._core import locking +from gt4py.next import config +from gt4py.next.otf import code_specs, stages +from gt4py.next.otf.compilation import build_data, build_system, cache + + +CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) +TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) + + +@dataclasses.dataclass(frozen=True) +class BuildResult: + """On-disk descriptor of a successful build.""" + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + + +class CompilationError(RuntimeError): ... + + +def is_compiled(data: build_data.BuildData) -> bool: + return data.status >= build_data.BuildStatus.COMPILED + + +def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: + return (src_dir / data.module).exists() + + +def run_build( + inp: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], + cache_lifetime: config.BuildCacheLifetime, + builder_factory: build_system.BuildSystemProjectGenerator[CodeSpecT, TargetCodeSpecT], + force_recompile: bool = False, +) -> BuildResult: + """Drive ``builder_factory`` to produce a Python extension module on disk.""" + src_dir = cache.get_cache_folder(inp, cache_lifetime) + + # If we are compiling the same program at the same time (e.g. multiple MPI ranks), + # we need to make sure that only one of them accesses the same build directory for compilation. + with locking.lock(src_dir): + data = build_data.read_data(src_dir) + + if not data or not is_compiled(data) or force_recompile: + builder_factory(inp, cache_lifetime).build() + + new_data = build_data.read_data(src_dir) + + if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): + raise CompilationError( + f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." + ) + + return BuildResult( + src_dir=src_dir, + module=new_data.module, + entry_point_name=new_data.entry_point_name, + ) diff --git a/src/gt4py/next/otf/compilation/build_system.py b/src/gt4py/next/otf/compilation/build_system.py new file mode 100644 index 0000000000..a7ce6d957e --- /dev/null +++ b/src/gt4py/next/otf/compilation/build_system.py @@ -0,0 +1,33 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import Protocol, TypeVar + +from gt4py.next import config +from gt4py.next.otf import code_specs, stages + + +CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) +TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) + + +class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): + """Factory protocol for build-system implementations. + + Given a :class:`stages.CompilableProject` and a cache lifetime, returns a + :class:`stages.BuildSystemProject` that drives the actual build (e.g. + cmake, compiledb). + """ + + def __call__( + self, + source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], + cache_lifetime: config.BuildCacheLifetime, + ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index 1b79cad6e4..dd9158d8c4 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -18,7 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import config, errors from gt4py.next.otf import code_specs, stages -from gt4py.next.otf.compilation import build_data, cache, common, compiler +from gt4py.next.otf.compilation import build_data, build_system, cache, common from gt4py.next.otf.compilation.build_systems import cmake_lists @@ -64,7 +64,7 @@ def get_cmake_device_arch_option() -> str: @dataclasses.dataclass class CMakeFactory( - compiler.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + build_system.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] ): """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 347b0e25e9..756a24ee38 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -20,7 +20,7 @@ from gt4py.next import config, errors from gt4py.next.otf import code_specs, stages from gt4py.next.otf.binding import interface -from gt4py.next.otf.compilation import build_data, cache, compiler +from gt4py.next.otf.compilation import build_data, build_system, cache from gt4py.next.otf.compilation.build_systems import cmake @@ -29,7 +29,7 @@ @dataclasses.dataclass class CompiledbFactory( - compiler.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + build_system.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] ): """ Create a CompiledbProject from a ``CompilableSource`` stage object with given CMake settings. diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index bddf06a1f3..95bacf6dba 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -14,9 +14,10 @@ from gt4py.next import backend, config from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler +from gt4py.next.otf.compilation import build_system from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module +from gt4py.next.program_processors.runners import gtfn_compiler class GTFNBuildWorkflowFactory(factory.Factory): @@ -28,7 +29,7 @@ class Params: cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough lambda: config.CMAKE_BUILD_TYPE ) - builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough + builder_factory: build_system.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) @@ -53,7 +54,7 @@ class Params: nanobind.bind_source ) compilation = factory.SubFactory( - compiler.CompilerFactory, + gtfn_compiler.CompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), device_type=factory.SelfAttribute("..device_type"), diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/program_processors/runners/gtfn_compiler.py similarity index 54% rename from src/gt4py/next/otf/compilation/compiler.py rename to src/gt4py/next/program_processors/runners/gtfn_compiler.py index a858a0b02d..73e683a7fb 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/program_processors/runners/gtfn_compiler.py @@ -10,41 +10,20 @@ import dataclasses import pathlib -from typing import Protocol, TypeVar +from typing import TypeVar import factory -from gt4py._core import definitions as core_defs, locking +from gt4py._core import definitions as core_defs from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow -from gt4py.next.otf.compilation import build_data, cache, importer +from gt4py.next.otf.compilation import build_orchestrator, build_system, importer from gt4py.next.program_processors.runners import gtfn_decoration -T = TypeVar("T") - - -def is_compiled(data: build_data.BuildData) -> bool: - return data.status >= build_data.BuildStatus.COMPILED - - -def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: - return (src_dir / data.module).exists() - - -CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) -TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) -class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): - def __call__( - self, - source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], - cache_lifetime: config.BuildCacheLifetime, - ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... - - @dataclasses.dataclass(frozen=True) class GTFNBuildArtifact: """On-disk result of a GTFN compilation: a Python extension module. @@ -87,10 +66,12 @@ class Compiler( ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Use any build system (via configured factory) to compile a GT4Py program into a :class:`GTFNBuildArtifact`.""" + """Drive a build system and wrap the result in a :class:`GTFNBuildArtifact`.""" cache_lifetime: config.BuildCacheLifetime - builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + builder_factory: build_system.BuildSystemProjectGenerator[ + CPPLikeCodeSpecT, code_specs.PythonCodeSpec + ] device_type: core_defs.DeviceType force_recompile: bool = False @@ -98,27 +79,13 @@ def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ) -> GTFNBuildArtifact: - src_dir = cache.get_cache_folder(inp, self.cache_lifetime) - - # If we are compiling the same program at the same time (e.g. multiple MPI ranks), - # we need to make sure that only one of them accesses the same build directory for compilation. - with locking.lock(src_dir): - data = build_data.read_data(src_dir) - - if not data or not is_compiled(data) or self.force_recompile: - self.builder_factory(inp, self.cache_lifetime).build() - - new_data = build_data.read_data(src_dir) - - if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): - raise CompilationError( - f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." - ) - + result = build_orchestrator.run_build( + inp, self.cache_lifetime, self.builder_factory, self.force_recompile + ) return GTFNBuildArtifact( - src_dir=src_dir, - module=new_data.module, - entry_point_name=new_data.entry_point_name, + src_dir=result.src_dir, + module=result.module, + entry_point_name=result.entry_point_name, device_type=self.device_type, ) @@ -126,6 +93,3 @@ def __call__( class CompilerFactory(factory.Factory): class Meta: model = Compiler - - -class CompilationError(RuntimeError): ... diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 49bd7b8f87..5967c75544 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -10,25 +10,45 @@ import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler +from gt4py.next.otf.compilation import importer from gt4py.next.otf.compilation.build_systems import cmake, compiledb +from gt4py.next.program_processors.runners import gtfn_compiler from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( program_source_with_name, ) +def _import_artifact_entry_point(artifact: gtfn_compiler.GTFNBuildArtifact): + """Import the .so directly and return the raw entry point. + + Bypasses :meth:`GTFNBuildArtifact.materialize` so the test can call the + nanobind-bound function with raw arguments rather than gt4py-shaped ones — + this is a build-system / binding integration test, not an end-to-end + program test. + """ + m = importer.import_from_path( + artifact.src_dir / artifact.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return getattr(m, artifact.entry_point_name) + + def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( - compiler.Compiler( - cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory() + gtfn_compiler.Compiler( + cache_lifetime=config.BuildCacheLifetime.SESSION, + builder_factory=cmake.CMakeFactory(), + device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source) + artifact = build_the_program(example_program_source) + compiled_program = _import_artifact_entry_point(artifact) buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), @@ -42,12 +62,14 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): def test_gtfn_cpp_with_compiledb(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_compiledb") build_the_program = workflow.make_step(nanobind.bind_source).chain( - compiler.Compiler( + gtfn_compiler.Compiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), + device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source) + artifact = build_the_program(example_program_source) + compiled_program = _import_artifact_entry_point(artifact) buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), From 68ead697c2ff57cbd84cf7de4144f8a142a589e4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 18:39:28 +0200 Subject: [PATCH 10/24] don't serialize/deserialize dace in the same process --- .../runners/dace/workflow/compilation.py | 51 ++++++++++++++++--- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 71b7934fc1..1a374491e7 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -10,12 +10,13 @@ import dataclasses import pathlib +from typing import Optional import dace import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import config +from gt4py.next import config, utils as gtx_utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import ( @@ -26,7 +27,7 @@ @dataclasses.dataclass(frozen=True) -class DaCeBuildArtifact: +class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" build_folder: pathlib.Path @@ -34,11 +35,35 @@ class DaCeBuildArtifact: bind_func_name: str device_type: core_defs.DeviceType + # Process-local cache of the live :class:`CompiledDaceProgram`. Populated by + # ``DaCeCompiler`` after a fresh compile so :meth:`materialize` can skip the + # SDFG re-deserialize + .so re-link round-trip in the same process. Marked + # ``pickle=False`` via :func:`gtx_utils.gt4py_metadata` so a receiver of the + # artifact in a different process sees ``None`` and falls back to the + # disk-based path. + _live_program: Optional[CompiledDaceProgram] = dataclasses.field( + init=False, + default=None, + compare=False, + repr=False, + metadata=gtx_utils.gt4py_metadata(pickle=False), + ) + def materialize(self) -> stages.ExecutableProgram: - """Re-deserialize the SDFG, link the .so, and wrap in gt4py's calling convention. + """Wrap the compiled program in gt4py's calling convention. - Must run in the process that will call the returned program. + Uses the live program cached on the artifact when available; otherwise + re-deserializes the SDFG, re-links the .so via ``compiler.use_cache``, + and caches the result for subsequent calls. Must run in the process + that will call the returned program. """ + program = self._live_program + if program is None: + program = self._load_compiled_program() + object.__setattr__(self, "_live_program", program) + return gtx_wfddecoration.convert_args(program, device=self.device_type) + + def _load_compiled_program(self) -> CompiledDaceProgram: for dump_name in ("program.sdfgz", "program.sdfg"): sdfg_dump = self.build_folder / dump_name if sdfg_dump.exists(): @@ -55,8 +80,7 @@ def materialize(self) -> stages.ExecutableProgram: with dace.config.set_temporary("compiler", "use_cache", value=True): sdfg_program = sdfg.compile(validate=False) - program = CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) - return gtx_wfddecoration.convert_args(program, device=self.device_type) + return CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) @dataclasses.dataclass(frozen=True) @@ -92,15 +116,26 @@ def __call__( sdfg = dace.SDFG.from_json(inp.program_source.source_code) sdfg.build_folder = sdfg_build_folder with locking.lock(sdfg_build_folder): - sdfg.compile(validate=False, return_program_handle=False) + # Keep the program handle so the artifact's materialize() can + # skip the SDFG re-deserialize + .so re-link round-trip when + # used in this same process. + sdfg_program = sdfg.compile(validate=False) assert inp.binding_source is not None - return DaCeBuildArtifact( + artifact = DaCeBuildArtifact( build_folder=pathlib.Path(sdfg_build_folder), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, device_type=self.device_type, ) + object.__setattr__( + artifact, + "_live_program", + CompiledDaceProgram( + sdfg_program, artifact.bind_func_name, artifact.binding_source_code + ), + ) + return artifact class DaCeCompilationStepFactory(factory.Factory): From 4702660ba9889bb8224e38d70b0551def93ebd3d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 19:21:46 +0200 Subject: [PATCH 11/24] don't search for sdfg in materialize --- .../runners/dace/workflow/compilation.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 1a374491e7..2f8365e289 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -31,6 +31,7 @@ class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" build_folder: pathlib.Path + sdfg_dump: pathlib.Path binding_source_code: str bind_func_name: str device_type: core_defs.DeviceType @@ -64,16 +65,7 @@ def materialize(self) -> stages.ExecutableProgram: return gtx_wfddecoration.convert_args(program, device=self.device_type) def _load_compiled_program(self) -> CompiledDaceProgram: - for dump_name in ("program.sdfgz", "program.sdfg"): - sdfg_dump = self.build_folder / dump_name - if sdfg_dump.exists(): - break - else: - raise RuntimeError( - f"No SDFG dump (program.sdfgz / program.sdfg) found in '{self.build_folder}'." - ) - - sdfg = dace.SDFG.from_file(str(sdfg_dump)) + sdfg = dace.SDFG.from_file(str(self.sdfg_dump)) sdfg.build_folder = str(self.build_folder) with gtx_wfdcommon.dace_context(device_type=self.device_type): @@ -110,20 +102,30 @@ def __call__( device_type=self.device_type, cmake_build_type=self.cmake_build_type, ): - sdfg_build_folder = gtx_cache.get_cache_folder(inp, self.cache_lifetime) + sdfg_build_folder = pathlib.Path(gtx_cache.get_cache_folder(inp, self.cache_lifetime)) sdfg_build_folder.mkdir(parents=True, exist_ok=True) sdfg = dace.SDFG.from_json(inp.program_source.source_code) - sdfg.build_folder = sdfg_build_folder + sdfg.build_folder = str(sdfg_build_folder) with locking.lock(sdfg_build_folder): # Keep the program handle so the artifact's materialize() can # skip the SDFG re-deserialize + .so re-link round-trip when # used in this same process. sdfg_program = sdfg.compile(validate=False) + for dump_name in ("program.sdfgz", "program.sdfg"): + sdfg_dump = sdfg_build_folder / dump_name + if sdfg_dump.exists(): + break + else: + raise RuntimeError( + f"No SDFG dump (program.sdfgz / program.sdfg) found in '{sdfg_build_folder}'." + ) + assert inp.binding_source is not None artifact = DaCeBuildArtifact( - build_folder=pathlib.Path(sdfg_build_folder), + build_folder=sdfg_build_folder, + sdfg_dump=sdfg_dump, binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, device_type=self.device_type, From 77b6235c01fb07987d5f9386b9760c48c947ddd4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:21:30 +0200 Subject: [PATCH 12/24] restore generic compiler --- .../otf/compilation/build_orchestrator.py | 81 ----------- .../next/otf/compilation/build_system.py | 33 ----- .../otf/compilation/build_systems/cmake.py | 4 +- .../compilation/build_systems/compiledb.py | 4 +- src/gt4py/next/otf/compilation/compiler.py | 135 ++++++++++++++++++ .../next/program_processors/runners/gtfn.py | 9 +- .../runners/gtfn_compiler.py | 95 ------------ .../otf_tests/test_nanobind_build.py | 17 ++- 8 files changed, 155 insertions(+), 223 deletions(-) delete mode 100644 src/gt4py/next/otf/compilation/build_orchestrator.py delete mode 100644 src/gt4py/next/otf/compilation/build_system.py create mode 100644 src/gt4py/next/otf/compilation/compiler.py delete mode 100644 src/gt4py/next/program_processors/runners/gtfn_compiler.py diff --git a/src/gt4py/next/otf/compilation/build_orchestrator.py b/src/gt4py/next/otf/compilation/build_orchestrator.py deleted file mode 100644 index 4aea7b4a0d..0000000000 --- a/src/gt4py/next/otf/compilation/build_orchestrator.py +++ /dev/null @@ -1,81 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Generic build orchestration for backends that produce a Python extension module. - -Wraps the lock + build-data tracking + builder-factory invocation + post-build -validation into a single :func:`run_build` call. Returns a :class:`BuildResult` -descriptor (paths + entry-point name) that backends wrap into their own -artifact type. -""" - -from __future__ import annotations - -import dataclasses -import pathlib -from typing import TypeVar - -from gt4py._core import locking -from gt4py.next import config -from gt4py.next.otf import code_specs, stages -from gt4py.next.otf.compilation import build_data, build_system, cache - - -CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) -TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) - - -@dataclasses.dataclass(frozen=True) -class BuildResult: - """On-disk descriptor of a successful build.""" - - src_dir: pathlib.Path - module: pathlib.Path - entry_point_name: str - - -class CompilationError(RuntimeError): ... - - -def is_compiled(data: build_data.BuildData) -> bool: - return data.status >= build_data.BuildStatus.COMPILED - - -def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: - return (src_dir / data.module).exists() - - -def run_build( - inp: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], - cache_lifetime: config.BuildCacheLifetime, - builder_factory: build_system.BuildSystemProjectGenerator[CodeSpecT, TargetCodeSpecT], - force_recompile: bool = False, -) -> BuildResult: - """Drive ``builder_factory`` to produce a Python extension module on disk.""" - src_dir = cache.get_cache_folder(inp, cache_lifetime) - - # If we are compiling the same program at the same time (e.g. multiple MPI ranks), - # we need to make sure that only one of them accesses the same build directory for compilation. - with locking.lock(src_dir): - data = build_data.read_data(src_dir) - - if not data or not is_compiled(data) or force_recompile: - builder_factory(inp, cache_lifetime).build() - - new_data = build_data.read_data(src_dir) - - if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): - raise CompilationError( - f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." - ) - - return BuildResult( - src_dir=src_dir, - module=new_data.module, - entry_point_name=new_data.entry_point_name, - ) diff --git a/src/gt4py/next/otf/compilation/build_system.py b/src/gt4py/next/otf/compilation/build_system.py deleted file mode 100644 index a7ce6d957e..0000000000 --- a/src/gt4py/next/otf/compilation/build_system.py +++ /dev/null @@ -1,33 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from typing import Protocol, TypeVar - -from gt4py.next import config -from gt4py.next.otf import code_specs, stages - - -CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) -TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) - - -class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): - """Factory protocol for build-system implementations. - - Given a :class:`stages.CompilableProject` and a cache lifetime, returns a - :class:`stages.BuildSystemProject` that drives the actual build (e.g. - cmake, compiledb). - """ - - def __call__( - self, - source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], - cache_lifetime: config.BuildCacheLifetime, - ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index dd9158d8c4..1b79cad6e4 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -18,7 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import config, errors from gt4py.next.otf import code_specs, stages -from gt4py.next.otf.compilation import build_data, build_system, cache, common +from gt4py.next.otf.compilation import build_data, cache, common, compiler from gt4py.next.otf.compilation.build_systems import cmake_lists @@ -64,7 +64,7 @@ def get_cmake_device_arch_option() -> str: @dataclasses.dataclass class CMakeFactory( - build_system.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + compiler.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] ): """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 756a24ee38..347b0e25e9 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -20,7 +20,7 @@ from gt4py.next import config, errors from gt4py.next.otf import code_specs, stages from gt4py.next.otf.binding import interface -from gt4py.next.otf.compilation import build_data, build_system, cache +from gt4py.next.otf.compilation import build_data, cache, compiler from gt4py.next.otf.compilation.build_systems import cmake @@ -29,7 +29,7 @@ @dataclasses.dataclass class CompiledbFactory( - build_system.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + compiler.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] ): """ Create a CompiledbProject from a ``CompilableSource`` stage object with given CMake settings. diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py new file mode 100644 index 0000000000..b128416a48 --- /dev/null +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -0,0 +1,135 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import dataclasses +import pathlib +from typing import Callable, Protocol, TypeVar + +import factory + +from gt4py._core import definitions as core_defs, locking +from gt4py.next import config, utils as gtx_utils +from gt4py.next.otf import code_specs, definitions, stages, workflow +from gt4py.next.otf.compilation import build_data, cache, importer + + +CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) +TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) +CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) + + +class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): + def __call__( + self, + source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], + cache_lifetime: config.BuildCacheLifetime, + ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... + + +def is_compiled(data: build_data.BuildData) -> bool: + return data.status >= build_data.BuildStatus.COMPILED + + +def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: + return (src_dir / data.module).exists() + + +class CompilationError(RuntimeError): ... + + +# Signature of the per-backend wrapping applied to a freshly imported entry point. +ProgramDecorator = Callable[ + [stages.ExecutableProgram, core_defs.DeviceType], stages.ExecutableProgram +] + + +@dataclasses.dataclass(frozen=True) +class CPPBuildArtifact(gtx_utils.MetadataBasedPickling): + """On-disk result of a CPP-style compilation: a Python extension module. + + Bindings are baked into the .so (e.g. via nanobind), so :meth:`materialize` + is just an ``importlib`` import + entry-point lookup, plus a per-backend + :attr:`decorator` that adapts the raw callable to the backend's calling + convention. + """ + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + device_type: core_defs.DeviceType + decorator: ProgramDecorator + + def materialize(self) -> stages.ExecutableProgram: + """Import the module and apply the configured per-backend decorator. + + Must run in the process that will call the returned program: the + module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. + """ + m = importer.import_from_path( + self.src_dir / self.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return self.decorator(getattr(m, self.entry_point_name), self.device_type) + + +@dataclasses.dataclass(frozen=True) +class CPPCompiler( + workflow.ChainableWorkflowMixin[ + stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], + CPPBuildArtifact, + ], + workflow.ReplaceEnabledWorkflowMixin[ + stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], + CPPBuildArtifact, + ], + definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], +): + """Drive a CPP-style build system and wrap the result in a :class:`CPPBuildArtifact`.""" + + cache_lifetime: config.BuildCacheLifetime + builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + device_type: core_defs.DeviceType + decorator: ProgramDecorator + force_recompile: bool = False + + def __call__( + self, + inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], + ) -> CPPBuildArtifact: + src_dir = cache.get_cache_folder(inp, self.cache_lifetime) + + # If we are compiling the same program at the same time (e.g. multiple MPI ranks), + # we need to make sure that only one of them accesses the same build directory for compilation. + with locking.lock(src_dir): + data = build_data.read_data(src_dir) + + if not data or not is_compiled(data) or self.force_recompile: + self.builder_factory(inp, self.cache_lifetime).build() + + new_data = build_data.read_data(src_dir) + + if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): + raise CompilationError( + f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." + ) + + return CPPBuildArtifact( + src_dir=src_dir, + module=new_data.module, + entry_point_name=new_data.entry_point_name, + device_type=self.device_type, + decorator=self.decorator, + ) + + +class CompilerFactory(factory.Factory): + class Meta: + model = CPPCompiler diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 95bacf6dba..b7d277a383 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -14,10 +14,10 @@ from gt4py.next import backend, config from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import build_system +from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module -from gt4py.next.program_processors.runners import gtfn_compiler +from gt4py.next.program_processors.runners import gtfn_decoration class GTFNBuildWorkflowFactory(factory.Factory): @@ -29,7 +29,7 @@ class Params: cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough lambda: config.CMAKE_BUILD_TYPE ) - builder_factory: build_system.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough + builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) @@ -54,10 +54,11 @@ class Params: nanobind.bind_source ) compilation = factory.SubFactory( - gtfn_compiler.CompilerFactory, + compiler.CompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), device_type=factory.SelfAttribute("..device_type"), + decorator=gtfn_decoration.convert_args, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn_compiler.py b/src/gt4py/next/program_processors/runners/gtfn_compiler.py deleted file mode 100644 index 73e683a7fb..0000000000 --- a/src/gt4py/next/program_processors/runners/gtfn_compiler.py +++ /dev/null @@ -1,95 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -import pathlib -from typing import TypeVar - -import factory - -from gt4py._core import definitions as core_defs -from gt4py.next import config -from gt4py.next.otf import code_specs, definitions, stages, workflow -from gt4py.next.otf.compilation import build_orchestrator, build_system, importer -from gt4py.next.program_processors.runners import gtfn_decoration - - -CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) - - -@dataclasses.dataclass(frozen=True) -class GTFNBuildArtifact: - """On-disk result of a GTFN compilation: a Python extension module. - - Bindings are baked into the .so via nanobind, so :meth:`materialize` is - just an ``importlib`` import + entry-point symbol lookup, plus a wrap in - gt4py's calling convention. - """ - - src_dir: pathlib.Path - module: pathlib.Path - entry_point_name: str - device_type: core_defs.DeviceType - - def materialize(self) -> stages.ExecutableProgram: - """Import the module and wrap its entry point in gt4py's calling convention. - - Must run in the process that will call the returned program: the - module is registered in that process's ``sys.modules`` under the - ``gt4py.__compiled_programs__.`` prefix. - """ - m = importer.import_from_path( - self.src_dir / self.module, - sys_modules_prefix="gt4py.__compiled_programs__.", - ) - return gtfn_decoration.convert_args( - getattr(m, self.entry_point_name), device=self.device_type - ) - - -@dataclasses.dataclass(frozen=True) -class Compiler( - workflow.ChainableWorkflowMixin[ - stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - GTFNBuildArtifact, - ], - workflow.ReplaceEnabledWorkflowMixin[ - stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - GTFNBuildArtifact, - ], - definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], -): - """Drive a build system and wrap the result in a :class:`GTFNBuildArtifact`.""" - - cache_lifetime: config.BuildCacheLifetime - builder_factory: build_system.BuildSystemProjectGenerator[ - CPPLikeCodeSpecT, code_specs.PythonCodeSpec - ] - device_type: core_defs.DeviceType - force_recompile: bool = False - - def __call__( - self, - inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> GTFNBuildArtifact: - result = build_orchestrator.run_build( - inp, self.cache_lifetime, self.builder_factory, self.force_recompile - ) - return GTFNBuildArtifact( - src_dir=result.src_dir, - module=result.module, - entry_point_name=result.entry_point_name, - device_type=self.device_type, - ) - - -class CompilerFactory(factory.Factory): - class Meta: - model = Compiler diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 5967c75544..b4b864a8d2 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -14,19 +14,18 @@ from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import importer +from gt4py.next.otf.compilation import compiler, importer from gt4py.next.otf.compilation.build_systems import cmake, compiledb -from gt4py.next.program_processors.runners import gtfn_compiler from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( program_source_with_name, ) -def _import_artifact_entry_point(artifact: gtfn_compiler.GTFNBuildArtifact): +def _import_artifact_entry_point(artifact: compiler.CPPBuildArtifact): """Import the .so directly and return the raw entry point. - Bypasses :meth:`GTFNBuildArtifact.materialize` so the test can call the + Bypasses :meth:`CPPBuildArtifact.materialize` so the test can call the nanobind-bound function with raw arguments rather than gt4py-shaped ones — this is a build-system / binding integration test, not an end-to-end program test. @@ -38,13 +37,18 @@ def _import_artifact_entry_point(artifact: gtfn_compiler.GTFNBuildArtifact): return getattr(m, artifact.entry_point_name) +def _identity(raw, _device): + return raw + + def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( - gtfn_compiler.Compiler( + compiler.CPPCompiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory(), device_type=core_defs.DeviceType.CPU, + decorator=_identity, ) ) artifact = build_the_program(example_program_source) @@ -62,10 +66,11 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): def test_gtfn_cpp_with_compiledb(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_compiledb") build_the_program = workflow.make_step(nanobind.bind_source).chain( - gtfn_compiler.Compiler( + compiler.CPPCompiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), device_type=core_defs.DeviceType.CPU, + decorator=_identity, ) ) artifact = build_the_program(example_program_source) From 3869d67443bf4afc7156255e27118f62e60b5563 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:27:58 +0200 Subject: [PATCH 13/24] cleanup --- src/gt4py/next/otf/compilation/compiler.py | 22 ++++++++--------- .../next/program_processors/runners/gtfn.py | 3 +-- .../otf_tests/test_nanobind_build.py | 24 ++++--------------- 3 files changed, 16 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index b128416a48..5063fb1c1f 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -20,6 +20,14 @@ from gt4py.next.otf.compilation import build_data, cache, importer +def is_compiled(data: build_data.BuildData) -> bool: + return data.status >= build_data.BuildStatus.COMPILED + + +def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: + return (src_dir / data.module).exists() + + CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) @@ -33,17 +41,6 @@ def __call__( ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... -def is_compiled(data: build_data.BuildData) -> bool: - return data.status >= build_data.BuildStatus.COMPILED - - -def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: - return (src_dir / data.module).exists() - - -class CompilationError(RuntimeError): ... - - # Signature of the per-backend wrapping applied to a freshly imported entry point. ProgramDecorator = Callable[ [stages.ExecutableProgram, core_defs.DeviceType], stages.ExecutableProgram @@ -133,3 +130,6 @@ def __call__( class CompilerFactory(factory.Factory): class Meta: model = CPPCompiler + + +class CompilationError(RuntimeError): ... diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index b7d277a383..dc009c376b 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -100,8 +100,7 @@ class Params: run_gtfn = GTFNBackendFactory() run_gtfn_imperative = GTFNBackendFactory( - name_postfix="_imperative", - otf_workflow__translation__use_imperative_backend=True, + name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True ) run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index b4b864a8d2..27ca7c16f1 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -14,7 +14,7 @@ from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler, importer +from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import cmake, compiledb from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( @@ -22,22 +22,8 @@ ) -def _import_artifact_entry_point(artifact: compiler.CPPBuildArtifact): - """Import the .so directly and return the raw entry point. - - Bypasses :meth:`CPPBuildArtifact.materialize` so the test can call the - nanobind-bound function with raw arguments rather than gt4py-shaped ones — - this is a build-system / binding integration test, not an end-to-end - program test. - """ - m = importer.import_from_path( - artifact.src_dir / artifact.module, - sys_modules_prefix="gt4py.__compiled_programs__.", - ) - return getattr(m, artifact.entry_point_name) - - def _identity(raw, _device): + """Pass-through decorator: this test calls the nanobind-bound function with raw args.""" return raw @@ -51,8 +37,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): decorator=_identity, ) ) - artifact = build_the_program(example_program_source) - compiled_program = _import_artifact_entry_point(artifact) + compiled_program = build_the_program(example_program_source).materialize() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), @@ -73,8 +58,7 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): decorator=_identity, ) ) - artifact = build_the_program(example_program_source) - compiled_program = _import_artifact_entry_point(artifact) + compiled_program = build_the_program(example_program_source).materialize() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), From 9017b33dcd9431629da712244c5a218e0ffd8d24 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:30:06 +0200 Subject: [PATCH 14/24] more cleanup --- .../runners/dace/program.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 1435080f52..e1ddbee455 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -76,19 +76,17 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: gt4py_program_args=[p.type for p in program.params], ) - # The executor and the translation stage may each be wrapped in a `CachedStep` - # depending on backend configuration; unwrap when so. - build_workflow = typing.cast( + compile_workflow = typing.cast( recipes.OTFBuildWorkflow, - self.backend.executor.step - if hasattr(self.backend.executor, "step") - else self.backend.executor, - ) + self.backend.executor + if not hasattr(self.backend.executor, "step") + else self.backend.executor.step, + ) # We know which backend we are using, but we don't know if the compile workflow is cached. compile_workflow_translation = ( - build_workflow.translation.step - if hasattr(build_workflow.translation, "step") - else build_workflow.translation - ) + compile_workflow.translation + if not hasattr(compile_workflow.translation, "step") + else compile_workflow.translation.step + ) # Same for the translation stage, which could be a `CachedStep` depending on backend configuration. # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with # the other parts of the workaround when possible. From 5bffec05e154497216d073a5a2d9f36b1534e860 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:32:39 +0200 Subject: [PATCH 15/24] add tests --- .../compilation_tests/test_compiler.py | 32 +++++++++++++++++ .../dace_tests/test_dace_compilation.py | 36 +++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py new file mode 100644 index 0000000000..250c026920 --- /dev/null +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -0,0 +1,32 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Minimal contract tests for :class:`compiler.CPPBuildArtifact`.""" + +import pathlib +import pickle + +from gt4py._core import definitions as core_defs +from gt4py.next.otf.compilation import compiler + + +def _identity_decorator(raw, _device): + return raw + + +def test_cpp_build_artifact_pickle_round_trip(): + artifact = compiler.CPPBuildArtifact( + src_dir=pathlib.Path("/tmp/build"), + module=pathlib.Path("entry.so"), + entry_point_name="entry", + device_type=core_defs.DeviceType.CPU, + decorator=_identity_decorator, + ) + restored = pickle.loads(pickle.dumps(artifact)) + assert restored == artifact + assert restored.decorator is _identity_decorator diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py new file mode 100644 index 0000000000..1149f3e131 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -0,0 +1,36 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Minimal contract tests for :class:`compilation.DaCeBuildArtifact`.""" + +import pathlib +import pickle + +import pytest + +pytest.importorskip("dace") + +from gt4py._core import definitions as core_defs # noqa: E402 +from gt4py.next.program_processors.runners.dace.workflow import compilation # noqa: E402 + + +def test_dace_build_artifact_pickle_round_trip_drops_live_program(): + artifact = compilation.DaCeBuildArtifact( + build_folder=pathlib.Path("/tmp/build"), + sdfg_dump=pathlib.Path("/tmp/build/program.sdfgz"), + binding_source_code="def update_sdfg_args(*a, **k): ...", + bind_func_name="update_sdfg_args", + device_type=core_defs.DeviceType.CPU, + ) + object.__setattr__(artifact, "_live_program", "") + + restored = pickle.loads(pickle.dumps(artifact)) + + # The data fields round-trip, the live in-process handle does not. + assert restored == artifact + assert restored._live_program is None From 7f219d7ea6bb449936ec4b94261e5fe3efa57635 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:34:04 +0200 Subject: [PATCH 16/24] cleanup --- src/gt4py/next/program_processors/runners/roundtrip.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index d97b3ab238..7d7075157e 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -216,9 +216,9 @@ def fencil_generator( class RoundtripArtifact: """In-memory artifact for the roundtrip backend. - Roundtrip generates and ``exec``\\ s a Python module per program, so its - output is a live callable rather than something on disk. Not picklable — - roundtrip is in-process only. + Roundtrip generates a Python module per program and executes it directly, + so its output is a live callable rather than something on disk. Not + picklable — roundtrip is in-process only. """ program: stages.ExecutableProgram From 69c5ae86d6a5cfbc10bfd854c64f3a03574f63c5 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:49:26 +0200 Subject: [PATCH 17/24] avoid decoration needs picklable --- src/gt4py/next/otf/compilation/compiler.py | 39 +++--- .../next/program_processors/runners/gtfn.py | 118 +++++++++++++++++- .../runners/gtfn_decoration.py | 105 ---------------- .../otf_tests/test_nanobind_build.py | 7 -- .../compilation_tests/test_compiler.py | 6 - 5 files changed, 134 insertions(+), 141 deletions(-) delete mode 100644 src/gt4py/next/program_processors/runners/gtfn_decoration.py diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 5063fb1c1f..25ce0aba09 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,7 +10,7 @@ import dataclasses import pathlib -from typing import Callable, Protocol, TypeVar +from typing import Protocol, TypeVar import factory @@ -41,30 +41,24 @@ def __call__( ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... -# Signature of the per-backend wrapping applied to a freshly imported entry point. -ProgramDecorator = Callable[ - [stages.ExecutableProgram, core_defs.DeviceType], stages.ExecutableProgram -] - - @dataclasses.dataclass(frozen=True) class CPPBuildArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a CPP-style compilation: a Python extension module. - Bindings are baked into the .so (e.g. via nanobind), so :meth:`materialize` - is just an ``importlib`` import + entry-point lookup, plus a per-backend - :attr:`decorator` that adapts the raw callable to the backend's calling - convention. + Bindings are baked into the .so (e.g. via nanobind), so the default + :meth:`materialize` is just an ``importlib`` import + entry-point lookup, + returning the raw imported callable. Backends that need to wrap the + callable in a calling convention (e.g. GTFN's gt4py-shaped argument + conversion) subclass and override :meth:`materialize`. """ src_dir: pathlib.Path module: pathlib.Path entry_point_name: str device_type: core_defs.DeviceType - decorator: ProgramDecorator def materialize(self) -> stages.ExecutableProgram: - """Import the module and apply the configured per-backend decorator. + """Import the .so and return the raw entry point. Must run in the process that will call the returned program: the module is registered in that process's ``sys.modules`` under the @@ -74,7 +68,7 @@ def materialize(self) -> stages.ExecutableProgram: self.src_dir / self.module, sys_modules_prefix="gt4py.__compiled_programs__.", ) - return self.decorator(getattr(m, self.entry_point_name), self.device_type) + return getattr(m, self.entry_point_name) @dataclasses.dataclass(frozen=True) @@ -89,12 +83,15 @@ class CPPCompiler( ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Drive a CPP-style build system and wrap the result in a :class:`CPPBuildArtifact`.""" + """Drive a CPP-style build system and wrap the result in a :class:`CPPBuildArtifact`. + + Backends that need a different artifact subclass (e.g. with a wrapped + ``materialize``) subclass and override :meth:`_make_artifact`. + """ cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] device_type: core_defs.DeviceType - decorator: ProgramDecorator force_recompile: bool = False def __call__( @@ -118,12 +115,16 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) + return self._make_artifact(src_dir, new_data.module, new_data.entry_point_name) + + def _make_artifact( + self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str + ) -> CPPBuildArtifact: return CPPBuildArtifact( src_dir=src_dir, - module=new_data.module, - entry_point_name=new_data.entry_point_name, + module=module, + entry_point_name=entry_point_name, device_type=self.device_type, - decorator=self.decorator, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index dc009c376b..fd652c1ee8 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,18 +6,129 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import dataclasses +import pathlib +from typing import Any + import factory +import numpy as np import gt4py._core.definitions as core_defs import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, config +from gt4py.next import backend, common, config, field_utils +from gt4py.next.embedded import nd_array_field +from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module -from gt4py.next.program_processors.runners import gtfn_decoration + + +def convert_arg(arg: Any) -> Any: + # Note: this function is on the hot path and needs to have minimal overhead. + if (origin := getattr(arg, "__gt_origin__", None)) is not None: + # `Field` is the most likely case, we use `__gt_origin__` as the property is needed anyway + # and (currently) uniquely identifies a `NDArrayField` (which is the only supported `Field`) + assert isinstance(arg, nd_array_field.NdArrayField) + return arg.ndarray, origin + if isinstance(arg, tuple): + return tuple(convert_arg(a) for a in arg) + if isinstance(arg, np.bool_): + # nanobind does not support implicit conversion of `np.bool` to `bool` + return bool(arg) + # TODO(havogt): if this function still appears in profiles, + # we should avoid going through the previous isinstance checks for detecting a scalar. + # E.g. functools.cache on the arg type, returning a function that does the conversion + return arg + + +def convert_args( + inp: stages.ExecutableProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU +) -> stages.ExecutableProgram: + def decorated_program( + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, + ) -> None: + # Note: this function is on the hot path and needs to have minimal overhead. + if out is not None: + args = (*args, out) + converted_args = (convert_arg(arg) for arg in args) + conn_args = extract_connectivity_args(offset_provider, device) + + opt_kwargs: dict[str, Any] = {} + if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): + # If we are collecting metrics, we need to add the `exec_info` argument + # to the `inp` call, which will be used to collect performance metrics. + exec_info: dict[str, float] = {} + opt_kwargs["exec_info"] = exec_info + + # generate implicit domain size arguments only if necessary, using `iter_size_args()` + inp( + *converted_args, + *conn_args, + **opt_kwargs, + ) + + if collect_metrics: + metrics.add_sample_to_current_source( + metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] + ) + + return decorated_program + + +def extract_connectivity_args( + offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: + # Note: this function is on the hot path and needs to have minimal overhead. + zero_origin = (0, 0) + assert all( + hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) + for conn in offset_provider.values() + ) + # Note: the order here needs to agree with the order of the generated bindings. + # This is currently true only because when hashing offset provider dicts, + # the keys' order is taken into account. Any modification to the hashing + # of offset providers may break this assumption here. + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ + (ndarray, zero_origin) + for conn in offset_provider.values() + if (ndarray := getattr(conn, "ndarray", None)) is not None + ] + assert all( + common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) + for conn in offset_provider.values() + if hasattr(conn, "ndarray") + ) + + return args + + +@dataclasses.dataclass(frozen=True) +class GTFNBuildArtifact(compiler.CPPBuildArtifact): + def materialize(self) -> stages.ExecutableProgram: + return convert_args(super().materialize(), device=self.device_type) + + +@dataclasses.dataclass(frozen=True) +class GTFNCompiler(compiler.CPPCompiler): + def _make_artifact( + self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str + ) -> GTFNBuildArtifact: + return GTFNBuildArtifact( + src_dir=src_dir, + module=module, + entry_point_name=entry_point_name, + device_type=self.device_type, + ) + + +class GTFNCompilerFactory(factory.Factory): + class Meta: + model = GTFNCompiler class GTFNBuildWorkflowFactory(factory.Factory): @@ -54,11 +165,10 @@ class Params: nanobind.bind_source ) compilation = factory.SubFactory( - compiler.CompilerFactory, + GTFNCompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), device_type=factory.SelfAttribute("..device_type"), - decorator=gtfn_decoration.convert_args, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn_decoration.py b/src/gt4py/next/program_processors/runners/gtfn_decoration.py deleted file mode 100644 index 1ea2b222ca..0000000000 --- a/src/gt4py/next/program_processors/runners/gtfn_decoration.py +++ /dev/null @@ -1,105 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Calling-convention adapter for GTFN-compiled programs. - -Wraps a freshly-imported GTFN entry point with gt4py's user-facing -argument convention: unpacks fields, splits offset_provider into -connectivity args, attaches metric collection. -""" - -from typing import Any - -import numpy as np - -import gt4py._core.definitions as core_defs -from gt4py.next import common, field_utils -from gt4py.next.embedded import nd_array_field -from gt4py.next.instrumentation import metrics -from gt4py.next.otf import stages - - -def convert_arg(arg: Any) -> Any: - # Note: this function is on the hot path and needs to have minimal overhead. - if (origin := getattr(arg, "__gt_origin__", None)) is not None: - # `Field` is the most likely case, we use `__gt_origin__` as the property is needed anyway - # and (currently) uniquely identifies a `NDArrayField` (which is the only supported `Field`) - assert isinstance(arg, nd_array_field.NdArrayField) - return arg.ndarray, origin - if isinstance(arg, tuple): - return tuple(convert_arg(a) for a in arg) - if isinstance(arg, np.bool_): - # nanobind does not support implicit conversion of `np.bool` to `bool` - return bool(arg) - # TODO(havogt): if this function still appears in profiles, - # we should avoid going through the previous isinstance checks for detecting a scalar. - # E.g. functools.cache on the arg type, returning a function that does the conversion - return arg - - -def convert_args( - inp: stages.ExecutableProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU -) -> stages.ExecutableProgram: - def decorated_program( - *args: Any, - offset_provider: dict[str, common.Connectivity | common.Dimension], - out: Any = None, - ) -> None: - # Note: this function is on the hot path and needs to have minimal overhead. - if out is not None: - args = (*args, out) - converted_args = (convert_arg(arg) for arg in args) - conn_args = extract_connectivity_args(offset_provider, device) - - opt_kwargs: dict[str, Any] = {} - if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): - # If we are collecting metrics, we need to add the `exec_info` argument - # to the `inp` call, which will be used to collect performance metrics. - exec_info: dict[str, float] = {} - opt_kwargs["exec_info"] = exec_info - - # generate implicit domain size arguments only if necessary, using `iter_size_args()` - inp( - *converted_args, - *conn_args, - **opt_kwargs, - ) - - if collect_metrics: - metrics.add_sample_to_current_source( - metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] - ) - - return decorated_program - - -def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: - # Note: this function is on the hot path and needs to have minimal overhead. - zero_origin = (0, 0) - assert all( - hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) - for conn in offset_provider.values() - ) - # Note: the order here needs to agree with the order of the generated bindings. - # This is currently true only because when hashing offset provider dicts, - # the keys' order is taken into account. Any modification to the hashing - # of offset providers may break this assumption here. - args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ - (ndarray, zero_origin) - for conn in offset_provider.values() - if (ndarray := getattr(conn, "ndarray", None)) is not None - ] - assert all( - common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) - for conn in offset_provider.values() - if hasattr(conn, "ndarray") - ) - - return args diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 27ca7c16f1..77db222a11 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -22,11 +22,6 @@ ) -def _identity(raw, _device): - """Pass-through decorator: this test calls the nanobind-bound function with raw args.""" - return raw - - def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( @@ -34,7 +29,6 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory(), device_type=core_defs.DeviceType.CPU, - decorator=_identity, ) ) compiled_program = build_the_program(example_program_source).materialize() @@ -55,7 +49,6 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), device_type=core_defs.DeviceType.CPU, - decorator=_identity, ) ) compiled_program = build_the_program(example_program_source).materialize() diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py index 250c026920..806ea94c93 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -15,18 +15,12 @@ from gt4py.next.otf.compilation import compiler -def _identity_decorator(raw, _device): - return raw - - def test_cpp_build_artifact_pickle_round_trip(): artifact = compiler.CPPBuildArtifact( src_dir=pathlib.Path("/tmp/build"), module=pathlib.Path("entry.so"), entry_point_name="entry", device_type=core_defs.DeviceType.CPU, - decorator=_identity_decorator, ) restored = pickle.loads(pickle.dumps(artifact)) assert restored == artifact - assert restored.decorator is _identity_decorator From e05a33624ba8fe9fabeac18bb667701f74b44ee0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:52:03 +0200 Subject: [PATCH 18/24] cleanup --- .../runners/dace/workflow/compilation.py | 99 +++++++++++++++- .../runners/dace/workflow/compiled_program.py | 110 ------------------ .../runners/dace/workflow/decoration.py | 11 +- .../runners_tests/dace_tests/test_dace.py | 8 +- 4 files changed, 109 insertions(+), 119 deletions(-) delete mode 100644 src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 2f8365e289..2a70628ae8 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -9,21 +9,114 @@ from __future__ import annotations import dataclasses +import os import pathlib -from typing import Optional +import warnings +from collections.abc import Callable, MutableSequence, Sequence +from typing import Any, Optional import dace import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import config, utils as gtx_utils +from gt4py.next import common, config, utils as gtx_utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import ( common as gtx_wfdcommon, decoration as gtx_wfddecoration, ) -from gt4py.next.program_processors.runners.dace.workflow.compiled_program import CompiledDaceProgram + + +class CompiledDaceProgram: + sdfg_program: dace.CompiledSDFG + + # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; + # scalar arguments that are not used in the SDFG will not be present. + sdfg_argtypes: list[dace.dtypes.Data] + + # The compiled program contains a callable object to update the SDFG arguments list. + update_sdfg_ctype_arglist: Callable[ + [ + core_defs.DeviceType, + Sequence[dace.dtypes.Data], + Sequence[Any], + MutableSequence[Any], + common.OffsetProvider, + ], + None, + ] + + # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` + # means that it has not been initialized, i.e. no call was ever performed. + # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. + # - csdfg_init_argv: Arguments used for initialization; used only the first time and + # never updated. + csdfg_argv: MutableSequence[Any] | None + csdfg_init_argv: Sequence[Any] | None + + def __init__( + self, + program: dace.CompiledSDFG, + bind_func_name: str, + binding_source_code: str, + ): + self.sdfg_program = program + + # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument + # name to its data type, in the same order as arguments appear in the program ABI. + # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. + self.sdfg_argtypes = list(program.sdfg.arglist().values()) + + # The binding source code is Python tailored to this specific SDFG. + # We dynamically compile that function and add it to the compiled program. + global_namespace: dict[str, Any] = {} + exec(binding_source_code, global_namespace) + self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] + # For debug purpose, we set a unique module name on the compiled function. + self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) + + # Since the SDFG hasn't been called yet. + self.csdfg_argv = None + self.csdfg_init_argv = None + + def construct_arguments(self, **kwargs: Any) -> None: + """ + This function will process the arguments and store the processed argument + vectors in `self.csdfg_args`, to call them use `self.fast_call()`. + """ + with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): + csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) + # Note we only care about `csdfg_argv` (normal call), since we have to update it, + # we ensure that it is a `list`. + self.csdfg_argv = [*csdfg_argv] + self.csdfg_init_argv = csdfg_init_argv + + def fast_call(self) -> None: + """ + Perform a call to the compiled SDFG using the previously generated argument + vectors, see `self.construct_arguments()`. + """ + assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( + "Argument vector was not set properly." + ) + self.sdfg_program.fast_call( + self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG + ) + + def __call__(self, **kwargs: Any) -> None: + """Call the compiled SDFG with the given arguments. + + Note that this function will not update the argument vectors stored inside + `self`. Furthermore, it is not recommended to use this function as it is + very slow. + """ + warnings.warn( + "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", + stacklevel=1, + ) + result = self.sdfg_program(**kwargs) + assert result is None @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py b/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py deleted file mode 100644 index 5e28853902..0000000000 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py +++ /dev/null @@ -1,110 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import os -import warnings -from collections.abc import Callable, MutableSequence, Sequence -from typing import Any - -import dace - -from gt4py._core import definitions as core_defs -from gt4py.next import common, config - - -class CompiledDaceProgram: - sdfg_program: dace.CompiledSDFG - - # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; - # scalar arguments that are not used in the SDFG will not be present. - sdfg_argtypes: list[dace.dtypes.Data] - - # The compiled program contains a callable object to update the SDFG arguments list. - update_sdfg_ctype_arglist: Callable[ - [ - core_defs.DeviceType, - Sequence[dace.dtypes.Data], - Sequence[Any], - MutableSequence[Any], - common.OffsetProvider, - ], - None, - ] - - # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` - # means that it has not been initialized, i.e. no call was ever performed. - # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. - # - csdfg_init_argv: Arguments used for initialization; used only the first time and - # never updated. - csdfg_argv: MutableSequence[Any] | None - csdfg_init_argv: Sequence[Any] | None - - def __init__( - self, - program: dace.CompiledSDFG, - bind_func_name: str, - binding_source_code: str, - ): - self.sdfg_program = program - - # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument - # name to its data type, in the same order as arguments appear in the program ABI. - # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. - self.sdfg_argtypes = list(program.sdfg.arglist().values()) - - # The binding source code is Python tailored to this specific SDFG. - # We dynamically compile that function and add it to the compiled program. - global_namespace: dict[str, Any] = {} - exec(binding_source_code, global_namespace) - self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] - # For debug purpose, we set a unique module name on the compiled function. - self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) - - # Since the SDFG hasn't been called yet. - self.csdfg_argv = None - self.csdfg_init_argv = None - - def construct_arguments(self, **kwargs: Any) -> None: - """ - This function will process the arguments and store the processed argument - vectors in `self.csdfg_args`, to call them use `self.fast_call()`. - """ - with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): - csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) - # Note we only care about `csdfg_argv` (normal call), since we have to update it, - # we ensure that it is a `list`. - self.csdfg_argv = [*csdfg_argv] - self.csdfg_init_argv = csdfg_init_argv - - def fast_call(self) -> None: - """ - Perform a call to the compiled SDFG using the previously generated argument - vectors, see `self.construct_arguments()`. - """ - assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( - "Argument vector was not set properly." - ) - self.sdfg_program.fast_call( - self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG - ) - - def __call__(self, **kwargs: Any) -> None: - """Call the compiled SDFG with the given arguments. - - Note that this function will not update the argument vectors stored inside - `self`. Furthermore, it is not recommended to use this function as it is - very slow. - """ - warnings.warn( - "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", - stacklevel=1, - ) - result = self.sdfg_program(**kwargs) - assert result is None diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 6b828d5a97..27eb57a82b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence import numpy as np @@ -19,7 +19,14 @@ from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon -from gt4py.next.program_processors.runners.dace.workflow.compiled_program import CompiledDaceProgram + + +if TYPE_CHECKING: + # Type-only: evaluating ``compilation`` at module load would create a cycle + # (compilation imports this module for the materialize body). + from gt4py.next.program_processors.runners.dace.workflow.compilation import ( + CompiledDaceProgram, + ) def convert_args( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 06b3d428bb..a204886690 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -82,7 +82,7 @@ def make_mocks(monkeypatch): # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object mock_fast_call = unittest.mock.MagicMock() gt4py_fast_call = ( - gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram.fast_call + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.fast_call ) def mocked_fast_call(self): @@ -99,21 +99,21 @@ def mocked_fast_call(self): return fast_call_result monkeypatch.setattr( - gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram, + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, "fast_call", mocked_fast_call, ) # Wrap `compiled_sdfg.CompiledSDFG.construct_arguments` with mock object mock_construct_arguments = unittest.mock.MagicMock() - gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram.construct_arguments + gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.construct_arguments def mocked_construct_arguments(self, *args, **kwargs): mock_construct_arguments.__call__(*args, **kwargs) return gt4py_construct_arguments(self, *args, **kwargs) monkeypatch.setattr( - gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram, + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, "construct_arguments", mocked_construct_arguments, ) From 7e9fa08e2d38f49827476a424e51fa0f824406f1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:54:25 +0200 Subject: [PATCH 19/24] cleanup --- .../codegens_tests/gtfn_tests/test_gtfn_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index d027c9dcb1..3d50fbaf52 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -178,7 +178,7 @@ def testee(a: cases.IJKField) -> cases.IJKField: # first call: this generates the cache file cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) - # clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again + # clearing the OTFBuildWorkflow cache such that the OTFBuildWorkflow step is executed again object.__setattr__(cartesian_case.backend.executor, "cache", {}) # second call: the cache file is used cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) From 8f35cb7794813ea154a2f4fc6135873453c400dd Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 22:03:13 +0200 Subject: [PATCH 20/24] Build->Compile, materialize->load --- src/gt4py/next/backend.py | 4 ++-- src/gt4py/next/otf/compilation/compiler.py | 22 +++++++++---------- src/gt4py/next/otf/definitions.py | 12 +++++----- src/gt4py/next/otf/recipes.py | 8 +++---- src/gt4py/next/otf/stages.py | 8 +++---- .../runners/dace/program.py | 2 +- .../runners/dace/workflow/__init__.py | 2 +- .../runners/dace/workflow/compilation.py | 20 ++++++++--------- .../runners/dace/workflow/decoration.py | 2 +- .../runners/dace/workflow/factory.py | 2 +- .../next/program_processors/runners/gtfn.py | 16 +++++++------- .../program_processors/runners/roundtrip.py | 2 +- .../otf_tests/test_nanobind_build.py | 4 ++-- .../compilation_tests/test_compiler.py | 4 ++-- .../otf_tests/test_compiled_program.py | 4 ++-- .../gtfn_tests/test_gtfn_module.py | 2 +- .../dace_tests/test_dace_compilation.py | 4 ++-- 17 files changed, 60 insertions(+), 58 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index b7ad2b2d2c..ae599ece6d 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -147,7 +147,7 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): name: str - executor: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.CompilationArtifact] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef] @@ -157,7 +157,7 @@ def compile( artifact = self.executor( self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args)) ) - return artifact.materialize() + return artifact.load() @property def __gt_allocator__( diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 25ce0aba09..4c0b2681aa 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -42,14 +42,14 @@ def __call__( @dataclasses.dataclass(frozen=True) -class CPPBuildArtifact(gtx_utils.MetadataBasedPickling): +class CPPCompilationArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a CPP-style compilation: a Python extension module. Bindings are baked into the .so (e.g. via nanobind), so the default - :meth:`materialize` is just an ``importlib`` import + entry-point lookup, + :meth:`load` is just an ``importlib`` import + entry-point lookup, returning the raw imported callable. Backends that need to wrap the callable in a calling convention (e.g. GTFN's gt4py-shaped argument - conversion) subclass and override :meth:`materialize`. + conversion) subclass and override :meth:`load`. """ src_dir: pathlib.Path @@ -57,7 +57,7 @@ class CPPBuildArtifact(gtx_utils.MetadataBasedPickling): entry_point_name: str device_type: core_defs.DeviceType - def materialize(self) -> stages.ExecutableProgram: + def load(self) -> stages.ExecutableProgram: """Import the .so and return the raw entry point. Must run in the process that will call the returned program: the @@ -75,18 +75,18 @@ def materialize(self) -> stages.ExecutableProgram: class CPPCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - CPPBuildArtifact, + CPPCompilationArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - CPPBuildArtifact, + CPPCompilationArtifact, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Drive a CPP-style build system and wrap the result in a :class:`CPPBuildArtifact`. + """Drive a CPP-style build system and wrap the result in a :class:`CPPCompilationArtifact`. Backends that need a different artifact subclass (e.g. with a wrapped - ``materialize``) subclass and override :meth:`_make_artifact`. + ``load``) subclass and override :meth:`_make_artifact`. """ cache_lifetime: config.BuildCacheLifetime @@ -97,7 +97,7 @@ class CPPCompiler( def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> CPPBuildArtifact: + ) -> CPPCompilationArtifact: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -119,8 +119,8 @@ def __call__( def _make_artifact( self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str - ) -> CPPBuildArtifact: - return CPPBuildArtifact( + ) -> CPPCompilationArtifact: + return CPPCompilationArtifact( src_dir=src_dir, module=module, entry_point_name=entry_point_name, diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index b5d6a0ecfa..6b33465949 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -56,16 +56,18 @@ def __call__( class CompilationStep( - workflow.Workflow[stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact], + workflow.Workflow[ + stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.CompilationArtifact + ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Run the build system and produce a :class:`stages.BuildArtifact`. + """Run the build system and produce a :class:`stages.CompilationArtifact`. Each backend defines its own concrete artifact dataclass (frozen, - picklable, self-materializing); they all satisfy the - :class:`stages.BuildArtifact` Protocol structurally. + picklable, with a :meth:`stages.CompilationArtifact.load` method); they all + satisfy the :class:`stages.CompilationArtifact` Protocol structurally. """ def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> stages.BuildArtifact: ... + ) -> stages.CompilationArtifact: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index f784a20a12..13a626926d 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,10 +14,10 @@ @dataclasses.dataclass(frozen=True) -class OTFBuildWorkflow( - workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] +class OTFCompileWorkflow( + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.CompilationArtifact] ): - """Translation + bindings + build system; ends at a :class:`stages.BuildArtifact`. + """Translation + bindings + build system; ends at a :class:`stages.CompilationArtifact`. Used as :attr:`gt4py.next.backend.Backend.executor`. The ``cached=True`` backend trait wraps it in a :class:`workflow.CachedStep` keyed on @@ -26,4 +26,4 @@ class OTFBuildWorkflow( translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] + compilation: workflow.Workflow[stages.CompilableProject, stages.CompilationArtifact] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 4a735a76aa..27ee8b45a6 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -129,17 +129,17 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable -class BuildArtifact(Protocol): - """The output of an :class:`recipes.OTFBuildWorkflow`. +class CompilationArtifact(Protocol): + """The output of an :class:`recipes.OTFCompileWorkflow`. Each backend defines its own concrete artifact dataclass; all share this Protocol. Implementations are frozen dataclasses, picklable, and have no - live process-bound state — that is reconstructed by :meth:`materialize`, + live process-bound state — that is reconstructed by :meth:`load`, which returns a directly-callable :class:`ExecutableProgram` taking gt4py-shaped arguments. """ - def materialize(self) -> ExecutableProgram: ... + def load(self) -> ExecutableProgram: ... def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index e1ddbee455..f8c8fd84a3 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -77,7 +77,7 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: ) compile_workflow = typing.cast( - recipes.OTFBuildWorkflow, + recipes.OTFCompileWorkflow, self.backend.executor if not hasattr(self.backend.executor, "step") else self.backend.executor.step, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py index f822709cd2..4d825c0c9b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py @@ -10,7 +10,7 @@ The main module is `backend`, that exports the backends for CPU and GPU devices. The `backend` module uses `factory` to define a workflow that implements the -`OTFBuildWorkflow` recipe. The different stages are implemeted in separate modules: +`OTFCompileWorkflow` recipe. The different stages are implemeted in separate modules: - `translation` for lowering of GTIR to SDFG and applying SDFG transformations - `compilation` for compiling the SDFG into a program - `decoration` to parse the program arguments and pass them to the program call diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 2a70628ae8..ae3fee1540 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -120,7 +120,7 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) -class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): +class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" build_folder: pathlib.Path @@ -130,8 +130,8 @@ class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): device_type: core_defs.DeviceType # Process-local cache of the live :class:`CompiledDaceProgram`. Populated by - # ``DaCeCompiler`` after a fresh compile so :meth:`materialize` can skip the - # SDFG re-deserialize + .so re-link round-trip in the same process. Marked + # ``DaCeCompiler`` after a fresh compile so :meth:`load` can skip the SDFG + # re-deserialize + .so re-link round-trip in the same process. Marked # ``pickle=False`` via :func:`gtx_utils.gt4py_metadata` so a receiver of the # artifact in a different process sees ``None`` and falls back to the # disk-based path. @@ -143,7 +143,7 @@ class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): metadata=gtx_utils.gt4py_metadata(pickle=False), ) - def materialize(self) -> stages.ExecutableProgram: + def load(self) -> stages.ExecutableProgram: """Wrap the compiled program in gt4py's calling convention. Uses the live program cached on the artifact when available; otherwise @@ -172,15 +172,15 @@ def _load_compiled_program(self) -> CompiledDaceProgram: class DaCeCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - DaCeBuildArtifact, + DaCeCompilationArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - DaCeBuildArtifact, + DaCeCompilationArtifact, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], ): - """Run the DaCe build system and produce an on-disk :class:`DaCeBuildArtifact`.""" + """Run the DaCe build system and produce an on-disk :class:`DaCeCompilationArtifact`.""" bind_func_name: str cache_lifetime: config.BuildCacheLifetime @@ -190,7 +190,7 @@ class DaCeCompiler( def __call__( self, inp: stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - ) -> DaCeBuildArtifact: + ) -> DaCeCompilationArtifact: with gtx_wfdcommon.dace_context( device_type=self.device_type, cmake_build_type=self.cmake_build_type, @@ -201,7 +201,7 @@ def __call__( sdfg = dace.SDFG.from_json(inp.program_source.source_code) sdfg.build_folder = str(sdfg_build_folder) with locking.lock(sdfg_build_folder): - # Keep the program handle so the artifact's materialize() can + # Keep the program handle so the artifact's load() can # skip the SDFG re-deserialize + .so re-link round-trip when # used in this same process. sdfg_program = sdfg.compile(validate=False) @@ -216,7 +216,7 @@ def __call__( ) assert inp.binding_source is not None - artifact = DaCeBuildArtifact( + artifact = DaCeCompilationArtifact( build_folder=sdfg_build_folder, sdfg_dump=sdfg_dump, binding_source_code=inp.binding_source.source_code, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 27eb57a82b..8e520856ac 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: # Type-only: evaluating ``compilation`` at module load would create a cycle - # (compilation imports this module for the materialize body). + # (compilation imports this module for the load body). from gt4py.next.program_processors.runners.dace.workflow.compilation import ( CompiledDaceProgram, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 9f6c80fd07..069854a586 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -30,7 +30,7 @@ class DaCeWorkflowFactory(factory.Factory): class Meta: - model = recipes.OTFBuildWorkflow + model = recipes.OTFCompileWorkflow class Params: auto_optimize: bool = False diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index fd652c1ee8..6a8ff1fc69 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -108,17 +108,17 @@ def extract_connectivity_args( @dataclasses.dataclass(frozen=True) -class GTFNBuildArtifact(compiler.CPPBuildArtifact): - def materialize(self) -> stages.ExecutableProgram: - return convert_args(super().materialize(), device=self.device_type) +class GTFNCompilationArtifact(compiler.CPPCompilationArtifact): + def load(self) -> stages.ExecutableProgram: + return convert_args(super().load(), device=self.device_type) @dataclasses.dataclass(frozen=True) class GTFNCompiler(compiler.CPPCompiler): def _make_artifact( self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str - ) -> GTFNBuildArtifact: - return GTFNBuildArtifact( + ) -> GTFNCompilationArtifact: + return GTFNCompilationArtifact( src_dir=src_dir, module=module, entry_point_name=entry_point_name, @@ -131,9 +131,9 @@ class Meta: model = GTFNCompiler -class GTFNBuildWorkflowFactory(factory.Factory): +class GTFNCompileWorkflowFactory(factory.Factory): class Meta: - model = recipes.OTFBuildWorkflow + model = recipes.OTFCompileWorkflow class Params: device_type: core_defs.DeviceType = core_defs.DeviceType.CPU @@ -195,7 +195,7 @@ class Params: device_type = core_defs.DeviceType.CPU hash_function = stages.compilation_hash otf_workflow = factory.SubFactory( - GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) name = factory.LazyAttribute( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 7d7075157e..f076018571 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -223,7 +223,7 @@ class RoundtripArtifact: program: stages.ExecutableProgram - def materialize(self) -> stages.ExecutableProgram: + def load(self) -> stages.ExecutableProgram: return self.program diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 77db222a11..84226c4e03 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -31,7 +31,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source).materialize() + compiled_program = build_the_program(example_program_source).load() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), @@ -51,7 +51,7 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source).materialize() + compiled_program = build_the_program(example_program_source).load() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py index 806ea94c93..42a6687699 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Minimal contract tests for :class:`compiler.CPPBuildArtifact`.""" +"""Minimal contract tests for :class:`compiler.CPPCompilationArtifact`.""" import pathlib import pickle @@ -16,7 +16,7 @@ def test_cpp_build_artifact_pickle_round_trip(): - artifact = compiler.CPPBuildArtifact( + artifact = compiler.CPPCompilationArtifact( src_dir=pathlib.Path("/tmp/build"), module=pathlib.Path("entry.so"), entry_point_name="entry", diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 233e5a2f6e..ed881c9495 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -116,9 +116,9 @@ def test_inlining_of_scalar_works_integration(testee_prog): @dataclasses.dataclass(frozen=True) class _NoOpArtifact: - """A trivial BuildArtifact that materializes to a no-op callable.""" + """A trivial CompilationArtifact that loads to a no-op callable.""" - def materialize(self): + def load(self): return lambda *args, **kwargs: None def pirate(program: toolchain.ConcreteArtifact): diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 3d50fbaf52..d027c9dcb1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -178,7 +178,7 @@ def testee(a: cases.IJKField) -> cases.IJKField: # first call: this generates the cache file cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) - # clearing the OTFBuildWorkflow cache such that the OTFBuildWorkflow step is executed again + # clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again object.__setattr__(cartesian_case.backend.executor, "cache", {}) # second call: the cache file is used cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index 1149f3e131..ec1a926e4a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Minimal contract tests for :class:`compilation.DaCeBuildArtifact`.""" +"""Minimal contract tests for :class:`compilation.DaCeCompilationArtifact`.""" import pathlib import pickle @@ -20,7 +20,7 @@ def test_dace_build_artifact_pickle_round_trip_drops_live_program(): - artifact = compilation.DaCeBuildArtifact( + artifact = compilation.DaCeCompilationArtifact( build_folder=pathlib.Path("/tmp/build"), sdfg_dump=pathlib.Path("/tmp/build/program.sdfgz"), binding_source_code="def update_sdfg_args(*a, **k): ...", From 29508bf421dfd799f1e9174681358b6198bc5139 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 22:09:18 +0200 Subject: [PATCH 21/24] cleanup --- src/gt4py/next/otf/compilation/compiler.py | 12 +++------ src/gt4py/next/otf/recipes.py | 7 +----- .../runners/dace/workflow/compilation.py | 25 ++++++++----------- .../runners/dace/workflow/decoration.py | 3 +-- 4 files changed, 17 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 4c0b2681aa..8f5da88b77 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -45,11 +45,8 @@ def __call__( class CPPCompilationArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a CPP-style compilation: a Python extension module. - Bindings are baked into the .so (e.g. via nanobind), so the default - :meth:`load` is just an ``importlib`` import + entry-point lookup, - returning the raw imported callable. Backends that need to wrap the - callable in a calling convention (e.g. GTFN's gt4py-shaped argument - conversion) subclass and override :meth:`load`. + The default :meth:`load` is an ``importlib`` import + entry-point lookup; + backends override to apply their own calling convention. """ src_dir: pathlib.Path @@ -83,10 +80,9 @@ class CPPCompiler( ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Drive a CPP-style build system and wrap the result in a :class:`CPPCompilationArtifact`. + """Drive a CPP-style build system into a :class:`CPPCompilationArtifact`. - Backends that need a different artifact subclass (e.g. with a wrapped - ``load``) subclass and override :meth:`_make_artifact`. + Backends override :meth:`_make_artifact` to use their own artifact subclass. """ cache_lifetime: config.BuildCacheLifetime diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 13a626926d..0b809e4731 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -17,12 +17,7 @@ class OTFCompileWorkflow( workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.CompilationArtifact] ): - """Translation + bindings + build system; ends at a :class:`stages.CompilationArtifact`. - - Used as :attr:`gt4py.next.backend.Backend.executor`. The ``cached=True`` - backend trait wraps it in a :class:`workflow.CachedStep` keyed on - :class:`definitions.CompilableProgramDef`. - """ + """The typical compiled backend steps composed into a workflow.""" translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index ae3fee1540..2734a67161 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -13,7 +13,7 @@ import pathlib import warnings from collections.abc import Callable, MutableSequence, Sequence -from typing import Any, Optional +from typing import Any import dace import factory @@ -130,12 +130,11 @@ class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): device_type: core_defs.DeviceType # Process-local cache of the live :class:`CompiledDaceProgram`. Populated by - # ``DaCeCompiler`` after a fresh compile so :meth:`load` can skip the SDFG - # re-deserialize + .so re-link round-trip in the same process. Marked - # ``pickle=False`` via :func:`gtx_utils.gt4py_metadata` so a receiver of the - # artifact in a different process sees ``None`` and falls back to the - # disk-based path. - _live_program: Optional[CompiledDaceProgram] = dataclasses.field( + # ``DaCeCompiler`` to skip the disk round-trip when the artifact stays in + # the same process. Excluded from pickle (``pickle=False`` metadata) so + # receivers in other processes see ``None`` and fall through to the + # disk-based load. + _live_program: CompiledDaceProgram | None = dataclasses.field( init=False, default=None, compare=False, @@ -146,10 +145,9 @@ class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): def load(self) -> stages.ExecutableProgram: """Wrap the compiled program in gt4py's calling convention. - Uses the live program cached on the artifact when available; otherwise - re-deserializes the SDFG, re-links the .so via ``compiler.use_cache``, - and caches the result for subsequent calls. Must run in the process - that will call the returned program. + On a miss, re-deserializes the SDFG and re-links the .so via + ``compiler.use_cache``. Must run in the process that will call the + returned program. """ program = self._live_program if program is None: @@ -201,9 +199,8 @@ def __call__( sdfg = dace.SDFG.from_json(inp.program_source.source_code) sdfg.build_folder = str(sdfg_build_folder) with locking.lock(sdfg_build_folder): - # Keep the program handle so the artifact's load() can - # skip the SDFG re-deserialize + .so re-link round-trip when - # used in this same process. + # Keep the handle so the artifact's load() can skip the disk + # round-trip in the same process. sdfg_program = sdfg.compile(validate=False) for dump_name in ("program.sdfgz", "program.sdfg"): diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 8e520856ac..07707b1f1a 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -22,8 +22,7 @@ if TYPE_CHECKING: - # Type-only: evaluating ``compilation`` at module load would create a cycle - # (compilation imports this module for the load body). + # Type-only: a top-level import would cycle with ``compilation``. from gt4py.next.program_processors.runners.dace.workflow.compilation import ( CompiledDaceProgram, ) From 9f4e776e6f1bfa5dd733e8fe3fdf6715a29a4e1e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 28 Apr 2026 08:23:17 +0200 Subject: [PATCH 22/24] use tmp_path fixture --- .../program_processors/runners/dace/workflow/decoration.py | 4 +--- .../unit_tests/otf_tests/compilation_tests/test_compiler.py | 4 ++-- .../runners_tests/dace_tests/test_dace_compilation.py | 6 +++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 07707b1f1a..f9e9f7181b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -23,9 +23,7 @@ if TYPE_CHECKING: # Type-only: a top-level import would cycle with ``compilation``. - from gt4py.next.program_processors.runners.dace.workflow.compilation import ( - CompiledDaceProgram, - ) + from gt4py.next.program_processors.runners.dace.workflow.compilation import CompiledDaceProgram def convert_args( diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py index 42a6687699..7dbaeaf719 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -15,9 +15,9 @@ from gt4py.next.otf.compilation import compiler -def test_cpp_build_artifact_pickle_round_trip(): +def test_cpp_compilation_artifact_pickle_round_trip(tmp_path: pathlib.Path): artifact = compiler.CPPCompilationArtifact( - src_dir=pathlib.Path("/tmp/build"), + src_dir=tmp_path, module=pathlib.Path("entry.so"), entry_point_name="entry", device_type=core_defs.DeviceType.CPU, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index ec1a926e4a..acb4ea24a8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -19,10 +19,10 @@ from gt4py.next.program_processors.runners.dace.workflow import compilation # noqa: E402 -def test_dace_build_artifact_pickle_round_trip_drops_live_program(): +def test_dace_compilation_artifact_pickle_round_trip_drops_live_program(tmp_path: pathlib.Path): artifact = compilation.DaCeCompilationArtifact( - build_folder=pathlib.Path("/tmp/build"), - sdfg_dump=pathlib.Path("/tmp/build/program.sdfgz"), + build_folder=tmp_path, + sdfg_dump=tmp_path / "program.sdfgz", binding_source_code="def update_sdfg_args(*a, **k): ...", bind_func_name="update_sdfg_args", device_type=core_defs.DeviceType.CPU, From 9c9234d82eb453dd6aa9b0144046195f5cd1ac41 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 28 Apr 2026 08:36:12 +0200 Subject: [PATCH 23/24] refactor roundtrip to resepect picklability --- .../program_processors/runners/roundtrip.py | 161 +++++++++--------- 1 file changed, 85 insertions(+), 76 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index f076018571..396eecc173 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -11,11 +11,10 @@ import dataclasses import functools import importlib.util -import pathlib import tempfile import textwrap -import typing -from collections.abc import Callable, Iterable +import types +from collections.abc import Iterable from typing import Any, Optional from gt4py.eve import codegen @@ -106,28 +105,20 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" -_FENCIL_CACHE: dict[int, Callable] = {} +# Caches the generated source by IR hash so re-codegen is skipped within a process. +_SOURCE_CACHE: dict[int, tuple[str, str]] = {} +# Caches the loaded module by source string so re-exec is skipped within a process. +_MODULE_CACHE: dict[str, types.ModuleType] = {} -def fencil_generator( +def _generate_source( ir: itir.Program, debug: bool, use_embedded: bool, offset_provider: common.OffsetProvider, transforms: itir_transforms.GTIRTransform, -) -> stages.ExecutableProgram: - """ - Generate a directly executable fencil from an ITIR node. - - Arguments: - ir: The iterator IR (ITIR) node. - debug: Keep module source containing fencil implementation. - extract_temporaries: Extract intermediate field values into temporaries. - use_embedded: Directly use builtins from embedded backend instead of - generic dispatcher. Gives faster performance and is easier - to debug. - offset_provider: A mapping from offset names to offset providers. - """ +) -> tuple[str, str]: + """Generate the Python source for an ITIR program. Returns ``(source_code, entry_point_name)``.""" # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism cache_key = hash( @@ -139,10 +130,10 @@ def fencil_generator( tuple(common.offset_provider_to_type(offset_provider).items()), ) ) - if cache_key in _FENCIL_CACHE: + if cache_key in _SOURCE_CACHE: if debug: - print(f"Using cached fencil for key {cache_key}") - return _FENCIL_CACHE[cache_key] # A CompiledProgram is just a Callable + print(f"Using cached source for key {cache_key}") + return _SOURCE_CACHE[cache_key] ir = transforms(ir, offset_provider=offset_provider) @@ -178,53 +169,84 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".py", encoding="utf-8", delete=False - ) as source_file: - source_file_name = source_file.name - if debug: - print(source_file_name) - offset_literals = [f'{o} = offset("{o}")' for o in offset_literals] - axis_literals = [ - f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' - for o in axis_literals_set - ] - source_file.write(header) - source_file.write("\n".join(offset_literals)) - source_file.write("\n") - source_file.write("\n".join(axis_literals)) - source_file.write("\n") - source_file.write(program) - try: - spec = importlib.util.spec_from_file_location("module.name", source_file_name) - mod = importlib.util.module_from_spec(spec) # type: ignore - spec.loader.exec_module(mod) # type: ignore - finally: - if not debug: - pathlib.Path(source_file_name).unlink(missing_ok=True) + offset_literals_src = "\n".join(f'{o} = offset("{o}")' for o in offset_literals) + axis_literals_src = "\n".join( + f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' + for o in axis_literals_set + ) + source_code = f"{header}{offset_literals_src}\n{axis_literals_src}\n{program}" assert isinstance(ir, itir.Program) - fencil_name = ir.id - fencil = getattr(mod, fencil_name) + entry_point_name = ir.id + + _SOURCE_CACHE[cache_key] = (source_code, entry_point_name) + return source_code, entry_point_name - _FENCIL_CACHE[cache_key] = fencil - return typing.cast(stages.ExecutableProgram, fencil) +def _load_module(source_code: str, debug: bool) -> types.ModuleType: + if source_code in _MODULE_CACHE: + return _MODULE_CACHE[source_code] + + if debug: + # Write to a real .py so debuggers/tracebacks have file/line info. + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: + source_file.write(source_code) + source_file_name = source_file.name + print(source_file_name) + spec = importlib.util.spec_from_file_location("module.name", source_file_name) + mod = importlib.util.module_from_spec(spec) # type: ignore[arg-type] + spec.loader.exec_module(mod) # type: ignore[union-attr] + else: + mod = types.ModuleType("roundtrip_module") + exec(compile(source_code, "", "exec"), mod.__dict__) + + _MODULE_CACHE[source_code] = mod + return mod @dataclasses.dataclass(frozen=True) class RoundtripArtifact: - """In-memory artifact for the roundtrip backend. + """Source-string artifact for the roundtrip backend. - Roundtrip generates a Python module per program and executes it directly, - so its output is a live callable rather than something on disk. Not - picklable — roundtrip is in-process only. + The generated Python source is the artifact: picklable, re-execed on + :meth:`load`. When ``debug`` is true, ``load`` writes a temporary ``.py`` + so debuggers/tracebacks resolve to source lines. """ - program: stages.ExecutableProgram + source_code: str + entry_point_name: str + column_axis: common.Dimension | None + dispatch_backend: next_backend.Backend | None + debug: bool def load(self) -> stages.ExecutableProgram: - return self.program + mod = _load_module(self.source_code, self.debug) + fencil = getattr(mod, self.entry_point_name) + captured_column_axis = self.column_axis + dispatch_backend = self.dispatch_backend + + def decorated_fencil( + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, + column_axis: Optional[ + common.Dimension + ] = None, # TODO(tehrengruber): unused, kept for signature compat + **kwargs: Any, + ) -> None: + if out is not None: + args = (*args, out) + fencil( + *args, + offset_provider=offset_provider, + backend=dispatch_backend, + column_axis=captured_column_axis, + **kwargs, + ) + + return decorated_fencil @dataclasses.dataclass(frozen=True) @@ -237,7 +259,7 @@ class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, RoundtripArt def __call__(self, inp: definitions.CompilableProgramDef) -> RoundtripArtifact: debug = config.DEBUG if self.debug is None else self.debug - fencil = fencil_generator( + source_code, entry_point_name = _generate_source( inp.data, offset_provider=inp.args.offset_provider, debug=debug, @@ -245,26 +267,13 @@ def __call__(self, inp: definitions.CompilableProgramDef) -> RoundtripArtifact: transforms=self.transforms, ) - def decorated_fencil( - *args: Any, - offset_provider: dict[str, common.Connectivity | common.Dimension], - out: Any = None, - column_axis: Optional[common.Dimension] = None, - **kwargs: Any, - ) -> None: - if out is not None: - args = (*args, out) - if not column_axis: # TODO(tehrengruber): This variable is never used. Bug? - column_axis = inp.args.column_axis - fencil( - *args, - offset_provider=offset_provider, - backend=self.dispatch_backend, - column_axis=inp.args.column_axis, - **kwargs, - ) - - return RoundtripArtifact(program=decorated_fencil) + return RoundtripArtifact( + source_code=source_code, + entry_point_name=entry_point_name, + column_axis=inp.args.column_axis, + dispatch_backend=self.dispatch_backend, + debug=debug, + ) # TODO(tehrengruber): introduce factory From edd9e9372ff90f6fd13f7017f72b35167c1787db Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 28 Apr 2026 12:38:49 +0200 Subject: [PATCH 24/24] sdfg as part of artifact (because of upcoming change) --- .../runners/dace/workflow/compilation.py | 46 ++++++++++--------- .../dace_tests/test_dace_compilation.py | 2 +- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 2734a67161..1f69f1ad71 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import json import os import pathlib import warnings @@ -16,6 +17,7 @@ from typing import Any import dace +import dace.codegen.compiler as dace_compiler import factory from gt4py._core import definitions as core_defs, locking @@ -121,10 +123,16 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): - """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" + """Result of a DaCe compilation: build folder + SDFG bindings + the SDFG itself. + + The SDFG is carried inline as JSON because dace's load path + (:func:`get_program_handle`) needs an SDFG instance to wrap into the + returned :class:`CompiledSDFG`, and the build folder may not contain a + ``program.sdfg(z)`` dump under the upcoming minimal-build-dir mode. + """ build_folder: pathlib.Path - sdfg_dump: pathlib.Path + sdfg_json: str binding_source_code: str bind_func_name: str device_type: core_defs.DeviceType @@ -145,9 +153,10 @@ class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): def load(self) -> stages.ExecutableProgram: """Wrap the compiled program in gt4py's calling convention. - On a miss, re-deserializes the SDFG and re-links the .so via - ``compiler.use_cache``. Must run in the process that will call the - returned program. + On a miss, loads the precompiled .so directly via + :func:`dace.codegen.compiler.get_program_handle` — no recompilation, + no ``dace.config`` re-entry. Must run in the process that will call + the returned program. """ program = self._live_program if program is None: @@ -156,13 +165,15 @@ def load(self) -> stages.ExecutableProgram: return gtx_wfddecoration.convert_args(program, device=self.device_type) def _load_compiled_program(self) -> CompiledDaceProgram: - sdfg = dace.SDFG.from_file(str(self.sdfg_dump)) - sdfg.build_folder = str(self.build_folder) - - with gtx_wfdcommon.dace_context(device_type=self.device_type): - with dace.config.set_temporary("compiler", "use_cache", value=True): - sdfg_program = sdfg.compile(validate=False) - + # TODO(phimuell): Drop ``sdfg_json`` from the artifact once dace + # exposes a load path that doesn't require an SDFG instance to wrap + # into the returned ``CompiledSDFG``. + sdfg = dace.SDFG.from_json(json.loads(self.sdfg_json)) + folder_version = dace_compiler.get_folder_version(self.build_folder) + library_path = dace_compiler.get_binary_name( + self.build_folder, sdfg_name=sdfg.name, folder_version=folder_version + ) + sdfg_program = dace_compiler.get_program_handle(library_path, sdfg) return CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) @@ -203,19 +214,10 @@ def __call__( # round-trip in the same process. sdfg_program = sdfg.compile(validate=False) - for dump_name in ("program.sdfgz", "program.sdfg"): - sdfg_dump = sdfg_build_folder / dump_name - if sdfg_dump.exists(): - break - else: - raise RuntimeError( - f"No SDFG dump (program.sdfgz / program.sdfg) found in '{sdfg_build_folder}'." - ) - assert inp.binding_source is not None artifact = DaCeCompilationArtifact( build_folder=sdfg_build_folder, - sdfg_dump=sdfg_dump, + sdfg_json=json.dumps(inp.program_source.source_code), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, device_type=self.device_type, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index acb4ea24a8..29d0ded9e1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -22,7 +22,7 @@ def test_dace_compilation_artifact_pickle_round_trip_drops_live_program(tmp_path: pathlib.Path): artifact = compilation.DaCeCompilationArtifact( build_folder=tmp_path, - sdfg_dump=tmp_path / "program.sdfgz", + sdfg_json="{}", binding_source_code="def update_sdfg_args(*a, **k): ...", bind_func_name="update_sdfg_args", device_type=core_defs.DeviceType.CPU,