From 6a35549d8948e5d607d65d637fdf50a2b287536b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 12:59:46 +0200 Subject: [PATCH 01/35] refactor[next]: otf split build and load --- src/gt4py/next/otf/compilation/compiler.py | 29 +++++--- src/gt4py/next/otf/definitions.py | 6 +- src/gt4py/next/otf/recipes.py | 27 +++++++- src/gt4py/next/otf/stages.py | 10 +++ .../program_processors/formatters/gtfn.py | 2 +- .../runners/dace/program.py | 18 ++--- .../runners/dace/workflow/backend.py | 22 +++--- .../runners/dace/workflow/compilation.py | 69 ++++++++++++++++--- .../runners/dace/workflow/factory.py | 47 ++++++++++++- .../next/program_processors/runners/gtfn.py | 47 +++++++++++-- .../test_temporaries_with_sizes.py | 14 ++-- .../iterator_tests/test_builtins.py | 6 +- .../gtfn_tests/test_gtfn_module.py | 4 +- .../runners_tests/test_gtfn.py | 21 +++--- 14 files changed, 254 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..8fa999bb3c 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -48,15 +48,15 @@ def __call__( class Compiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.ExecutableProgram, + stages.BuildArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.ExecutableProgram, + stages.BuildArtifact, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + """Use any build system (via configured factory) to compile a GT4Py program into an on-disk ``BuildArtifact``.""" cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] @@ -65,7 +65,7 @@ class Compiler( def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> stages.ExecutableProgram: + ) -> stages.BuildArtifact: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -83,12 +83,25 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - m = importer.import_from_path( - src_dir / new_data.module, sys_modules_prefix="gt4py.__compiled_programs__." + return stages.BuildArtifact( + src_dir=src_dir, + module=new_data.module, + entry_point_name=new_data.entry_point_name, ) - func = getattr(m, new_data.entry_point_name) - return func + +def load_artifact(artifact: stages.BuildArtifact) -> stages.ExecutableProgram: + """Dynamically import a previously-built module and return its entry point. + + Must run in the process that will ultimately call the returned program, since + the module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. + """ + m = importer.import_from_path( + artifact.src_dir / artifact.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return getattr(m, artifact.entry_point_name) class CompilerFactory(factory.Factory): diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 11b42dc6ce..9e4f7dc586 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -57,12 +57,12 @@ def __call__( class CompilationStep( workflow.Workflow[ - stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.ExecutableProgram + stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" + """Run the build system and produce an on-disk artifact (CompilableSource -> BuildArtifact).""" def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> stages.ExecutableProgram: ... + ) -> stages.BuildArtifact: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 79cd17162b..573c3581fe 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,10 +14,31 @@ @dataclasses.dataclass(frozen=True) -class OTFCompileWorkflow(workflow.NamedStepSequence): - """The typical compiled backend steps composed into a workflow.""" +class OTFBuildWorkflow( + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] +): + """Translation + bindings + build system; ends at an on-disk :class:`stages.BuildArtifact`.""" translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, stages.ExecutableProgram] + compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] + + +@dataclasses.dataclass(frozen=True) +class OTFFinalizeWorkflow( + workflow.NamedStepSequence[stages.BuildArtifact, stages.ExecutableProgram] +): + """Import the built module and apply decoration to get a live callable.""" + + load: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram] + + +@dataclasses.dataclass(frozen=True) +class OTFCompileWorkflow( + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.ExecutableProgram] +): + """Full OTF pipeline: the ``build`` phase ends at a picklable artifact, ``finalize`` rehydrates it.""" + + build: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] + finalize: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index b6816b1cc3..a0a6c6216e 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import pathlib from collections.abc import Callable from typing import Generic, Optional, Protocol, TypeAlias, TypeVar @@ -129,6 +130,15 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable +@dataclasses.dataclass(frozen=True) +class BuildArtifact: + """On-disk result of a compilation: everything a later step needs to import it.""" + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + + def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index 1d65b8d8d0..c20f7a8555 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -17,7 +17,7 @@ @program_formatter.program_formatter def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. - gtfn_translation = gtfn.GTFNBackendFactory().executor.translation # type: ignore[attr-defined] + gtfn_translation = gtfn.GTFNBackendFactory().executor.build.translation # type: ignore[attr-defined] assert isinstance(gtfn_translation, GTFNTranslationStep) return gtfn_translation.generate_stencil_source( program, diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index f8c8fd84a3..c13daa249f 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -76,16 +76,16 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: gt4py_program_args=[p.type for p in program.params], ) - compile_workflow = typing.cast( - recipes.OTFCompileWorkflow, - self.backend.executor - if not hasattr(self.backend.executor, "step") - else self.backend.executor.step, - ) # We know which backend we are using, but we don't know if the compile workflow is cached. + compile_workflow = typing.cast(recipes.OTFCompileWorkflow, self.backend.executor) + build_workflow = ( + compile_workflow.build.step + if hasattr(compile_workflow.build, "step") + else compile_workflow.build + ) # the `build` phase may be wrapped in a `CachedStep` depending on backend configuration. compile_workflow_translation = ( - compile_workflow.translation - if not hasattr(compile_workflow.translation, "step") - else compile_workflow.translation.step + build_workflow.translation.step + if hasattr(build_workflow.translation, "step") + else build_workflow.translation ) # Same for the translation stage, which could be a `CachedStep` depending on backend configuration. # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index de6778a750..935655a422 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -8,6 +8,7 @@ from __future__ import annotations +import dataclasses import warnings from typing import Any, Final @@ -44,7 +45,12 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + lambda o: dataclasses.replace( + o.otf_workflow, + build=workflow.CachedStep( + o.otf_workflow.build, hash_function=o.hash_function + ), + ) ), name_cached="_cached", ) @@ -127,13 +133,13 @@ def make_dace_backend( gpu=gpu, cached=cached, auto_optimize=auto_optimize, - otf_workflow__cached_translation=cached, - otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), - otf_workflow__bare_translation__auto_optimize_args=optimization_args, - otf_workflow__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, - otf_workflow__bare_translation__use_metrics=use_metrics, - otf_workflow__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, - otf_workflow__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, + otf_workflow__build__cached_translation=cached, + otf_workflow__build__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), + otf_workflow__build__bare_translation__auto_optimize_args=optimization_args, + otf_workflow__build__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, + otf_workflow__build__bare_translation__use_metrics=use_metrics, + otf_workflow__build__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, + otf_workflow__build__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index e1747b7ac3..dcbe73454f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -10,6 +10,8 @@ import dataclasses import os +import pathlib +import types import warnings from collections.abc import Callable, MutableSequence, Sequence from typing import Any @@ -114,19 +116,28 @@ def __call__(self, **kwargs: Any) -> None: assert result is None +@dataclasses.dataclass(frozen=True) +class DaCeBuildArtifact: + """On-disk result of a DaCe compilation.""" + + build_folder: pathlib.Path + binding_source_code: str + bind_func_name: str + + @dataclasses.dataclass(frozen=True) class DaCeCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - CompiledDaceProgram, + DaCeBuildArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - CompiledDaceProgram, + DaCeBuildArtifact, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], ): - """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + """Run the DaCe build system and produce an on-disk :class:`DaCeBuildArtifact`.""" bind_func_name: str cache_lifetime: config.BuildCacheLifetime @@ -136,7 +147,7 @@ class DaCeCompiler( def __call__( self, inp: stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - ) -> CompiledDaceProgram: + ) -> DaCeBuildArtifact: with gtx_wfdcommon.dace_context( device_type=self.device_type, cmake_build_type=self.cmake_build_type, @@ -147,16 +158,56 @@ def __call__( sdfg = dace.SDFG.from_json(inp.program_source.source_code) sdfg.build_folder = sdfg_build_folder with locking.lock(sdfg_build_folder): - sdfg_program = sdfg.compile(validate=False) + sdfg.compile(validate=False, return_program_handle=False) assert inp.binding_source is not None - return CompiledDaceProgram( - sdfg_program, - self.bind_func_name, - inp.binding_source, + return DaCeBuildArtifact( + build_folder=pathlib.Path(sdfg_build_folder), + binding_source_code=inp.binding_source.source_code, + bind_func_name=self.bind_func_name, ) +@dataclasses.dataclass(frozen=True) +class DaCeLoader( + workflow.ChainableWorkflowMixin[DaCeBuildArtifact, CompiledDaceProgram], + workflow.ReplaceEnabledWorkflowMixin[DaCeBuildArtifact, CompiledDaceProgram], +): + """Rehydrate a :class:`DaCeBuildArtifact` into a live :class:`CompiledDaceProgram`.""" + + device_type: core_defs.DeviceType + cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + + def __call__(self, artifact: DaCeBuildArtifact) -> CompiledDaceProgram: + for dump_name in ("program.sdfgz", "program.sdfg"): + sdfg_dump = artifact.build_folder / dump_name + if sdfg_dump.exists(): + break + else: + raise RuntimeError( + f"No SDFG dump (program.sdfgz / program.sdfg) found in '{artifact.build_folder}'." + ) + + sdfg = dace.SDFG.from_file(str(sdfg_dump)) + sdfg.build_folder = str(artifact.build_folder) + + with gtx_wfdcommon.dace_context( + device_type=self.device_type, + cmake_build_type=self.cmake_build_type, + ): + # use_cache=True forces DaCe to load the existing .so without re-codegen. + with dace.config.set_temporary("compiler", "use_cache", value=True): + sdfg_program = sdfg.compile(validate=False) + + binding_source_shim = types.SimpleNamespace(source_code=artifact.binding_source_code) + return CompiledDaceProgram(sdfg_program, artifact.bind_func_name, binding_source_shim) # type: ignore[arg-type] + + class DaCeCompilationStepFactory(factory.Factory): class Meta: model = DaCeCompiler + + +class DaCeLoaderFactory(factory.Factory): + class Meta: + model = DaCeLoader diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 62febd0965..12441587c5 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -22,6 +22,7 @@ ) from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, + DaCeLoaderFactory, ) from gt4py.next.program_processors.runners.dace.workflow.translation import ( DaCeTranslationStepFactory, @@ -31,9 +32,9 @@ _GT_DACE_BINDING_FUNCTION_NAME: Final[str] = "update_sdfg_args" -class DaCeWorkflowFactory(factory.Factory): +class DaCeBuildWorkflowFactory(factory.Factory): class Meta: - model = recipes.OTFCompileWorkflow + model = recipes.OTFBuildWorkflow class Params: auto_optimize: bool = False @@ -72,9 +73,51 @@ class Params: device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) + + +class DaCeFinalizeWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFFinalizeWorkflow + + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough + lambda: config.CMAKE_BUILD_TYPE + ) + + load = factory.SubFactory( + DaCeLoaderFactory, + device_type=factory.SelfAttribute("..device_type"), + cmake_build_type=factory.SelfAttribute("..cmake_build_type"), + ) decoration = factory.LazyAttribute( lambda o: functools.partial( decoration_step.convert_args, device=o.device_type, ) ) + + +class DaCeWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFCompileWorkflow + + class Params: + auto_optimize: bool = False + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough + lambda: config.CMAKE_BUILD_TYPE + ) + cached_translation = factory.Trait(build__cached_translation=True) + + build = factory.SubFactory( + DaCeBuildWorkflowFactory, + device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), + cmake_build_type=factory.SelfAttribute("..cmake_build_type"), + ) + finalize = factory.SubFactory( + DaCeFinalizeWorkflowFactory, + device_type=factory.SelfAttribute("..device_type"), + cmake_build_type=factory.SelfAttribute("..cmake_build_type"), + ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c1743dea6a..6a600d5b5f 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import dataclasses import functools from typing import Any @@ -106,9 +107,9 @@ def extract_connectivity_args( return args -class GTFNCompileWorkflowFactory(factory.Factory): +class GTFNBuildWorkflowFactory(factory.Factory): class Meta: - model = recipes.OTFCompileWorkflow + model = recipes.OTFBuildWorkflow class Params: device_type: core_defs.DeviceType = core_defs.DeviceType.CPU @@ -144,11 +145,37 @@ class Params: cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), ) + + +class GTFNFinalizeWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFFinalizeWorkflow + + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + + load = factory.LazyFunction(lambda: compiler.load_artifact) decoration = factory.LazyAttribute( lambda o: functools.partial(convert_args, device=o.device_type) ) +class GTFNCompileWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFCompileWorkflow + + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cached_translation = factory.Trait(build__cached_translation=True) + + build = factory.SubFactory( + GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + ) + finalize = factory.SubFactory( + GTFNFinalizeWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + ) + + class GTFNBackendFactory(factory.Factory): class Meta: model = backend.Backend @@ -165,7 +192,12 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + lambda o: dataclasses.replace( + o.otf_workflow, + build=workflow.CachedStep( + o.otf_workflow.build, hash_function=o.hash_function + ), + ) ), name_cached="_cached", ) @@ -187,17 +219,18 @@ class Params: run_gtfn = GTFNBackendFactory() run_gtfn_imperative = GTFNBackendFactory( - name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True + name_postfix="_imperative", + otf_workflow__build__translation__use_imperative_backend=True, ) -run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__build__cached_translation=True) run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory( - gpu=True, cached=True, otf_workflow__cached_translation=True + gpu=True, cached=True, otf_workflow__build__cached_translation=True ) run_gtfn_no_transforms = GTFNBackendFactory( - otf_workflow__bare_translation__enable_itir_transforms=False + otf_workflow__build__bare_translation__enable_itir_transforms=False ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 90c0d775f2..c3058f33a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -34,12 +34,14 @@ def exec_alloc_descriptor(): name="run_gtfn_with_temporaries_and_sizes", transforms=backend.DEFAULT_TRANSFORMS, executor=run_gtfn.executor.replace( - translation=run_gtfn.executor.translation.replace( - symbolic_domain_sizes={ - "Cell": "num_cells", - "Edge": "num_edges", - "Vertex": "num_vertices", - } + build=run_gtfn.executor.build.replace( + translation=run_gtfn.executor.build.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + } + ) ) ), allocator=run_gtfn.allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 4fff5192aa..9d39b5d63a 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -192,7 +192,11 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): gtfn_without_transforms = dataclasses.replace( run_gtfn, executor=run_gtfn.executor.replace( - translation=run_gtfn.executor.translation.replace(enable_itir_transforms=False), + build=run_gtfn.executor.build.replace( + translation=run_gtfn.executor.build.translation.replace( + enable_itir_transforms=False + ), + ), ), # avoid inlining the function ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index d027c9dcb1..3611abad89 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -135,11 +135,11 @@ def test_gtfn_file_cache(program_example): ) cached_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=True - ).executor.step.translation + ).executor.build.step.translation bare_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=False - ).executor.step.translation + ).executor.build.step.translation cache_key = stages.fingerprint_compilable_program(compilable_program) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 96d8c6e27c..f088761ffd 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -34,11 +34,11 @@ def test_backend_factory_trait_device(): assert cpu_version.name == "run_gtfn_cpu" assert gpu_version.name == "run_gtfn_gpu" - assert cpu_version.executor.translation.device_type is core_defs.DeviceType.CPU - assert gpu_version.executor.translation.device_type is core_defs.DeviceType.CUDA + assert cpu_version.executor.build.translation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.build.translation.device_type is core_defs.DeviceType.CUDA - assert cpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CPU - assert gpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CUDA + assert cpu_version.executor.finalize.decoration.keywords["device"] is core_defs.DeviceType.CPU + assert gpu_version.executor.finalize.decoration.keywords["device"] is core_defs.DeviceType.CUDA assert custom_layout_allocators.is_field_allocator_for( cpu_version.allocator, core_defs.DeviceType.CPU @@ -50,7 +50,7 @@ def test_backend_factory_trait_device(): def test_backend_factory_trait_cached(): cached_version = gtfn.GTFNBackendFactory(gpu=False, cached=True) - assert isinstance(cached_version.executor, workflow.CachedStep) + assert isinstance(cached_version.executor.build, workflow.CachedStep) assert cached_version.name == "run_gtfn_cpu_cached" @@ -60,9 +60,12 @@ def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.PERSISTENT) persistent_version = gtfn.GTFNBackendFactory() - assert session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION assert ( - persistent_version.executor.compilation.cache_lifetime + session_version.executor.build.compilation.cache_lifetime + is config.BuildCacheLifetime.SESSION + ) + assert ( + persistent_version.executor.build.compilation.cache_lifetime is config.BuildCacheLifetime.PERSISTENT ) @@ -74,10 +77,10 @@ def test_backend_factory_build_type_config(monkeypatch): min_size_version = gtfn.GTFNBackendFactory() assert ( - release_version.executor.compilation.builder_factory.cmake_build_type + release_version.executor.build.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.RELEASE ) assert ( - min_size_version.executor.compilation.builder_factory.cmake_build_type + min_size_version.executor.build.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.MIN_SIZE_REL ) From d2510573866e50dba2597b22d72379e7b1653824 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 15:04:56 +0200 Subject: [PATCH 02/35] refactorings to combine load+decoration --- src/gt4py/next/otf/compilation/compiler.py | 45 +++++----- src/gt4py/next/otf/definitions.py | 16 ++-- src/gt4py/next/otf/recipes.py | 51 +++++++---- src/gt4py/next/otf/stages.py | 10 --- .../runners/dace/workflow/compilation.py | 89 ++++++++++--------- .../runners/dace/workflow/factory.py | 32 +------ .../next/program_processors/runners/gtfn.py | 39 ++++---- .../runners_tests/test_gtfn.py | 5 +- 8 files changed, 142 insertions(+), 145 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 8fa999bb3c..9d77d50b07 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -14,10 +14,10 @@ import factory -from gt4py._core import locking +from gt4py._core import definitions as core_defs, locking from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow -from gt4py.next.otf.compilation import build_data, cache, importer +from gt4py.next.otf.compilation import build_data, cache T = TypeVar("T") @@ -44,28 +44,44 @@ def __call__( ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... +@dataclasses.dataclass(frozen=True) +class GTFNBuildArtifact: + """On-disk result of a GTFN compilation: a Python extension module. + + Bindings are baked into the .so via nanobind, so the load step is just an + ``importlib`` import + entry-point symbol lookup. ``device_type`` is + intrinsic to the artifact: a CPU-built .so cannot be loaded as GPU. + """ + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + device_type: core_defs.DeviceType + + @dataclasses.dataclass(frozen=True) class Compiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.BuildArtifact, + GTFNBuildArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.BuildArtifact, + GTFNBuildArtifact, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Use any build system (via configured factory) to compile a GT4Py program into an on-disk ``BuildArtifact``.""" + """Use any build system (via configured factory) to compile a GT4Py program into a :class:`GTFNBuildArtifact`.""" cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + device_type: core_defs.DeviceType force_recompile: bool = False def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> stages.BuildArtifact: + ) -> GTFNBuildArtifact: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -83,27 +99,14 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - return stages.BuildArtifact( + return GTFNBuildArtifact( src_dir=src_dir, module=new_data.module, entry_point_name=new_data.entry_point_name, + device_type=self.device_type, ) -def load_artifact(artifact: stages.BuildArtifact) -> stages.ExecutableProgram: - """Dynamically import a previously-built module and return its entry point. - - Must run in the process that will ultimately call the returned program, since - the module is registered in that process's ``sys.modules`` under the - ``gt4py.__compiled_programs__.`` prefix. - """ - m = importer.import_from_path( - artifact.src_dir / artifact.module, - sys_modules_prefix="gt4py.__compiled_programs__.", - ) - return getattr(m, artifact.entry_point_name) - - class CompilerFactory(factory.Factory): class Meta: model = Compiler diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 9e4f7dc586..c242e02aa2 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Protocol, TypeAlias, TypeVar +from typing import Any, Protocol, TypeAlias, TypeVar from gt4py.next.ffront import stages as ffront_stages from gt4py.next.iterator import ir as itir @@ -56,13 +56,17 @@ def __call__( class CompilationStep( - workflow.Workflow[ - stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact - ], + workflow.Workflow[stages.CompilableProject[CodeSpecT, TargetCodeSpecT], Any], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Run the build system and produce an on-disk artifact (CompilableSource -> BuildArtifact).""" + """Run the build system and produce an on-disk, backend-specific build artifact. + + The artifact type is intentionally :class:`Any` here — each backend defines + its own concrete dataclass (frozen, picklable). The build/finalize boundary + in :class:`recipes.OTFCompileWorkflow` only requires that whatever + ``CompilationStep`` produces is what the backend's ``finalize`` consumes. + """ def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> stages.BuildArtifact: ... + ) -> Any: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 573c3581fe..0ded8426aa 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -9,36 +9,53 @@ from __future__ import annotations import dataclasses +from typing import Any from gt4py.next.otf import definitions, stages, workflow @dataclasses.dataclass(frozen=True) class OTFBuildWorkflow( - workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] + workflow.NamedStepSequence[definitions.CompilableProgramDef, Any] ): - """Translation + bindings + build system; ends at an on-disk :class:`stages.BuildArtifact`.""" + """Translation + bindings + build system; ends at an on-disk artifact. - translation: definitions.TranslationStep - bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] + The artifact type is backend-specific (e.g. a ``GTFNBuildArtifact`` or a + ``DaCeBuildArtifact``); a workflow only ever pairs a backend's build with + that same backend's finalize, so no cross-backend artifact protocol is + needed. + Grouped as a sub-workflow so the ``cached=True`` backend trait can wrap + just this sub-workflow in a :class:`workflow.CachedStep` — caching keys + on :class:`definitions.CompilableProgramDef` and values on a picklable, + backend-specific artifact dataclass. + """ -@dataclasses.dataclass(frozen=True) -class OTFFinalizeWorkflow( - workflow.NamedStepSequence[stages.BuildArtifact, stages.ExecutableProgram] -): - """Import the built module and apply decoration to get a live callable.""" - - load: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] - decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram] + translation: definitions.TranslationStep + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] + compilation: workflow.Workflow[stages.CompilableProject, Any] @dataclasses.dataclass(frozen=True) class OTFCompileWorkflow( workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.ExecutableProgram] ): - """Full OTF pipeline: the ``build`` phase ends at a picklable artifact, ``finalize`` rehydrates it.""" - - build: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] - finalize: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] + """Full OTF pipeline: two phases separated by an on-disk artifact boundary. + + 1. ``build`` — produces a picklable, backend-specific build artifact. + Heavy, idempotent, parallelizable across processes; the natural cache + target. + 2. ``finalize`` — rehydrates the artifact into a directly-callable + :class:`stages.ExecutableProgram`. Backend-internal; whatever + sequence of "load the .so / wrap with gt4py calling convention / + attach metrics" the backend needs. + + The artifact dataclass is the contract between these two phases. By + convention, artifacts are frozen dataclasses, picklable across process + boundaries, and self-describing (carry every property finalize needs, + e.g. ``device_type``). Each backend defines its own; nothing about that + contract is enforced by this module — it is per-backend convention. + """ + + build: workflow.Workflow[definitions.CompilableProgramDef, Any] + finalize: workflow.Workflow[Any, stages.ExecutableProgram] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index a0a6c6216e..b6816b1cc3 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -9,7 +9,6 @@ from __future__ import annotations import dataclasses -import pathlib from collections.abc import Callable from typing import Generic, Optional, Protocol, TypeAlias, TypeVar @@ -130,15 +129,6 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable -@dataclasses.dataclass(frozen=True) -class BuildArtifact: - """On-disk result of a compilation: everything a later step needs to import it.""" - - src_dir: pathlib.Path - module: pathlib.Path - entry_point_name: str - - def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index dcbe73454f..3215d44c10 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -11,7 +11,6 @@ import dataclasses import os import pathlib -import types import warnings from collections.abc import Callable, MutableSequence, Sequence from typing import Any @@ -57,7 +56,7 @@ def __init__( self, program: dace.CompiledSDFG, bind_func_name: str, - binding_source: stages.BindingSource[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], + binding_source_code: str, ): self.sdfg_program = program @@ -66,9 +65,10 @@ def __init__( # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. self.sdfg_argtypes = list(program.sdfg.arglist().values()) - # Note that `binding_source` contains Python code tailored to this specific SDFG. - # Here we dinamically compile this function and add it to the compiled program. - exec(binding_source.source_code, global_namespace := {}) # type: ignore[var-annotated] + # The binding source code is Python tailored to this specific SDFG. + # We dynamically compile that function and add it to the compiled program. + global_namespace: dict[str, Any] = {} + exec(binding_source_code, global_namespace) self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] # For debug purpose, we set a unique module name on the compiled function. self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) @@ -118,11 +118,18 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) class DaCeBuildArtifact: - """On-disk result of a DaCe compilation.""" + """On-disk result of a DaCe compilation. + + Carries the ``device_type`` the artifact was built for; a CPU-built .so + cannot be loaded as GPU. Also carries the bindings (Python source code + that the loader ``exec``\\ s to materialize the SDFG argument-marshalling + function). + """ build_folder: pathlib.Path binding_source_code: str bind_func_name: str + device_type: core_defs.DeviceType @dataclasses.dataclass(frozen=True) @@ -165,49 +172,49 @@ def __call__( build_folder=pathlib.Path(sdfg_build_folder), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, + device_type=self.device_type, ) -@dataclasses.dataclass(frozen=True) -class DaCeLoader( - workflow.ChainableWorkflowMixin[DaCeBuildArtifact, CompiledDaceProgram], - workflow.ReplaceEnabledWorkflowMixin[DaCeBuildArtifact, CompiledDaceProgram], -): - """Rehydrate a :class:`DaCeBuildArtifact` into a live :class:`CompiledDaceProgram`.""" - - device_type: core_defs.DeviceType - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG - - def __call__(self, artifact: DaCeBuildArtifact) -> CompiledDaceProgram: - for dump_name in ("program.sdfgz", "program.sdfg"): - sdfg_dump = artifact.build_folder / dump_name - if sdfg_dump.exists(): - break - else: - raise RuntimeError( - f"No SDFG dump (program.sdfgz / program.sdfg) found in '{artifact.build_folder}'." - ) +def dace_finalize(artifact: DaCeBuildArtifact) -> stages.ExecutableProgram: + """Turn a :class:`DaCeBuildArtifact` into a directly-callable program. + + Re-deserializes the SDFG dump from the build folder, links against the + pre-built .so via ``compiler.use_cache=True`` (no re-codegen), wraps it + in a :class:`CompiledDaceProgram`, and applies gt4py's calling convention + via :func:`decoration.convert_args`. Reads the target device from the + artifact. + + Must run in the process that will ultimately call the returned program. + """ + # Local import to avoid a circular reference (decoration imports compilation). + from gt4py.next.program_processors.runners.dace.workflow import ( + decoration as gtx_wfddecoration, + ) + + for dump_name in ("program.sdfgz", "program.sdfg"): + sdfg_dump = artifact.build_folder / dump_name + if sdfg_dump.exists(): + break + else: + raise RuntimeError( + f"No SDFG dump (program.sdfgz / program.sdfg) found in '{artifact.build_folder}'." + ) - sdfg = dace.SDFG.from_file(str(sdfg_dump)) - sdfg.build_folder = str(artifact.build_folder) + sdfg = dace.SDFG.from_file(str(sdfg_dump)) + sdfg.build_folder = str(artifact.build_folder) - with gtx_wfdcommon.dace_context( - device_type=self.device_type, - cmake_build_type=self.cmake_build_type, - ): - # use_cache=True forces DaCe to load the existing .so without re-codegen. - with dace.config.set_temporary("compiler", "use_cache", value=True): - sdfg_program = sdfg.compile(validate=False) + with gtx_wfdcommon.dace_context(device_type=artifact.device_type): + # use_cache=True forces DaCe to load the existing .so without re-codegen. + with dace.config.set_temporary("compiler", "use_cache", value=True): + sdfg_program = sdfg.compile(validate=False) - binding_source_shim = types.SimpleNamespace(source_code=artifact.binding_source_code) - return CompiledDaceProgram(sdfg_program, artifact.bind_func_name, binding_source_shim) # type: ignore[arg-type] + program = CompiledDaceProgram( + sdfg_program, artifact.bind_func_name, artifact.binding_source_code + ) + return gtx_wfddecoration.convert_args(program, device=artifact.device_type) class DaCeCompilationStepFactory(factory.Factory): class Meta: model = DaCeCompiler - - -class DaCeLoaderFactory(factory.Factory): - class Meta: - model = DaCeLoader diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 12441587c5..5855ef5cc4 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -18,11 +18,10 @@ from gt4py.next.otf import recipes, stages, workflow from gt4py.next.program_processors.runners.dace.workflow import ( bindings as bindings_step, - decoration as decoration_step, + compilation as compilation_step, ) from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, - DaCeLoaderFactory, ) from gt4py.next.program_processors.runners.dace.workflow.translation import ( DaCeTranslationStepFactory, @@ -75,29 +74,6 @@ class Params: ) -class DaCeFinalizeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFFinalizeWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE - ) - - load = factory.SubFactory( - DaCeLoaderFactory, - device_type=factory.SelfAttribute("..device_type"), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - decoration_step.convert_args, - device=o.device_type, - ) - ) - - class DaCeWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFCompileWorkflow @@ -116,8 +92,4 @@ class Params: auto_optimize=factory.SelfAttribute("..auto_optimize"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - finalize = factory.SubFactory( - DaCeFinalizeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) + finalize = factory.LazyFunction(lambda: compilation_step.dace_finalize) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 6a600d5b5f..01610ec0c4 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -import functools from typing import Any import factory @@ -21,7 +20,7 @@ from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler +from gt4py.next.otf.compilation import compiler, importer from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -107,6 +106,24 @@ def extract_connectivity_args( return args +def gtfn_finalize(artifact: compiler.GTFNBuildArtifact) -> stages.ExecutableProgram: + """Turn a :class:`compiler.GTFNBuildArtifact` into a directly-callable program. + + Imports the .so as a Python extension module and wraps the entry point in + gt4py's calling convention (argument conversion, device-aware connectivity + handling, metric collection). Reads the target device from the artifact. + + Must run in the process that will ultimately call the returned program; + the module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. + """ + m = importer.import_from_path( + artifact.src_dir / artifact.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return convert_args(getattr(m, artifact.entry_point_name), device=artifact.device_type) + + class GTFNBuildWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFBuildWorkflow @@ -144,19 +161,7 @@ class Params: compiler.CompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), - ) - - -class GTFNFinalizeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFFinalizeWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - - load = factory.LazyFunction(lambda: compiler.load_artifact) - decoration = factory.LazyAttribute( - lambda o: functools.partial(convert_args, device=o.device_type) + device_type=factory.SelfAttribute("..device_type"), ) @@ -171,9 +176,7 @@ class Params: build = factory.SubFactory( GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) - finalize = factory.SubFactory( - GTFNFinalizeWorkflowFactory, device_type=factory.SelfAttribute("..device_type") - ) + finalize = factory.LazyFunction(lambda: gtfn_finalize) class GTFNBackendFactory(factory.Factory): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index f088761ffd..cd3bddb19a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -37,8 +37,9 @@ def test_backend_factory_trait_device(): assert cpu_version.executor.build.translation.device_type is core_defs.DeviceType.CPU assert gpu_version.executor.build.translation.device_type is core_defs.DeviceType.CUDA - assert cpu_version.executor.finalize.decoration.keywords["device"] is core_defs.DeviceType.CPU - assert gpu_version.executor.finalize.decoration.keywords["device"] is core_defs.DeviceType.CUDA + # The compilation step now also carries device_type so it can stamp the artifact. + assert cpu_version.executor.build.compilation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.build.compilation.device_type is core_defs.DeviceType.CUDA assert custom_layout_allocators.is_field_allocator_for( cpu_version.allocator, core_defs.DeviceType.CPU From b8efa44840f24ab2a7ba7d057371d8bdc926b1d6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 15:28:14 +0200 Subject: [PATCH 03/35] BuildArtifact can materialize() itself --- src/gt4py/next/otf/compilation/compiler.py | 27 +++++- src/gt4py/next/otf/definitions.py | 17 ++-- src/gt4py/next/otf/recipes.py | 60 ++++++++------ src/gt4py/next/otf/stages.py | 35 ++++++++ .../runners/dace/workflow/compilation.py | 83 ++++++++++--------- .../runners/dace/workflow/factory.py | 5 +- .../next/program_processors/runners/gtfn.py | 24 +----- 7 files changed, 151 insertions(+), 100 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 9d77d50b07..f1b076b55c 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -17,7 +17,7 @@ from gt4py._core import definitions as core_defs, locking from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow -from gt4py.next.otf.compilation import build_data, cache +from gt4py.next.otf.compilation import build_data, cache, importer T = TypeVar("T") @@ -48,9 +48,10 @@ def __call__( class GTFNBuildArtifact: """On-disk result of a GTFN compilation: a Python extension module. - Bindings are baked into the .so via nanobind, so the load step is just an - ``importlib`` import + entry-point symbol lookup. ``device_type`` is - intrinsic to the artifact: a CPU-built .so cannot be loaded as GPU. + Bindings are baked into the .so via nanobind, so materialization is just + an ``importlib`` import + entry-point symbol lookup, plus a wrapping in + gt4py's calling convention. ``device_type`` is intrinsic to the artifact: + a CPU-built .so cannot be loaded as GPU. """ src_dir: pathlib.Path @@ -58,6 +59,24 @@ class GTFNBuildArtifact: entry_point_name: str device_type: core_defs.DeviceType + def materialize(self) -> stages.ExecutableProgram: + """Bring the artifact up as a directly-callable program. + + Must run in the process that will ultimately call the returned + program; the imported module is registered in that process's + ``sys.modules`` under the ``gt4py.__compiled_programs__.`` prefix. + """ + # Imported lazily to avoid a circular module dependency: ``runners.gtfn`` + # imports this module to construct the workflow, while the + # gt4py-shaped argument-conversion lives there. + from gt4py.next.program_processors.runners.gtfn import convert_args + + m = importer.import_from_path( + self.src_dir / self.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return convert_args(getattr(m, self.entry_point_name), device=self.device_type) + @dataclasses.dataclass(frozen=True) class Compiler( diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index c242e02aa2..1fe56a1f11 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Protocol, TypeAlias, TypeVar +from typing import Protocol, TypeAlias, TypeVar from gt4py.next.ffront import stages as ffront_stages from gt4py.next.iterator import ir as itir @@ -56,17 +56,18 @@ def __call__( class CompilationStep( - workflow.Workflow[stages.CompilableProject[CodeSpecT, TargetCodeSpecT], Any], + workflow.Workflow[ + stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact + ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Run the build system and produce an on-disk, backend-specific build artifact. + """Run the build system and produce a :class:`stages.BuildArtifact`. - The artifact type is intentionally :class:`Any` here — each backend defines - its own concrete dataclass (frozen, picklable). The build/finalize boundary - in :class:`recipes.OTFCompileWorkflow` only requires that whatever - ``CompilationStep`` produces is what the backend's ``finalize`` consumes. + Each backend defines its own concrete artifact dataclass (frozen, + picklable, self-materializing); they all satisfy the + :class:`stages.BuildArtifact` Protocol structurally. """ def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> Any: ... + ) -> stages.BuildArtifact: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 0ded8426aa..88188aebbd 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -9,53 +9,63 @@ from __future__ import annotations import dataclasses -from typing import Any from gt4py.next.otf import definitions, stages, workflow +def materialize_artifact(artifact: stages.BuildArtifact) -> stages.ExecutableProgram: + """Default ``finalize`` step for :class:`OTFCompileWorkflow`. + + Universal across backends: dispatches into the artifact's own + :meth:`stages.BuildArtifact.materialize` method. The dispatch happens + through ordinary Python method resolution on the artifact's concrete + type — no separate registry, no backend-specific finalize plumbing. + """ + return artifact.materialize() + + @dataclasses.dataclass(frozen=True) class OTFBuildWorkflow( - workflow.NamedStepSequence[definitions.CompilableProgramDef, Any] + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] ): - """Translation + bindings + build system; ends at an on-disk artifact. + """Translation + bindings + build system; ends at a :class:`stages.BuildArtifact`. - The artifact type is backend-specific (e.g. a ``GTFNBuildArtifact`` or a - ``DaCeBuildArtifact``); a workflow only ever pairs a backend's build with - that same backend's finalize, so no cross-backend artifact protocol is - needed. + The artifact's concrete type is backend-specific (e.g. ``GTFNBuildArtifact`` + or ``DaCeBuildArtifact``); both share only the + :class:`stages.BuildArtifact` Protocol — frozen, picklable, self- + materializing. Grouped as a sub-workflow so the ``cached=True`` backend trait can wrap just this sub-workflow in a :class:`workflow.CachedStep` — caching keys - on :class:`definitions.CompilableProgramDef` and values on a picklable, - backend-specific artifact dataclass. + on :class:`definitions.CompilableProgramDef` and values on a picklable + artifact. """ translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, Any] + compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] @dataclasses.dataclass(frozen=True) class OTFCompileWorkflow( workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.ExecutableProgram] ): - """Full OTF pipeline: two phases separated by an on-disk artifact boundary. + """Full OTF pipeline: build an artifact, then materialize a callable. - 1. ``build`` — produces a picklable, backend-specific build artifact. - Heavy, idempotent, parallelizable across processes; the natural cache - target. + 1. ``build`` — produces a picklable :class:`stages.BuildArtifact`. Heavy, + idempotent, parallelizable; the natural cache target. 2. ``finalize`` — rehydrates the artifact into a directly-callable - :class:`stages.ExecutableProgram`. Backend-internal; whatever - sequence of "load the .so / wrap with gt4py calling convention / - attach metrics" the backend needs. - - The artifact dataclass is the contract between these two phases. By - convention, artifacts are frozen dataclasses, picklable across process - boundaries, and self-describing (carry every property finalize needs, - e.g. ``device_type``). Each backend defines its own; nothing about that - contract is enforced by this module — it is per-backend convention. + :class:`stages.ExecutableProgram`. Defaults to + :func:`materialize_artifact`, which dispatches through the artifact's + own :meth:`stages.BuildArtifact.materialize` — backend-specific code + lives on the artifact, not in a sibling free function. + + Backends typically only configure ``build``; ``finalize`` falls through + to the artifact's own materialization logic. Override ``finalize`` only + to wrap the entire post-build phase (e.g. add a tracing wrapper). """ - build: workflow.Workflow[definitions.CompilableProgramDef, Any] - finalize: workflow.Workflow[Any, stages.ExecutableProgram] + build: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] + finalize: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] = ( + materialize_artifact + ) diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index b6816b1cc3..8448db9334 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -129,6 +129,41 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable +class BuildArtifact(Protocol): + """A picklable, self-contained, compiled gt4py program in transit. + + A *build artifact* is the output of the ``build`` phase of + :class:`recipes.OTFCompileWorkflow` — the explicit boundary between the + build phase (heavy, idempotent, parallelizable, picklable output) and the + live-callable phase (cheap, process-bound). + + Each backend defines its own concrete artifact dataclass, carrying + whatever fields it needs to bring up a runnable callable in any process + that has the backend module on the import path. Conventions: + + 1. **Frozen dataclass.** Implementations are + ``@dataclasses.dataclass(frozen=True)`` so they have value semantics + (hashable, structurally equatable) for use as cache keys. + + 2. **Picklable.** Implementations round-trip safely through :mod:`pickle` + so they can cross process boundaries: process-pool / distributed + compilation, AOT pipelines that build now and run later from a + different process, persistent caches keyed on the artifact. Live, + process-bound state (open files, ``ctypes`` handles, imported Python + modules) is therefore not allowed in the artifact — that is what + :meth:`materialize` rehydrates. + + 3. **Self-materializing.** Calling :meth:`materialize` returns a + directly-callable :class:`ExecutableProgram` taking gt4py-shaped + arguments. The method body is the backend's full post-build + sequence (load the .so, wrap with the calling convention, attach + metric hooks, etc.). Receivers don't need to know which backend + produced the artifact — they just call ``artifact.materialize()``. + """ + + def materialize(self) -> ExecutableProgram: ... + + def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 3215d44c10..9bbb035997 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -122,8 +122,8 @@ class DaCeBuildArtifact: Carries the ``device_type`` the artifact was built for; a CPU-built .so cannot be loaded as GPU. Also carries the bindings (Python source code - that the loader ``exec``\\ s to materialize the SDFG argument-marshalling - function). + that materialization ``exec``\\ s to bring the SDFG argument-marshalling + function into existence). """ build_folder: pathlib.Path @@ -131,6 +131,46 @@ class DaCeBuildArtifact: bind_func_name: str device_type: core_defs.DeviceType + def materialize(self) -> stages.ExecutableProgram: + """Bring the artifact up as a directly-callable program. + + Re-deserializes the SDFG dump from the build folder, links against + the pre-built .so via ``compiler.use_cache=True`` (no re-codegen), + wraps it in a :class:`CompiledDaceProgram`, and applies gt4py's + calling convention via :func:`decoration.convert_args`. + + Must run in the process that will ultimately call the returned + program; the imported binding code is bound into a per-call namespace + within the produced :class:`CompiledDaceProgram`. + """ + # Imported lazily to avoid a circular module dependency: + # ``decoration`` imports this module. + from gt4py.next.program_processors.runners.dace.workflow import ( + decoration as gtx_wfddecoration, + ) + + for dump_name in ("program.sdfgz", "program.sdfg"): + sdfg_dump = self.build_folder / dump_name + if sdfg_dump.exists(): + break + else: + raise RuntimeError( + f"No SDFG dump (program.sdfgz / program.sdfg) found in '{self.build_folder}'." + ) + + sdfg = dace.SDFG.from_file(str(sdfg_dump)) + sdfg.build_folder = str(self.build_folder) + + with gtx_wfdcommon.dace_context(device_type=self.device_type): + # use_cache=True forces DaCe to load the existing .so without re-codegen. + with dace.config.set_temporary("compiler", "use_cache", value=True): + sdfg_program = sdfg.compile(validate=False) + + program = CompiledDaceProgram( + sdfg_program, self.bind_func_name, self.binding_source_code + ) + return gtx_wfddecoration.convert_args(program, device=self.device_type) + @dataclasses.dataclass(frozen=True) class DaCeCompiler( @@ -176,45 +216,6 @@ def __call__( ) -def dace_finalize(artifact: DaCeBuildArtifact) -> stages.ExecutableProgram: - """Turn a :class:`DaCeBuildArtifact` into a directly-callable program. - - Re-deserializes the SDFG dump from the build folder, links against the - pre-built .so via ``compiler.use_cache=True`` (no re-codegen), wraps it - in a :class:`CompiledDaceProgram`, and applies gt4py's calling convention - via :func:`decoration.convert_args`. Reads the target device from the - artifact. - - Must run in the process that will ultimately call the returned program. - """ - # Local import to avoid a circular reference (decoration imports compilation). - from gt4py.next.program_processors.runners.dace.workflow import ( - decoration as gtx_wfddecoration, - ) - - for dump_name in ("program.sdfgz", "program.sdfg"): - sdfg_dump = artifact.build_folder / dump_name - if sdfg_dump.exists(): - break - else: - raise RuntimeError( - f"No SDFG dump (program.sdfgz / program.sdfg) found in '{artifact.build_folder}'." - ) - - sdfg = dace.SDFG.from_file(str(sdfg_dump)) - sdfg.build_folder = str(artifact.build_folder) - - with gtx_wfdcommon.dace_context(device_type=artifact.device_type): - # use_cache=True forces DaCe to load the existing .so without re-codegen. - with dace.config.set_temporary("compiler", "use_cache", value=True): - sdfg_program = sdfg.compile(validate=False) - - program = CompiledDaceProgram( - sdfg_program, artifact.bind_func_name, artifact.binding_source_code - ) - return gtx_wfddecoration.convert_args(program, device=artifact.device_type) - - class DaCeCompilationStepFactory(factory.Factory): class Meta: model = DaCeCompiler diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 5855ef5cc4..4eeced2341 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -18,7 +18,6 @@ from gt4py.next.otf import recipes, stages, workflow from gt4py.next.program_processors.runners.dace.workflow import ( bindings as bindings_step, - compilation as compilation_step, ) from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, @@ -92,4 +91,6 @@ class Params: auto_optimize=factory.SelfAttribute("..auto_optimize"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - finalize = factory.LazyFunction(lambda: compilation_step.dace_finalize) + # ``finalize`` is left at its OTFCompileWorkflow default + # (``stages.materialize_artifact``), which dispatches via the artifact's + # own :meth:`stages.BuildArtifact.materialize` method. diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 01610ec0c4..80c918df34 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -20,7 +20,7 @@ from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler, importer +from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -106,24 +106,6 @@ def extract_connectivity_args( return args -def gtfn_finalize(artifact: compiler.GTFNBuildArtifact) -> stages.ExecutableProgram: - """Turn a :class:`compiler.GTFNBuildArtifact` into a directly-callable program. - - Imports the .so as a Python extension module and wraps the entry point in - gt4py's calling convention (argument conversion, device-aware connectivity - handling, metric collection). Reads the target device from the artifact. - - Must run in the process that will ultimately call the returned program; - the module is registered in that process's ``sys.modules`` under the - ``gt4py.__compiled_programs__.`` prefix. - """ - m = importer.import_from_path( - artifact.src_dir / artifact.module, - sys_modules_prefix="gt4py.__compiled_programs__.", - ) - return convert_args(getattr(m, artifact.entry_point_name), device=artifact.device_type) - - class GTFNBuildWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFBuildWorkflow @@ -176,7 +158,9 @@ class Params: build = factory.SubFactory( GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) - finalize = factory.LazyFunction(lambda: gtfn_finalize) + # ``finalize`` is left at its OTFCompileWorkflow default + # (``stages.materialize_artifact``), which dispatches via the artifact's + # own :meth:`stages.BuildArtifact.materialize` method. class GTFNBackendFactory(factory.Factory): From 915ce278a2962d48c889724fbbcf4135cabbbf9e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 16:36:14 +0200 Subject: [PATCH 04/35] remove OTFCompileWorkflow --- src/gt4py/next/backend.py | 5 +- src/gt4py/next/otf/recipes.py | 49 +++---------------- src/gt4py/next/otf/stages.py | 8 +-- .../program_processors/formatters/gtfn.py | 2 +- .../runners/dace/program.py | 16 +++--- .../runners/dace/workflow/__init__.py | 2 +- .../runners/dace/workflow/backend.py | 22 +++------ .../runners/dace/workflow/factory.py | 25 +--------- .../next/program_processors/runners/gtfn.py | 34 +++---------- .../test_temporaries_with_sizes.py | 14 +++--- .../iterator_tests/test_builtins.py | 6 +-- .../otf_tests/test_compiled_program.py | 13 +++-- .../gtfn_tests/test_gtfn_module.py | 4 +- .../runners_tests/test_gtfn.py | 18 +++---- 14 files changed, 69 insertions(+), 149 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 6123da97e0..b7ad2b2d2c 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -147,16 +147,17 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): name: str - executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef] def compile( self, program: definitions.IRDefinitionT, compile_time_args: arguments.CompileTimeArgs ) -> stages.ExecutableProgram: - return self.executor( + artifact = self.executor( self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args)) ) + return artifact.materialize() @property def __gt_allocator__( diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 88188aebbd..c1af3388ae 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -13,17 +13,6 @@ from gt4py.next.otf import definitions, stages, workflow -def materialize_artifact(artifact: stages.BuildArtifact) -> stages.ExecutableProgram: - """Default ``finalize`` step for :class:`OTFCompileWorkflow`. - - Universal across backends: dispatches into the artifact's own - :meth:`stages.BuildArtifact.materialize` method. The dispatch happens - through ordinary Python method resolution on the artifact's concrete - type — no separate registry, no backend-specific finalize plumbing. - """ - return artifact.materialize() - - @dataclasses.dataclass(frozen=True) class OTFBuildWorkflow( workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] @@ -33,39 +22,17 @@ class OTFBuildWorkflow( The artifact's concrete type is backend-specific (e.g. ``GTFNBuildArtifact`` or ``DaCeBuildArtifact``); both share only the :class:`stages.BuildArtifact` Protocol — frozen, picklable, self- - materializing. - - Grouped as a sub-workflow so the ``cached=True`` backend trait can wrap - just this sub-workflow in a :class:`workflow.CachedStep` — caching keys - on :class:`definitions.CompilableProgramDef` and values on a picklable + materializing. The whole post-build phase lives on the artifact itself + (``artifact.materialize()`` returns the directly-callable program); this + workflow's job is just to produce the artifact. + + Used directly as :attr:`gt4py.next.backend.Backend.executor`. The + ``cached=True`` backend trait wraps this whole workflow in a + :class:`workflow.CachedStep` — caching keys on + :class:`definitions.CompilableProgramDef` and values on a picklable artifact. """ translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] - - -@dataclasses.dataclass(frozen=True) -class OTFCompileWorkflow( - workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.ExecutableProgram] -): - """Full OTF pipeline: build an artifact, then materialize a callable. - - 1. ``build`` — produces a picklable :class:`stages.BuildArtifact`. Heavy, - idempotent, parallelizable; the natural cache target. - 2. ``finalize`` — rehydrates the artifact into a directly-callable - :class:`stages.ExecutableProgram`. Defaults to - :func:`materialize_artifact`, which dispatches through the artifact's - own :meth:`stages.BuildArtifact.materialize` — backend-specific code - lives on the artifact, not in a sibling free function. - - Backends typically only configure ``build``; ``finalize`` falls through - to the artifact's own materialization logic. Override ``finalize`` only - to wrap the entire post-build phase (e.g. add a tracing wrapper). - """ - - build: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] - finalize: workflow.Workflow[stages.BuildArtifact, stages.ExecutableProgram] = ( - materialize_artifact - ) diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 8448db9334..35cfe0e425 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -132,10 +132,10 @@ def build(self) -> None: ... class BuildArtifact(Protocol): """A picklable, self-contained, compiled gt4py program in transit. - A *build artifact* is the output of the ``build`` phase of - :class:`recipes.OTFCompileWorkflow` — the explicit boundary between the - build phase (heavy, idempotent, parallelizable, picklable output) and the - live-callable phase (cheap, process-bound). + A *build artifact* is the output of an :class:`recipes.OTFBuildWorkflow` + (the value of :attr:`gt4py.next.backend.Backend.executor`) — the explicit + boundary between the build phase (heavy, idempotent, parallelizable, + picklable output) and the live-callable phase (cheap, process-bound). Each backend defines its own concrete artifact dataclass, carrying whatever fields it needs to bring up a runnable callable in any process diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index c20f7a8555..1d65b8d8d0 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -17,7 +17,7 @@ @program_formatter.program_formatter def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. - gtfn_translation = gtfn.GTFNBackendFactory().executor.build.translation # type: ignore[attr-defined] + gtfn_translation = gtfn.GTFNBackendFactory().executor.translation # type: ignore[attr-defined] assert isinstance(gtfn_translation, GTFNTranslationStep) return gtfn_translation.generate_stencil_source( program, diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index c13daa249f..310b9634ba 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -76,17 +76,19 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: gt4py_program_args=[p.type for p in program.params], ) - compile_workflow = typing.cast(recipes.OTFCompileWorkflow, self.backend.executor) - build_workflow = ( - compile_workflow.build.step - if hasattr(compile_workflow.build, "step") - else compile_workflow.build - ) # the `build` phase may be wrapped in a `CachedStep` depending on backend configuration. + # ``backend.executor`` is an :class:`recipes.OTFBuildWorkflow`, optionally wrapped + # in a :class:`workflow.CachedStep` when ``cached=True``. + build_workflow = typing.cast( + recipes.OTFBuildWorkflow, + self.backend.executor.step + if hasattr(self.backend.executor, "step") + else self.backend.executor, + ) compile_workflow_translation = ( build_workflow.translation.step if hasattr(build_workflow.translation, "step") else build_workflow.translation - ) # Same for the translation stage, which could be a `CachedStep` depending on backend configuration. + ) # the translation stage could also be a `CachedStep` depending on backend configuration. # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with # the other parts of the workaround when possible. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py index 4d825c0c9b..f822709cd2 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py @@ -10,7 +10,7 @@ The main module is `backend`, that exports the backends for CPU and GPU devices. The `backend` module uses `factory` to define a workflow that implements the -`OTFCompileWorkflow` recipe. The different stages are implemeted in separate modules: +`OTFBuildWorkflow` recipe. The different stages are implemeted in separate modules: - `translation` for lowering of GTIR to SDFG and applying SDFG transformations - `compilation` for compiling the SDFG into a program - `decoration` to parse the program arguments and pass them to the program call diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index 935655a422..de6778a750 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -8,7 +8,6 @@ from __future__ import annotations -import dataclasses import warnings from typing import Any, Final @@ -45,12 +44,7 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: dataclasses.replace( - o.otf_workflow, - build=workflow.CachedStep( - o.otf_workflow.build, hash_function=o.hash_function - ), - ) + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) ), name_cached="_cached", ) @@ -133,13 +127,13 @@ def make_dace_backend( gpu=gpu, cached=cached, auto_optimize=auto_optimize, - otf_workflow__build__cached_translation=cached, - otf_workflow__build__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), - otf_workflow__build__bare_translation__auto_optimize_args=optimization_args, - otf_workflow__build__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, - otf_workflow__build__bare_translation__use_metrics=use_metrics, - otf_workflow__build__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, - otf_workflow__build__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, + otf_workflow__cached_translation=cached, + otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), + otf_workflow__bare_translation__auto_optimize_args=optimization_args, + otf_workflow__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, + otf_workflow__bare_translation__use_metrics=use_metrics, + otf_workflow__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, + otf_workflow__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 4eeced2341..f022cdea64 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -30,7 +30,7 @@ _GT_DACE_BINDING_FUNCTION_NAME: Final[str] = "update_sdfg_args" -class DaCeBuildWorkflowFactory(factory.Factory): +class DaCeWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFBuildWorkflow @@ -71,26 +71,3 @@ class Params: device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - auto_optimize: bool = False - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE - ) - cached_translation = factory.Trait(build__cached_translation=True) - - build = factory.SubFactory( - DaCeBuildWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - auto_optimize=factory.SelfAttribute("..auto_optimize"), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - # ``finalize`` is left at its OTFCompileWorkflow default - # (``stages.materialize_artifact``), which dispatches via the artifact's - # own :meth:`stages.BuildArtifact.materialize` method. diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 80c918df34..f865384271 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import dataclasses from typing import Any import factory @@ -147,22 +146,6 @@ class Params: ) -class GTFNCompileWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cached_translation = factory.Trait(build__cached_translation=True) - - build = factory.SubFactory( - GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") - ) - # ``finalize`` is left at its OTFCompileWorkflow default - # (``stages.materialize_artifact``), which dispatches via the artifact's - # own :meth:`stages.BuildArtifact.materialize` method. - - class GTFNBackendFactory(factory.Factory): class Meta: model = backend.Backend @@ -179,19 +162,14 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: dataclasses.replace( - o.otf_workflow, - build=workflow.CachedStep( - o.otf_workflow.build, hash_function=o.hash_function - ), - ) + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) ), name_cached="_cached", ) device_type = core_defs.DeviceType.CPU hash_function = stages.compilation_hash otf_workflow = factory.SubFactory( - GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) name = factory.LazyAttribute( @@ -207,17 +185,17 @@ class Params: run_gtfn_imperative = GTFNBackendFactory( name_postfix="_imperative", - otf_workflow__build__translation__use_imperative_backend=True, + otf_workflow__translation__use_imperative_backend=True, ) -run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__build__cached_translation=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory( - gpu=True, cached=True, otf_workflow__build__cached_translation=True + gpu=True, cached=True, otf_workflow__cached_translation=True ) run_gtfn_no_transforms = GTFNBackendFactory( - otf_workflow__build__bare_translation__enable_itir_transforms=False + otf_workflow__bare_translation__enable_itir_transforms=False ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index c3058f33a8..90c0d775f2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -34,14 +34,12 @@ def exec_alloc_descriptor(): name="run_gtfn_with_temporaries_and_sizes", transforms=backend.DEFAULT_TRANSFORMS, executor=run_gtfn.executor.replace( - build=run_gtfn.executor.build.replace( - translation=run_gtfn.executor.build.translation.replace( - symbolic_domain_sizes={ - "Cell": "num_cells", - "Edge": "num_edges", - "Vertex": "num_vertices", - } - ) + translation=run_gtfn.executor.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + } ) ), allocator=run_gtfn.allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 9d39b5d63a..4fff5192aa 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -192,11 +192,7 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): gtfn_without_transforms = dataclasses.replace( run_gtfn, executor=run_gtfn.executor.replace( - build=run_gtfn.executor.build.replace( - translation=run_gtfn.executor.build.translation.replace( - enable_itir_transforms=False - ), - ), + translation=run_gtfn.executor.translation.replace(enable_itir_transforms=False), ), # avoid inlining the function ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index def8800c98..233e5a2f6e 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -114,12 +114,19 @@ def test_inlining_of_scalar_works_integration(testee_prog): hijacked_program = None + @dataclasses.dataclass(frozen=True) + class _NoOpArtifact: + """A trivial BuildArtifact that materializes to a no-op callable.""" + + def materialize(self): + return lambda *args, **kwargs: None + def pirate(program: toolchain.ConcreteArtifact): - # Replaces the gtfn otf_workflow: and steals the compilable program, - # then returns a dummy "CompiledProgram" that does nothing. + # Replaces the gtfn otf_workflow: steals the compilable program, then + # returns a dummy artifact whose materialization is a no-op callable. nonlocal hijacked_program hijacked_program = program - return lambda *args, **kwargs: None + return _NoOpArtifact() hacked_gtfn_backend = gtfn.GTFNBackendFactory(name_postfix="_custom", executor=pirate) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 3611abad89..d027c9dcb1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -135,11 +135,11 @@ def test_gtfn_file_cache(program_example): ) cached_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=True - ).executor.build.step.translation + ).executor.step.translation bare_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=False - ).executor.build.step.translation + ).executor.step.translation cache_key = stages.fingerprint_compilable_program(compilable_program) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index cd3bddb19a..ab4697ed73 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -34,12 +34,12 @@ def test_backend_factory_trait_device(): assert cpu_version.name == "run_gtfn_cpu" assert gpu_version.name == "run_gtfn_gpu" - assert cpu_version.executor.build.translation.device_type is core_defs.DeviceType.CPU - assert gpu_version.executor.build.translation.device_type is core_defs.DeviceType.CUDA + assert cpu_version.executor.translation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.translation.device_type is core_defs.DeviceType.CUDA # The compilation step now also carries device_type so it can stamp the artifact. - assert cpu_version.executor.build.compilation.device_type is core_defs.DeviceType.CPU - assert gpu_version.executor.build.compilation.device_type is core_defs.DeviceType.CUDA + assert cpu_version.executor.compilation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.compilation.device_type is core_defs.DeviceType.CUDA assert custom_layout_allocators.is_field_allocator_for( cpu_version.allocator, core_defs.DeviceType.CPU @@ -51,7 +51,7 @@ def test_backend_factory_trait_device(): def test_backend_factory_trait_cached(): cached_version = gtfn.GTFNBackendFactory(gpu=False, cached=True) - assert isinstance(cached_version.executor.build, workflow.CachedStep) + assert isinstance(cached_version.executor, workflow.CachedStep) assert cached_version.name == "run_gtfn_cpu_cached" @@ -62,11 +62,11 @@ def test_backend_factory_build_cache_config(monkeypatch): persistent_version = gtfn.GTFNBackendFactory() assert ( - session_version.executor.build.compilation.cache_lifetime + session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION ) assert ( - persistent_version.executor.build.compilation.cache_lifetime + persistent_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.PERSISTENT ) @@ -78,10 +78,10 @@ def test_backend_factory_build_type_config(monkeypatch): min_size_version = gtfn.GTFNBackendFactory() assert ( - release_version.executor.build.compilation.builder_factory.cmake_build_type + release_version.executor.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.RELEASE ) assert ( - min_size_version.executor.build.compilation.builder_factory.cmake_build_type + min_size_version.executor.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.MIN_SIZE_REL ) From c51ac5592cb182b4537dfcc5e840fb416a716105 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 16:47:24 +0200 Subject: [PATCH 05/35] cleanup docstrings --- src/gt4py/next/otf/compilation/compiler.py | 19 +++++----- src/gt4py/next/otf/recipes.py | 15 ++------ src/gt4py/next/otf/stages.py | 36 ++++--------------- .../runners/dace/program.py | 6 ++-- .../runners/dace/workflow/compilation.py | 23 +++--------- 5 files changed, 25 insertions(+), 74 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index f1b076b55c..66bd74556f 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -48,10 +48,9 @@ def __call__( class GTFNBuildArtifact: """On-disk result of a GTFN compilation: a Python extension module. - Bindings are baked into the .so via nanobind, so materialization is just - an ``importlib`` import + entry-point symbol lookup, plus a wrapping in - gt4py's calling convention. ``device_type`` is intrinsic to the artifact: - a CPU-built .so cannot be loaded as GPU. + Bindings are baked into the .so via nanobind, so :meth:`materialize` is + just an ``importlib`` import + entry-point symbol lookup, plus a wrap in + gt4py's calling convention. """ src_dir: pathlib.Path @@ -60,15 +59,13 @@ class GTFNBuildArtifact: device_type: core_defs.DeviceType def materialize(self) -> stages.ExecutableProgram: - """Bring the artifact up as a directly-callable program. + """Import the module and wrap its entry point in gt4py's calling convention. - Must run in the process that will ultimately call the returned - program; the imported module is registered in that process's - ``sys.modules`` under the ``gt4py.__compiled_programs__.`` prefix. + Must run in the process that will call the returned program: the + module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. """ - # Imported lazily to avoid a circular module dependency: ``runners.gtfn`` - # imports this module to construct the workflow, while the - # gt4py-shaped argument-conversion lives there. + # Lazy import: ``runners.gtfn`` imports this module to construct the workflow. from gt4py.next.program_processors.runners.gtfn import convert_args m = importer.import_from_path( diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index c1af3388ae..f784a20a12 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -19,18 +19,9 @@ class OTFBuildWorkflow( ): """Translation + bindings + build system; ends at a :class:`stages.BuildArtifact`. - The artifact's concrete type is backend-specific (e.g. ``GTFNBuildArtifact`` - or ``DaCeBuildArtifact``); both share only the - :class:`stages.BuildArtifact` Protocol — frozen, picklable, self- - materializing. The whole post-build phase lives on the artifact itself - (``artifact.materialize()`` returns the directly-callable program); this - workflow's job is just to produce the artifact. - - Used directly as :attr:`gt4py.next.backend.Backend.executor`. The - ``cached=True`` backend trait wraps this whole workflow in a - :class:`workflow.CachedStep` — caching keys on - :class:`definitions.CompilableProgramDef` and values on a picklable - artifact. + Used as :attr:`gt4py.next.backend.Backend.executor`. The ``cached=True`` + backend trait wraps it in a :class:`workflow.CachedStep` keyed on + :class:`definitions.CompilableProgramDef`. """ translation: definitions.TranslationStep diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 35cfe0e425..4a735a76aa 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -130,35 +130,13 @@ def build(self) -> None: ... class BuildArtifact(Protocol): - """A picklable, self-contained, compiled gt4py program in transit. - - A *build artifact* is the output of an :class:`recipes.OTFBuildWorkflow` - (the value of :attr:`gt4py.next.backend.Backend.executor`) — the explicit - boundary between the build phase (heavy, idempotent, parallelizable, - picklable output) and the live-callable phase (cheap, process-bound). - - Each backend defines its own concrete artifact dataclass, carrying - whatever fields it needs to bring up a runnable callable in any process - that has the backend module on the import path. Conventions: - - 1. **Frozen dataclass.** Implementations are - ``@dataclasses.dataclass(frozen=True)`` so they have value semantics - (hashable, structurally equatable) for use as cache keys. - - 2. **Picklable.** Implementations round-trip safely through :mod:`pickle` - so they can cross process boundaries: process-pool / distributed - compilation, AOT pipelines that build now and run later from a - different process, persistent caches keyed on the artifact. Live, - process-bound state (open files, ``ctypes`` handles, imported Python - modules) is therefore not allowed in the artifact — that is what - :meth:`materialize` rehydrates. - - 3. **Self-materializing.** Calling :meth:`materialize` returns a - directly-callable :class:`ExecutableProgram` taking gt4py-shaped - arguments. The method body is the backend's full post-build - sequence (load the .so, wrap with the calling convention, attach - metric hooks, etc.). Receivers don't need to know which backend - produced the artifact — they just call ``artifact.materialize()``. + """The output of an :class:`recipes.OTFBuildWorkflow`. + + Each backend defines its own concrete artifact dataclass; all share this + Protocol. Implementations are frozen dataclasses, picklable, and have no + live process-bound state — that is reconstructed by :meth:`materialize`, + which returns a directly-callable :class:`ExecutableProgram` taking + gt4py-shaped arguments. """ def materialize(self) -> ExecutableProgram: ... diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 310b9634ba..1435080f52 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -76,8 +76,8 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: gt4py_program_args=[p.type for p in program.params], ) - # ``backend.executor`` is an :class:`recipes.OTFBuildWorkflow`, optionally wrapped - # in a :class:`workflow.CachedStep` when ``cached=True``. + # The executor and the translation stage may each be wrapped in a `CachedStep` + # depending on backend configuration; unwrap when so. build_workflow = typing.cast( recipes.OTFBuildWorkflow, self.backend.executor.step @@ -88,7 +88,7 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: build_workflow.translation.step if hasattr(build_workflow.translation, "step") else build_workflow.translation - ) # the translation stage could also be a `CachedStep` depending on backend configuration. + ) # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with # the other parts of the workaround when possible. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 9bbb035997..d79de03233 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -118,13 +118,7 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) class DaCeBuildArtifact: - """On-disk result of a DaCe compilation. - - Carries the ``device_type`` the artifact was built for; a CPU-built .so - cannot be loaded as GPU. Also carries the bindings (Python source code - that materialization ``exec``\\ s to bring the SDFG argument-marshalling - function into existence). - """ + """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" build_folder: pathlib.Path binding_source_code: str @@ -132,19 +126,11 @@ class DaCeBuildArtifact: device_type: core_defs.DeviceType def materialize(self) -> stages.ExecutableProgram: - """Bring the artifact up as a directly-callable program. - - Re-deserializes the SDFG dump from the build folder, links against - the pre-built .so via ``compiler.use_cache=True`` (no re-codegen), - wraps it in a :class:`CompiledDaceProgram`, and applies gt4py's - calling convention via :func:`decoration.convert_args`. + """Re-deserialize the SDFG, link the .so, and wrap in gt4py's calling convention. - Must run in the process that will ultimately call the returned - program; the imported binding code is bound into a per-call namespace - within the produced :class:`CompiledDaceProgram`. + Must run in the process that will call the returned program. """ - # Imported lazily to avoid a circular module dependency: - # ``decoration`` imports this module. + # Lazy import: ``decoration`` imports this module. from gt4py.next.program_processors.runners.dace.workflow import ( decoration as gtx_wfddecoration, ) @@ -162,7 +148,6 @@ def materialize(self) -> stages.ExecutableProgram: sdfg.build_folder = str(self.build_folder) with gtx_wfdcommon.dace_context(device_type=self.device_type): - # use_cache=True forces DaCe to load the existing .so without re-codegen. with dace.config.set_temporary("compiler", "use_cache", value=True): sdfg_program = sdfg.compile(validate=False) From 9b842e2ab79ed8f053fb9dbe60f578e1fc1b5739 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 17:09:58 +0200 Subject: [PATCH 06/35] refactor: remove lazy import --- src/gt4py/next/otf/compilation/compiler.py | 8 +- .../runners/dace/workflow/compilation.py | 110 ++---------------- .../runners/dace/workflow/compiled_program.py | 110 ++++++++++++++++++ .../runners/dace/workflow/decoration.py | 8 +- .../next/program_processors/runners/gtfn.py | 88 +------------- .../runners/gtfn_decoration.py | 105 +++++++++++++++++ .../runners_tests/dace_tests/test_dace.py | 8 +- 7 files changed, 236 insertions(+), 201 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py create mode 100644 src/gt4py/next/program_processors/runners/gtfn_decoration.py diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 66bd74556f..a858a0b02d 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -18,6 +18,7 @@ from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer +from gt4py.next.program_processors.runners import gtfn_decoration T = TypeVar("T") @@ -65,14 +66,13 @@ def materialize(self) -> stages.ExecutableProgram: module is registered in that process's ``sys.modules`` under the ``gt4py.__compiled_programs__.`` prefix. """ - # Lazy import: ``runners.gtfn`` imports this module to construct the workflow. - from gt4py.next.program_processors.runners.gtfn import convert_args - m = importer.import_from_path( self.src_dir / self.module, sys_modules_prefix="gt4py.__compiled_programs__.", ) - return convert_args(getattr(m, self.entry_point_name), device=self.device_type) + return gtfn_decoration.convert_args( + getattr(m, self.entry_point_name), device=self.device_type + ) @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index d79de03233..a7c628daea 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -9,111 +9,22 @@ from __future__ import annotations import dataclasses -import os import pathlib -import warnings -from collections.abc import Callable, MutableSequence, Sequence -from typing import Any import dace import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import common, config +from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache -from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon - - -class CompiledDaceProgram: - sdfg_program: dace.CompiledSDFG - - # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; - # scalar arguments that are not used in the SDFG will not be present. - sdfg_argtypes: list[dace.dtypes.Data] - - # The compiled program contains a callable object to update the SDFG arguments list. - update_sdfg_ctype_arglist: Callable[ - [ - core_defs.DeviceType, - Sequence[dace.dtypes.Data], - Sequence[Any], - MutableSequence[Any], - common.OffsetProvider, - ], - None, - ] - - # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` - # means that it has not been initialized, i.e. no call was ever performed. - # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. - # - csdfg_init_argv: Arguments used for initialization; used only the first time and - # never updated. - csdfg_argv: MutableSequence[Any] | None - csdfg_init_argv: Sequence[Any] | None - - def __init__( - self, - program: dace.CompiledSDFG, - bind_func_name: str, - binding_source_code: str, - ): - self.sdfg_program = program - - # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument - # name to its data type, in the same order as arguments appear in the program ABI. - # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. - self.sdfg_argtypes = list(program.sdfg.arglist().values()) - - # The binding source code is Python tailored to this specific SDFG. - # We dynamically compile that function and add it to the compiled program. - global_namespace: dict[str, Any] = {} - exec(binding_source_code, global_namespace) - self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] - # For debug purpose, we set a unique module name on the compiled function. - self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) - - # Since the SDFG hasn't been called yet. - self.csdfg_argv = None - self.csdfg_init_argv = None - - def construct_arguments(self, **kwargs: Any) -> None: - """ - This function will process the arguments and store the processed argument - vectors in `self.csdfg_args`, to call them use `self.fast_call()`. - """ - with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): - csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) - # Note we only care about `csdfg_argv` (normal call), since we have to update it, - # we ensure that it is a `list`. - self.csdfg_argv = [*csdfg_argv] - self.csdfg_init_argv = csdfg_init_argv - - def fast_call(self) -> None: - """ - Perform a call to the compiled SDFG using the previously generated argument - vectors, see `self.construct_arguments()`. - """ - assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( - "Argument vector was not set properly." - ) - self.sdfg_program.fast_call( - self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG - ) - - def __call__(self, **kwargs: Any) -> None: - """Call the compiled SDFG with the given arguments. - - Note that this function will not update the argument vectors stored inside - `self`. Furthermore, it is not recommended to use this function as it is - very slow. - """ - warnings.warn( - "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", - stacklevel=1, - ) - result = self.sdfg_program(**kwargs) - assert result is None +from gt4py.next.program_processors.runners.dace.workflow import ( + common as gtx_wfdcommon, + decoration as gtx_wfddecoration, +) +from gt4py.next.program_processors.runners.dace.workflow.compiled_program import ( + CompiledDaceProgram, +) @dataclasses.dataclass(frozen=True) @@ -130,11 +41,6 @@ def materialize(self) -> stages.ExecutableProgram: Must run in the process that will call the returned program. """ - # Lazy import: ``decoration`` imports this module. - from gt4py.next.program_processors.runners.dace.workflow import ( - decoration as gtx_wfddecoration, - ) - for dump_name in ("program.sdfgz", "program.sdfg"): sdfg_dump = self.build_folder / dump_name if sdfg_dump.exists(): diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py b/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py new file mode 100644 index 0000000000..5e28853902 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py @@ -0,0 +1,110 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import os +import warnings +from collections.abc import Callable, MutableSequence, Sequence +from typing import Any + +import dace + +from gt4py._core import definitions as core_defs +from gt4py.next import common, config + + +class CompiledDaceProgram: + sdfg_program: dace.CompiledSDFG + + # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; + # scalar arguments that are not used in the SDFG will not be present. + sdfg_argtypes: list[dace.dtypes.Data] + + # The compiled program contains a callable object to update the SDFG arguments list. + update_sdfg_ctype_arglist: Callable[ + [ + core_defs.DeviceType, + Sequence[dace.dtypes.Data], + Sequence[Any], + MutableSequence[Any], + common.OffsetProvider, + ], + None, + ] + + # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` + # means that it has not been initialized, i.e. no call was ever performed. + # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. + # - csdfg_init_argv: Arguments used for initialization; used only the first time and + # never updated. + csdfg_argv: MutableSequence[Any] | None + csdfg_init_argv: Sequence[Any] | None + + def __init__( + self, + program: dace.CompiledSDFG, + bind_func_name: str, + binding_source_code: str, + ): + self.sdfg_program = program + + # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument + # name to its data type, in the same order as arguments appear in the program ABI. + # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. + self.sdfg_argtypes = list(program.sdfg.arglist().values()) + + # The binding source code is Python tailored to this specific SDFG. + # We dynamically compile that function and add it to the compiled program. + global_namespace: dict[str, Any] = {} + exec(binding_source_code, global_namespace) + self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] + # For debug purpose, we set a unique module name on the compiled function. + self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) + + # Since the SDFG hasn't been called yet. + self.csdfg_argv = None + self.csdfg_init_argv = None + + def construct_arguments(self, **kwargs: Any) -> None: + """ + This function will process the arguments and store the processed argument + vectors in `self.csdfg_args`, to call them use `self.fast_call()`. + """ + with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): + csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) + # Note we only care about `csdfg_argv` (normal call), since we have to update it, + # we ensure that it is a `list`. + self.csdfg_argv = [*csdfg_argv] + self.csdfg_init_argv = csdfg_init_argv + + def fast_call(self) -> None: + """ + Perform a call to the compiled SDFG using the previously generated argument + vectors, see `self.construct_arguments()`. + """ + assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( + "Argument vector was not set properly." + ) + self.sdfg_program.fast_call( + self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG + ) + + def __call__(self, **kwargs: Any) -> None: + """Call the compiled SDFG with the given arguments. + + Note that this function will not update the argument vectors stored inside + `self`. Furthermore, it is not recommended to use this function as it is + very slow. + """ + warnings.warn( + "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", + stacklevel=1, + ) + result = self.sdfg_program(**kwargs) + assert result is None diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 103e7af33b..9681785206 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -18,14 +18,14 @@ from gt4py.next.instrumentation import metrics from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable -from gt4py.next.program_processors.runners.dace.workflow import ( - common as gtx_wfdcommon, - compilation as gtx_wfdcompilation, +from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon +from gt4py.next.program_processors.runners.dace.workflow.compiled_program import ( + CompiledDaceProgram, ) def convert_args( - fun: gtx_wfdcompilation.CompiledDaceProgram, + fun: CompiledDaceProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU, ) -> stages.ExecutableProgram: # Retieve metrics level from GT4Py environment variable. diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index f865384271..bddf06a1f3 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,17 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any - import factory -import numpy as np import gt4py._core.definitions as core_defs import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, common, config, field_utils -from gt4py.next.embedded import nd_array_field -from gt4py.next.instrumentation import metrics +from gt4py.next import backend, config from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -24,87 +19,6 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module -def convert_arg(arg: Any) -> Any: - # Note: this function is on the hot path and needs to have minimal overhead. - if (origin := getattr(arg, "__gt_origin__", None)) is not None: - # `Field` is the most likely case, we use `__gt_origin__` as the property is needed anyway - # and (currently) uniquely identifies a `NDArrayField` (which is the only supported `Field`) - assert isinstance(arg, nd_array_field.NdArrayField) - return arg.ndarray, origin - if isinstance(arg, tuple): - return tuple(convert_arg(a) for a in arg) - if isinstance(arg, np.bool_): - # nanobind does not support implicit conversion of `np.bool` to `bool` - return bool(arg) - # TODO(havogt): if this function still appears in profiles, - # we should avoid going through the previous isinstance checks for detecting a scalar. - # E.g. functools.cache on the arg type, returning a function that does the conversion - return arg - - -def convert_args( - inp: stages.ExecutableProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU -) -> stages.ExecutableProgram: - def decorated_program( - *args: Any, - offset_provider: dict[str, common.Connectivity | common.Dimension], - out: Any = None, - ) -> None: - # Note: this function is on the hot path and needs to have minimal overhead. - if out is not None: - args = (*args, out) - converted_args = (convert_arg(arg) for arg in args) - conn_args = extract_connectivity_args(offset_provider, device) - - opt_kwargs: dict[str, Any] = {} - if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): - # If we are collecting metrics, we need to add the `exec_info` argument - # to the `inp` call, which will be used to collect performance metrics. - exec_info: dict[str, float] = {} - opt_kwargs["exec_info"] = exec_info - - # generate implicit domain size arguments only if necessary, using `iter_size_args()` - inp( - *converted_args, - *conn_args, - **opt_kwargs, - ) - - if collect_metrics: - metrics.add_sample_to_current_source( - metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] - ) - - return decorated_program - - -def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: - # Note: this function is on the hot path and needs to have minimal overhead. - zero_origin = (0, 0) - assert all( - hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) - for conn in offset_provider.values() - ) - # Note: the order here needs to agree with the order of the generated bindings. - # This is currently true only because when hashing offset provider dicts, - # the keys' order is taken into account. Any modification to the hashing - # of offset providers may break this assumption here. - args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ - (ndarray, zero_origin) - for conn in offset_provider.values() - if (ndarray := getattr(conn, "ndarray", None)) is not None - ] - assert all( - common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) - for conn in offset_provider.values() - if hasattr(conn, "ndarray") - ) - - return args - - class GTFNBuildWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFBuildWorkflow diff --git a/src/gt4py/next/program_processors/runners/gtfn_decoration.py b/src/gt4py/next/program_processors/runners/gtfn_decoration.py new file mode 100644 index 0000000000..1ea2b222ca --- /dev/null +++ b/src/gt4py/next/program_processors/runners/gtfn_decoration.py @@ -0,0 +1,105 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Calling-convention adapter for GTFN-compiled programs. + +Wraps a freshly-imported GTFN entry point with gt4py's user-facing +argument convention: unpacks fields, splits offset_provider into +connectivity args, attaches metric collection. +""" + +from typing import Any + +import numpy as np + +import gt4py._core.definitions as core_defs +from gt4py.next import common, field_utils +from gt4py.next.embedded import nd_array_field +from gt4py.next.instrumentation import metrics +from gt4py.next.otf import stages + + +def convert_arg(arg: Any) -> Any: + # Note: this function is on the hot path and needs to have minimal overhead. + if (origin := getattr(arg, "__gt_origin__", None)) is not None: + # `Field` is the most likely case, we use `__gt_origin__` as the property is needed anyway + # and (currently) uniquely identifies a `NDArrayField` (which is the only supported `Field`) + assert isinstance(arg, nd_array_field.NdArrayField) + return arg.ndarray, origin + if isinstance(arg, tuple): + return tuple(convert_arg(a) for a in arg) + if isinstance(arg, np.bool_): + # nanobind does not support implicit conversion of `np.bool` to `bool` + return bool(arg) + # TODO(havogt): if this function still appears in profiles, + # we should avoid going through the previous isinstance checks for detecting a scalar. + # E.g. functools.cache on the arg type, returning a function that does the conversion + return arg + + +def convert_args( + inp: stages.ExecutableProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU +) -> stages.ExecutableProgram: + def decorated_program( + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, + ) -> None: + # Note: this function is on the hot path and needs to have minimal overhead. + if out is not None: + args = (*args, out) + converted_args = (convert_arg(arg) for arg in args) + conn_args = extract_connectivity_args(offset_provider, device) + + opt_kwargs: dict[str, Any] = {} + if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): + # If we are collecting metrics, we need to add the `exec_info` argument + # to the `inp` call, which will be used to collect performance metrics. + exec_info: dict[str, float] = {} + opt_kwargs["exec_info"] = exec_info + + # generate implicit domain size arguments only if necessary, using `iter_size_args()` + inp( + *converted_args, + *conn_args, + **opt_kwargs, + ) + + if collect_metrics: + metrics.add_sample_to_current_source( + metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] + ) + + return decorated_program + + +def extract_connectivity_args( + offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: + # Note: this function is on the hot path and needs to have minimal overhead. + zero_origin = (0, 0) + assert all( + hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) + for conn in offset_provider.values() + ) + # Note: the order here needs to agree with the order of the generated bindings. + # This is currently true only because when hashing offset provider dicts, + # the keys' order is taken into account. Any modification to the hashing + # of offset providers may break this assumption here. + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ + (ndarray, zero_origin) + for conn in offset_provider.values() + if (ndarray := getattr(conn, "ndarray", None)) is not None + ] + assert all( + common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) + for conn in offset_provider.values() + if hasattr(conn, "ndarray") + ) + + return args diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index a204886690..06b3d428bb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -82,7 +82,7 @@ def make_mocks(monkeypatch): # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object mock_fast_call = unittest.mock.MagicMock() gt4py_fast_call = ( - gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.fast_call + gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram.fast_call ) def mocked_fast_call(self): @@ -99,21 +99,21 @@ def mocked_fast_call(self): return fast_call_result monkeypatch.setattr( - gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, + gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram, "fast_call", mocked_fast_call, ) # Wrap `compiled_sdfg.CompiledSDFG.construct_arguments` with mock object mock_construct_arguments = unittest.mock.MagicMock() - gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.construct_arguments + gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram.construct_arguments def mocked_construct_arguments(self, *args, **kwargs): mock_construct_arguments.__call__(*args, **kwargs) return gt4py_construct_arguments(self, *args, **kwargs) monkeypatch.setattr( - gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, + gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram, "construct_arguments", mocked_construct_arguments, ) From c022bde47ee8e54a4d72750385bd3a9328cb489e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 17:24:52 +0200 Subject: [PATCH 07/35] apply pre-commit --- src/gt4py/next/otf/definitions.py | 4 +--- .../runners/dace/workflow/compilation.py | 8 ++------ .../runners/dace/workflow/decoration.py | 4 +--- .../program_processors/runners/dace/workflow/factory.py | 4 +--- .../program_processor_tests/runners_tests/test_gtfn.py | 5 +---- 5 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 1fe56a1f11..b5d6a0ecfa 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -56,9 +56,7 @@ def __call__( class CompilationStep( - workflow.Workflow[ - stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact - ], + workflow.Workflow[stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact], Protocol[CodeSpecT, TargetCodeSpecT], ): """Run the build system and produce a :class:`stages.BuildArtifact`. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index a7c628daea..71b7934fc1 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -22,9 +22,7 @@ common as gtx_wfdcommon, decoration as gtx_wfddecoration, ) -from gt4py.next.program_processors.runners.dace.workflow.compiled_program import ( - CompiledDaceProgram, -) +from gt4py.next.program_processors.runners.dace.workflow.compiled_program import CompiledDaceProgram @dataclasses.dataclass(frozen=True) @@ -57,9 +55,7 @@ def materialize(self) -> stages.ExecutableProgram: with dace.config.set_temporary("compiler", "use_cache", value=True): sdfg_program = sdfg.compile(validate=False) - program = CompiledDaceProgram( - sdfg_program, self.bind_func_name, self.binding_source_code - ) + program = CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) return gtx_wfddecoration.convert_args(program, device=self.device_type) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 9681785206..6b828d5a97 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -19,9 +19,7 @@ from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon -from gt4py.next.program_processors.runners.dace.workflow.compiled_program import ( - CompiledDaceProgram, -) +from gt4py.next.program_processors.runners.dace.workflow.compiled_program import CompiledDaceProgram def convert_args( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index f022cdea64..9f6c80fd07 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -16,9 +16,7 @@ from gt4py._core import definitions as core_defs, filecache from gt4py.next import config from gt4py.next.otf import recipes, stages, workflow -from gt4py.next.program_processors.runners.dace.workflow import ( - bindings as bindings_step, -) +from gt4py.next.program_processors.runners.dace.workflow import bindings as bindings_step from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index ab4697ed73..712e0500f5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -61,10 +61,7 @@ def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.PERSISTENT) persistent_version = gtfn.GTFNBackendFactory() - assert ( - session_version.executor.compilation.cache_lifetime - is config.BuildCacheLifetime.SESSION - ) + assert session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION assert ( persistent_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.PERSISTENT From a107dd61b03aa3fc0aef5817318f542d8ae66bbf Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 17:31:08 +0200 Subject: [PATCH 08/35] update roundtrip --- .../program_processors/runners/roundtrip.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 5ee0a67f25..d97b3ab238 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -213,13 +213,28 @@ def fencil_generator( @dataclasses.dataclass(frozen=True) -class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram]): +class RoundtripArtifact: + """In-memory artifact for the roundtrip backend. + + Roundtrip generates and ``exec``\\ s a Python module per program, so its + output is a live callable rather than something on disk. Not picklable — + roundtrip is in-process only. + """ + + program: stages.ExecutableProgram + + def materialize(self) -> stages.ExecutableProgram: + return self.program + + +@dataclasses.dataclass(frozen=True) +class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, RoundtripArtifact]): debug: Optional[bool] = None use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` - def __call__(self, inp: definitions.CompilableProgramDef) -> stages.ExecutableProgram: + def __call__(self, inp: definitions.CompilableProgramDef) -> RoundtripArtifact: debug = config.DEBUG if self.debug is None else self.debug fencil = fencil_generator( @@ -249,7 +264,7 @@ def decorated_fencil( **kwargs, ) - return decorated_fencil + return RoundtripArtifact(program=decorated_fencil) # TODO(tehrengruber): introduce factory From efecdf5b26da61897b4d5ad8146d324865f5c4e1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 18:01:01 +0200 Subject: [PATCH 09/35] separate gtfn from generic compiler --- .../otf/compilation/build_orchestrator.py | 81 +++++++++++++++++++ .../next/otf/compilation/build_system.py | 33 ++++++++ .../otf/compilation/build_systems/cmake.py | 4 +- .../compilation/build_systems/compiledb.py | 4 +- .../next/program_processors/runners/gtfn.py | 7 +- .../runners/gtfn_compiler.py} | 62 +++----------- .../otf_tests/test_nanobind_build.py | 34 ++++++-- 7 files changed, 163 insertions(+), 62 deletions(-) create mode 100644 src/gt4py/next/otf/compilation/build_orchestrator.py create mode 100644 src/gt4py/next/otf/compilation/build_system.py rename src/gt4py/next/{otf/compilation/compiler.py => program_processors/runners/gtfn_compiler.py} (54%) diff --git a/src/gt4py/next/otf/compilation/build_orchestrator.py b/src/gt4py/next/otf/compilation/build_orchestrator.py new file mode 100644 index 0000000000..4aea7b4a0d --- /dev/null +++ b/src/gt4py/next/otf/compilation/build_orchestrator.py @@ -0,0 +1,81 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Generic build orchestration for backends that produce a Python extension module. + +Wraps the lock + build-data tracking + builder-factory invocation + post-build +validation into a single :func:`run_build` call. Returns a :class:`BuildResult` +descriptor (paths + entry-point name) that backends wrap into their own +artifact type. +""" + +from __future__ import annotations + +import dataclasses +import pathlib +from typing import TypeVar + +from gt4py._core import locking +from gt4py.next import config +from gt4py.next.otf import code_specs, stages +from gt4py.next.otf.compilation import build_data, build_system, cache + + +CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) +TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) + + +@dataclasses.dataclass(frozen=True) +class BuildResult: + """On-disk descriptor of a successful build.""" + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + + +class CompilationError(RuntimeError): ... + + +def is_compiled(data: build_data.BuildData) -> bool: + return data.status >= build_data.BuildStatus.COMPILED + + +def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: + return (src_dir / data.module).exists() + + +def run_build( + inp: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], + cache_lifetime: config.BuildCacheLifetime, + builder_factory: build_system.BuildSystemProjectGenerator[CodeSpecT, TargetCodeSpecT], + force_recompile: bool = False, +) -> BuildResult: + """Drive ``builder_factory`` to produce a Python extension module on disk.""" + src_dir = cache.get_cache_folder(inp, cache_lifetime) + + # If we are compiling the same program at the same time (e.g. multiple MPI ranks), + # we need to make sure that only one of them accesses the same build directory for compilation. + with locking.lock(src_dir): + data = build_data.read_data(src_dir) + + if not data or not is_compiled(data) or force_recompile: + builder_factory(inp, cache_lifetime).build() + + new_data = build_data.read_data(src_dir) + + if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): + raise CompilationError( + f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." + ) + + return BuildResult( + src_dir=src_dir, + module=new_data.module, + entry_point_name=new_data.entry_point_name, + ) diff --git a/src/gt4py/next/otf/compilation/build_system.py b/src/gt4py/next/otf/compilation/build_system.py new file mode 100644 index 0000000000..a7ce6d957e --- /dev/null +++ b/src/gt4py/next/otf/compilation/build_system.py @@ -0,0 +1,33 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import Protocol, TypeVar + +from gt4py.next import config +from gt4py.next.otf import code_specs, stages + + +CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) +TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) + + +class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): + """Factory protocol for build-system implementations. + + Given a :class:`stages.CompilableProject` and a cache lifetime, returns a + :class:`stages.BuildSystemProject` that drives the actual build (e.g. + cmake, compiledb). + """ + + def __call__( + self, + source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], + cache_lifetime: config.BuildCacheLifetime, + ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index 1b79cad6e4..dd9158d8c4 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -18,7 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import config, errors from gt4py.next.otf import code_specs, stages -from gt4py.next.otf.compilation import build_data, cache, common, compiler +from gt4py.next.otf.compilation import build_data, build_system, cache, common from gt4py.next.otf.compilation.build_systems import cmake_lists @@ -64,7 +64,7 @@ def get_cmake_device_arch_option() -> str: @dataclasses.dataclass class CMakeFactory( - compiler.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + build_system.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] ): """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 347b0e25e9..756a24ee38 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -20,7 +20,7 @@ from gt4py.next import config, errors from gt4py.next.otf import code_specs, stages from gt4py.next.otf.binding import interface -from gt4py.next.otf.compilation import build_data, cache, compiler +from gt4py.next.otf.compilation import build_data, build_system, cache from gt4py.next.otf.compilation.build_systems import cmake @@ -29,7 +29,7 @@ @dataclasses.dataclass class CompiledbFactory( - compiler.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + build_system.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] ): """ Create a CompiledbProject from a ``CompilableSource`` stage object with given CMake settings. diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index bddf06a1f3..95bacf6dba 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -14,9 +14,10 @@ from gt4py.next import backend, config from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler +from gt4py.next.otf.compilation import build_system from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module +from gt4py.next.program_processors.runners import gtfn_compiler class GTFNBuildWorkflowFactory(factory.Factory): @@ -28,7 +29,7 @@ class Params: cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough lambda: config.CMAKE_BUILD_TYPE ) - builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough + builder_factory: build_system.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) @@ -53,7 +54,7 @@ class Params: nanobind.bind_source ) compilation = factory.SubFactory( - compiler.CompilerFactory, + gtfn_compiler.CompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), device_type=factory.SelfAttribute("..device_type"), diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/program_processors/runners/gtfn_compiler.py similarity index 54% rename from src/gt4py/next/otf/compilation/compiler.py rename to src/gt4py/next/program_processors/runners/gtfn_compiler.py index a858a0b02d..73e683a7fb 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/program_processors/runners/gtfn_compiler.py @@ -10,41 +10,20 @@ import dataclasses import pathlib -from typing import Protocol, TypeVar +from typing import TypeVar import factory -from gt4py._core import definitions as core_defs, locking +from gt4py._core import definitions as core_defs from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow -from gt4py.next.otf.compilation import build_data, cache, importer +from gt4py.next.otf.compilation import build_orchestrator, build_system, importer from gt4py.next.program_processors.runners import gtfn_decoration -T = TypeVar("T") - - -def is_compiled(data: build_data.BuildData) -> bool: - return data.status >= build_data.BuildStatus.COMPILED - - -def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: - return (src_dir / data.module).exists() - - -CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) -TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) -class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): - def __call__( - self, - source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], - cache_lifetime: config.BuildCacheLifetime, - ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... - - @dataclasses.dataclass(frozen=True) class GTFNBuildArtifact: """On-disk result of a GTFN compilation: a Python extension module. @@ -87,10 +66,12 @@ class Compiler( ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Use any build system (via configured factory) to compile a GT4Py program into a :class:`GTFNBuildArtifact`.""" + """Drive a build system and wrap the result in a :class:`GTFNBuildArtifact`.""" cache_lifetime: config.BuildCacheLifetime - builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + builder_factory: build_system.BuildSystemProjectGenerator[ + CPPLikeCodeSpecT, code_specs.PythonCodeSpec + ] device_type: core_defs.DeviceType force_recompile: bool = False @@ -98,27 +79,13 @@ def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ) -> GTFNBuildArtifact: - src_dir = cache.get_cache_folder(inp, self.cache_lifetime) - - # If we are compiling the same program at the same time (e.g. multiple MPI ranks), - # we need to make sure that only one of them accesses the same build directory for compilation. - with locking.lock(src_dir): - data = build_data.read_data(src_dir) - - if not data or not is_compiled(data) or self.force_recompile: - self.builder_factory(inp, self.cache_lifetime).build() - - new_data = build_data.read_data(src_dir) - - if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): - raise CompilationError( - f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." - ) - + result = build_orchestrator.run_build( + inp, self.cache_lifetime, self.builder_factory, self.force_recompile + ) return GTFNBuildArtifact( - src_dir=src_dir, - module=new_data.module, - entry_point_name=new_data.entry_point_name, + src_dir=result.src_dir, + module=result.module, + entry_point_name=result.entry_point_name, device_type=self.device_type, ) @@ -126,6 +93,3 @@ def __call__( class CompilerFactory(factory.Factory): class Meta: model = Compiler - - -class CompilationError(RuntimeError): ... diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 49bd7b8f87..5967c75544 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -10,25 +10,45 @@ import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler +from gt4py.next.otf.compilation import importer from gt4py.next.otf.compilation.build_systems import cmake, compiledb +from gt4py.next.program_processors.runners import gtfn_compiler from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( program_source_with_name, ) +def _import_artifact_entry_point(artifact: gtfn_compiler.GTFNBuildArtifact): + """Import the .so directly and return the raw entry point. + + Bypasses :meth:`GTFNBuildArtifact.materialize` so the test can call the + nanobind-bound function with raw arguments rather than gt4py-shaped ones — + this is a build-system / binding integration test, not an end-to-end + program test. + """ + m = importer.import_from_path( + artifact.src_dir / artifact.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return getattr(m, artifact.entry_point_name) + + def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( - compiler.Compiler( - cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory() + gtfn_compiler.Compiler( + cache_lifetime=config.BuildCacheLifetime.SESSION, + builder_factory=cmake.CMakeFactory(), + device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source) + artifact = build_the_program(example_program_source) + compiled_program = _import_artifact_entry_point(artifact) buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), @@ -42,12 +62,14 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): def test_gtfn_cpp_with_compiledb(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_compiledb") build_the_program = workflow.make_step(nanobind.bind_source).chain( - compiler.Compiler( + gtfn_compiler.Compiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), + device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source) + artifact = build_the_program(example_program_source) + compiled_program = _import_artifact_entry_point(artifact) buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), From 68ead697c2ff57cbd84cf7de4144f8a142a589e4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 18:39:28 +0200 Subject: [PATCH 10/35] don't serialize/deserialize dace in the same process --- .../runners/dace/workflow/compilation.py | 51 ++++++++++++++++--- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 71b7934fc1..1a374491e7 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -10,12 +10,13 @@ import dataclasses import pathlib +from typing import Optional import dace import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import config +from gt4py.next import config, utils as gtx_utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import ( @@ -26,7 +27,7 @@ @dataclasses.dataclass(frozen=True) -class DaCeBuildArtifact: +class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" build_folder: pathlib.Path @@ -34,11 +35,35 @@ class DaCeBuildArtifact: bind_func_name: str device_type: core_defs.DeviceType + # Process-local cache of the live :class:`CompiledDaceProgram`. Populated by + # ``DaCeCompiler`` after a fresh compile so :meth:`materialize` can skip the + # SDFG re-deserialize + .so re-link round-trip in the same process. Marked + # ``pickle=False`` via :func:`gtx_utils.gt4py_metadata` so a receiver of the + # artifact in a different process sees ``None`` and falls back to the + # disk-based path. + _live_program: Optional[CompiledDaceProgram] = dataclasses.field( + init=False, + default=None, + compare=False, + repr=False, + metadata=gtx_utils.gt4py_metadata(pickle=False), + ) + def materialize(self) -> stages.ExecutableProgram: - """Re-deserialize the SDFG, link the .so, and wrap in gt4py's calling convention. + """Wrap the compiled program in gt4py's calling convention. - Must run in the process that will call the returned program. + Uses the live program cached on the artifact when available; otherwise + re-deserializes the SDFG, re-links the .so via ``compiler.use_cache``, + and caches the result for subsequent calls. Must run in the process + that will call the returned program. """ + program = self._live_program + if program is None: + program = self._load_compiled_program() + object.__setattr__(self, "_live_program", program) + return gtx_wfddecoration.convert_args(program, device=self.device_type) + + def _load_compiled_program(self) -> CompiledDaceProgram: for dump_name in ("program.sdfgz", "program.sdfg"): sdfg_dump = self.build_folder / dump_name if sdfg_dump.exists(): @@ -55,8 +80,7 @@ def materialize(self) -> stages.ExecutableProgram: with dace.config.set_temporary("compiler", "use_cache", value=True): sdfg_program = sdfg.compile(validate=False) - program = CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) - return gtx_wfddecoration.convert_args(program, device=self.device_type) + return CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) @dataclasses.dataclass(frozen=True) @@ -92,15 +116,26 @@ def __call__( sdfg = dace.SDFG.from_json(inp.program_source.source_code) sdfg.build_folder = sdfg_build_folder with locking.lock(sdfg_build_folder): - sdfg.compile(validate=False, return_program_handle=False) + # Keep the program handle so the artifact's materialize() can + # skip the SDFG re-deserialize + .so re-link round-trip when + # used in this same process. + sdfg_program = sdfg.compile(validate=False) assert inp.binding_source is not None - return DaCeBuildArtifact( + artifact = DaCeBuildArtifact( build_folder=pathlib.Path(sdfg_build_folder), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, device_type=self.device_type, ) + object.__setattr__( + artifact, + "_live_program", + CompiledDaceProgram( + sdfg_program, artifact.bind_func_name, artifact.binding_source_code + ), + ) + return artifact class DaCeCompilationStepFactory(factory.Factory): From 4702660ba9889bb8224e38d70b0551def93ebd3d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 19:21:46 +0200 Subject: [PATCH 11/35] don't search for sdfg in materialize --- .../runners/dace/workflow/compilation.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 1a374491e7..2f8365e289 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -31,6 +31,7 @@ class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" build_folder: pathlib.Path + sdfg_dump: pathlib.Path binding_source_code: str bind_func_name: str device_type: core_defs.DeviceType @@ -64,16 +65,7 @@ def materialize(self) -> stages.ExecutableProgram: return gtx_wfddecoration.convert_args(program, device=self.device_type) def _load_compiled_program(self) -> CompiledDaceProgram: - for dump_name in ("program.sdfgz", "program.sdfg"): - sdfg_dump = self.build_folder / dump_name - if sdfg_dump.exists(): - break - else: - raise RuntimeError( - f"No SDFG dump (program.sdfgz / program.sdfg) found in '{self.build_folder}'." - ) - - sdfg = dace.SDFG.from_file(str(sdfg_dump)) + sdfg = dace.SDFG.from_file(str(self.sdfg_dump)) sdfg.build_folder = str(self.build_folder) with gtx_wfdcommon.dace_context(device_type=self.device_type): @@ -110,20 +102,30 @@ def __call__( device_type=self.device_type, cmake_build_type=self.cmake_build_type, ): - sdfg_build_folder = gtx_cache.get_cache_folder(inp, self.cache_lifetime) + sdfg_build_folder = pathlib.Path(gtx_cache.get_cache_folder(inp, self.cache_lifetime)) sdfg_build_folder.mkdir(parents=True, exist_ok=True) sdfg = dace.SDFG.from_json(inp.program_source.source_code) - sdfg.build_folder = sdfg_build_folder + sdfg.build_folder = str(sdfg_build_folder) with locking.lock(sdfg_build_folder): # Keep the program handle so the artifact's materialize() can # skip the SDFG re-deserialize + .so re-link round-trip when # used in this same process. sdfg_program = sdfg.compile(validate=False) + for dump_name in ("program.sdfgz", "program.sdfg"): + sdfg_dump = sdfg_build_folder / dump_name + if sdfg_dump.exists(): + break + else: + raise RuntimeError( + f"No SDFG dump (program.sdfgz / program.sdfg) found in '{sdfg_build_folder}'." + ) + assert inp.binding_source is not None artifact = DaCeBuildArtifact( - build_folder=pathlib.Path(sdfg_build_folder), + build_folder=sdfg_build_folder, + sdfg_dump=sdfg_dump, binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, device_type=self.device_type, From 77b6235c01fb07987d5f9386b9760c48c947ddd4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:21:30 +0200 Subject: [PATCH 12/35] restore generic compiler --- .../otf/compilation/build_orchestrator.py | 81 ----------- .../next/otf/compilation/build_system.py | 33 ----- .../otf/compilation/build_systems/cmake.py | 4 +- .../compilation/build_systems/compiledb.py | 4 +- src/gt4py/next/otf/compilation/compiler.py | 135 ++++++++++++++++++ .../next/program_processors/runners/gtfn.py | 9 +- .../runners/gtfn_compiler.py | 95 ------------ .../otf_tests/test_nanobind_build.py | 17 ++- 8 files changed, 155 insertions(+), 223 deletions(-) delete mode 100644 src/gt4py/next/otf/compilation/build_orchestrator.py delete mode 100644 src/gt4py/next/otf/compilation/build_system.py create mode 100644 src/gt4py/next/otf/compilation/compiler.py delete mode 100644 src/gt4py/next/program_processors/runners/gtfn_compiler.py diff --git a/src/gt4py/next/otf/compilation/build_orchestrator.py b/src/gt4py/next/otf/compilation/build_orchestrator.py deleted file mode 100644 index 4aea7b4a0d..0000000000 --- a/src/gt4py/next/otf/compilation/build_orchestrator.py +++ /dev/null @@ -1,81 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Generic build orchestration for backends that produce a Python extension module. - -Wraps the lock + build-data tracking + builder-factory invocation + post-build -validation into a single :func:`run_build` call. Returns a :class:`BuildResult` -descriptor (paths + entry-point name) that backends wrap into their own -artifact type. -""" - -from __future__ import annotations - -import dataclasses -import pathlib -from typing import TypeVar - -from gt4py._core import locking -from gt4py.next import config -from gt4py.next.otf import code_specs, stages -from gt4py.next.otf.compilation import build_data, build_system, cache - - -CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) -TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) - - -@dataclasses.dataclass(frozen=True) -class BuildResult: - """On-disk descriptor of a successful build.""" - - src_dir: pathlib.Path - module: pathlib.Path - entry_point_name: str - - -class CompilationError(RuntimeError): ... - - -def is_compiled(data: build_data.BuildData) -> bool: - return data.status >= build_data.BuildStatus.COMPILED - - -def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: - return (src_dir / data.module).exists() - - -def run_build( - inp: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], - cache_lifetime: config.BuildCacheLifetime, - builder_factory: build_system.BuildSystemProjectGenerator[CodeSpecT, TargetCodeSpecT], - force_recompile: bool = False, -) -> BuildResult: - """Drive ``builder_factory`` to produce a Python extension module on disk.""" - src_dir = cache.get_cache_folder(inp, cache_lifetime) - - # If we are compiling the same program at the same time (e.g. multiple MPI ranks), - # we need to make sure that only one of them accesses the same build directory for compilation. - with locking.lock(src_dir): - data = build_data.read_data(src_dir) - - if not data or not is_compiled(data) or force_recompile: - builder_factory(inp, cache_lifetime).build() - - new_data = build_data.read_data(src_dir) - - if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): - raise CompilationError( - f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." - ) - - return BuildResult( - src_dir=src_dir, - module=new_data.module, - entry_point_name=new_data.entry_point_name, - ) diff --git a/src/gt4py/next/otf/compilation/build_system.py b/src/gt4py/next/otf/compilation/build_system.py deleted file mode 100644 index a7ce6d957e..0000000000 --- a/src/gt4py/next/otf/compilation/build_system.py +++ /dev/null @@ -1,33 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from typing import Protocol, TypeVar - -from gt4py.next import config -from gt4py.next.otf import code_specs, stages - - -CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) -TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) - - -class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): - """Factory protocol for build-system implementations. - - Given a :class:`stages.CompilableProject` and a cache lifetime, returns a - :class:`stages.BuildSystemProject` that drives the actual build (e.g. - cmake, compiledb). - """ - - def __call__( - self, - source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], - cache_lifetime: config.BuildCacheLifetime, - ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index dd9158d8c4..1b79cad6e4 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -18,7 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import config, errors from gt4py.next.otf import code_specs, stages -from gt4py.next.otf.compilation import build_data, build_system, cache, common +from gt4py.next.otf.compilation import build_data, cache, common, compiler from gt4py.next.otf.compilation.build_systems import cmake_lists @@ -64,7 +64,7 @@ def get_cmake_device_arch_option() -> str: @dataclasses.dataclass class CMakeFactory( - build_system.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + compiler.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] ): """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 756a24ee38..347b0e25e9 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -20,7 +20,7 @@ from gt4py.next import config, errors from gt4py.next.otf import code_specs, stages from gt4py.next.otf.binding import interface -from gt4py.next.otf.compilation import build_data, build_system, cache +from gt4py.next.otf.compilation import build_data, cache, compiler from gt4py.next.otf.compilation.build_systems import cmake @@ -29,7 +29,7 @@ @dataclasses.dataclass class CompiledbFactory( - build_system.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + compiler.BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] ): """ Create a CompiledbProject from a ``CompilableSource`` stage object with given CMake settings. diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py new file mode 100644 index 0000000000..b128416a48 --- /dev/null +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -0,0 +1,135 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import dataclasses +import pathlib +from typing import Callable, Protocol, TypeVar + +import factory + +from gt4py._core import definitions as core_defs, locking +from gt4py.next import config, utils as gtx_utils +from gt4py.next.otf import code_specs, definitions, stages, workflow +from gt4py.next.otf.compilation import build_data, cache, importer + + +CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) +TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) +CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) + + +class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): + def __call__( + self, + source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], + cache_lifetime: config.BuildCacheLifetime, + ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... + + +def is_compiled(data: build_data.BuildData) -> bool: + return data.status >= build_data.BuildStatus.COMPILED + + +def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: + return (src_dir / data.module).exists() + + +class CompilationError(RuntimeError): ... + + +# Signature of the per-backend wrapping applied to a freshly imported entry point. +ProgramDecorator = Callable[ + [stages.ExecutableProgram, core_defs.DeviceType], stages.ExecutableProgram +] + + +@dataclasses.dataclass(frozen=True) +class CPPBuildArtifact(gtx_utils.MetadataBasedPickling): + """On-disk result of a CPP-style compilation: a Python extension module. + + Bindings are baked into the .so (e.g. via nanobind), so :meth:`materialize` + is just an ``importlib`` import + entry-point lookup, plus a per-backend + :attr:`decorator` that adapts the raw callable to the backend's calling + convention. + """ + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + device_type: core_defs.DeviceType + decorator: ProgramDecorator + + def materialize(self) -> stages.ExecutableProgram: + """Import the module and apply the configured per-backend decorator. + + Must run in the process that will call the returned program: the + module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. + """ + m = importer.import_from_path( + self.src_dir / self.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return self.decorator(getattr(m, self.entry_point_name), self.device_type) + + +@dataclasses.dataclass(frozen=True) +class CPPCompiler( + workflow.ChainableWorkflowMixin[ + stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], + CPPBuildArtifact, + ], + workflow.ReplaceEnabledWorkflowMixin[ + stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], + CPPBuildArtifact, + ], + definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], +): + """Drive a CPP-style build system and wrap the result in a :class:`CPPBuildArtifact`.""" + + cache_lifetime: config.BuildCacheLifetime + builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + device_type: core_defs.DeviceType + decorator: ProgramDecorator + force_recompile: bool = False + + def __call__( + self, + inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], + ) -> CPPBuildArtifact: + src_dir = cache.get_cache_folder(inp, self.cache_lifetime) + + # If we are compiling the same program at the same time (e.g. multiple MPI ranks), + # we need to make sure that only one of them accesses the same build directory for compilation. + with locking.lock(src_dir): + data = build_data.read_data(src_dir) + + if not data or not is_compiled(data) or self.force_recompile: + self.builder_factory(inp, self.cache_lifetime).build() + + new_data = build_data.read_data(src_dir) + + if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): + raise CompilationError( + f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." + ) + + return CPPBuildArtifact( + src_dir=src_dir, + module=new_data.module, + entry_point_name=new_data.entry_point_name, + device_type=self.device_type, + decorator=self.decorator, + ) + + +class CompilerFactory(factory.Factory): + class Meta: + model = CPPCompiler diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 95bacf6dba..b7d277a383 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -14,10 +14,10 @@ from gt4py.next import backend, config from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import build_system +from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module -from gt4py.next.program_processors.runners import gtfn_compiler +from gt4py.next.program_processors.runners import gtfn_decoration class GTFNBuildWorkflowFactory(factory.Factory): @@ -29,7 +29,7 @@ class Params: cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough lambda: config.CMAKE_BUILD_TYPE ) - builder_factory: build_system.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough + builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) @@ -54,10 +54,11 @@ class Params: nanobind.bind_source ) compilation = factory.SubFactory( - gtfn_compiler.CompilerFactory, + compiler.CompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), device_type=factory.SelfAttribute("..device_type"), + decorator=gtfn_decoration.convert_args, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn_compiler.py b/src/gt4py/next/program_processors/runners/gtfn_compiler.py deleted file mode 100644 index 73e683a7fb..0000000000 --- a/src/gt4py/next/program_processors/runners/gtfn_compiler.py +++ /dev/null @@ -1,95 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -import pathlib -from typing import TypeVar - -import factory - -from gt4py._core import definitions as core_defs -from gt4py.next import config -from gt4py.next.otf import code_specs, definitions, stages, workflow -from gt4py.next.otf.compilation import build_orchestrator, build_system, importer -from gt4py.next.program_processors.runners import gtfn_decoration - - -CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) - - -@dataclasses.dataclass(frozen=True) -class GTFNBuildArtifact: - """On-disk result of a GTFN compilation: a Python extension module. - - Bindings are baked into the .so via nanobind, so :meth:`materialize` is - just an ``importlib`` import + entry-point symbol lookup, plus a wrap in - gt4py's calling convention. - """ - - src_dir: pathlib.Path - module: pathlib.Path - entry_point_name: str - device_type: core_defs.DeviceType - - def materialize(self) -> stages.ExecutableProgram: - """Import the module and wrap its entry point in gt4py's calling convention. - - Must run in the process that will call the returned program: the - module is registered in that process's ``sys.modules`` under the - ``gt4py.__compiled_programs__.`` prefix. - """ - m = importer.import_from_path( - self.src_dir / self.module, - sys_modules_prefix="gt4py.__compiled_programs__.", - ) - return gtfn_decoration.convert_args( - getattr(m, self.entry_point_name), device=self.device_type - ) - - -@dataclasses.dataclass(frozen=True) -class Compiler( - workflow.ChainableWorkflowMixin[ - stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - GTFNBuildArtifact, - ], - workflow.ReplaceEnabledWorkflowMixin[ - stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - GTFNBuildArtifact, - ], - definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], -): - """Drive a build system and wrap the result in a :class:`GTFNBuildArtifact`.""" - - cache_lifetime: config.BuildCacheLifetime - builder_factory: build_system.BuildSystemProjectGenerator[ - CPPLikeCodeSpecT, code_specs.PythonCodeSpec - ] - device_type: core_defs.DeviceType - force_recompile: bool = False - - def __call__( - self, - inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> GTFNBuildArtifact: - result = build_orchestrator.run_build( - inp, self.cache_lifetime, self.builder_factory, self.force_recompile - ) - return GTFNBuildArtifact( - src_dir=result.src_dir, - module=result.module, - entry_point_name=result.entry_point_name, - device_type=self.device_type, - ) - - -class CompilerFactory(factory.Factory): - class Meta: - model = Compiler diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 5967c75544..b4b864a8d2 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -14,19 +14,18 @@ from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import importer +from gt4py.next.otf.compilation import compiler, importer from gt4py.next.otf.compilation.build_systems import cmake, compiledb -from gt4py.next.program_processors.runners import gtfn_compiler from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( program_source_with_name, ) -def _import_artifact_entry_point(artifact: gtfn_compiler.GTFNBuildArtifact): +def _import_artifact_entry_point(artifact: compiler.CPPBuildArtifact): """Import the .so directly and return the raw entry point. - Bypasses :meth:`GTFNBuildArtifact.materialize` so the test can call the + Bypasses :meth:`CPPBuildArtifact.materialize` so the test can call the nanobind-bound function with raw arguments rather than gt4py-shaped ones — this is a build-system / binding integration test, not an end-to-end program test. @@ -38,13 +37,18 @@ def _import_artifact_entry_point(artifact: gtfn_compiler.GTFNBuildArtifact): return getattr(m, artifact.entry_point_name) +def _identity(raw, _device): + return raw + + def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( - gtfn_compiler.Compiler( + compiler.CPPCompiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory(), device_type=core_defs.DeviceType.CPU, + decorator=_identity, ) ) artifact = build_the_program(example_program_source) @@ -62,10 +66,11 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): def test_gtfn_cpp_with_compiledb(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_compiledb") build_the_program = workflow.make_step(nanobind.bind_source).chain( - gtfn_compiler.Compiler( + compiler.CPPCompiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), device_type=core_defs.DeviceType.CPU, + decorator=_identity, ) ) artifact = build_the_program(example_program_source) From 3869d67443bf4afc7156255e27118f62e60b5563 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:27:58 +0200 Subject: [PATCH 13/35] cleanup --- src/gt4py/next/otf/compilation/compiler.py | 22 ++++++++--------- .../next/program_processors/runners/gtfn.py | 3 +-- .../otf_tests/test_nanobind_build.py | 24 ++++--------------- 3 files changed, 16 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index b128416a48..5063fb1c1f 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -20,6 +20,14 @@ from gt4py.next.otf.compilation import build_data, cache, importer +def is_compiled(data: build_data.BuildData) -> bool: + return data.status >= build_data.BuildStatus.COMPILED + + +def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: + return (src_dir / data.module).exists() + + CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) @@ -33,17 +41,6 @@ def __call__( ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... -def is_compiled(data: build_data.BuildData) -> bool: - return data.status >= build_data.BuildStatus.COMPILED - - -def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: - return (src_dir / data.module).exists() - - -class CompilationError(RuntimeError): ... - - # Signature of the per-backend wrapping applied to a freshly imported entry point. ProgramDecorator = Callable[ [stages.ExecutableProgram, core_defs.DeviceType], stages.ExecutableProgram @@ -133,3 +130,6 @@ def __call__( class CompilerFactory(factory.Factory): class Meta: model = CPPCompiler + + +class CompilationError(RuntimeError): ... diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index b7d277a383..dc009c376b 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -100,8 +100,7 @@ class Params: run_gtfn = GTFNBackendFactory() run_gtfn_imperative = GTFNBackendFactory( - name_postfix="_imperative", - otf_workflow__translation__use_imperative_backend=True, + name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True ) run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index b4b864a8d2..27ca7c16f1 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -14,7 +14,7 @@ from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import compiler, importer +from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import cmake, compiledb from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( @@ -22,22 +22,8 @@ ) -def _import_artifact_entry_point(artifact: compiler.CPPBuildArtifact): - """Import the .so directly and return the raw entry point. - - Bypasses :meth:`CPPBuildArtifact.materialize` so the test can call the - nanobind-bound function with raw arguments rather than gt4py-shaped ones — - this is a build-system / binding integration test, not an end-to-end - program test. - """ - m = importer.import_from_path( - artifact.src_dir / artifact.module, - sys_modules_prefix="gt4py.__compiled_programs__.", - ) - return getattr(m, artifact.entry_point_name) - - def _identity(raw, _device): + """Pass-through decorator: this test calls the nanobind-bound function with raw args.""" return raw @@ -51,8 +37,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): decorator=_identity, ) ) - artifact = build_the_program(example_program_source) - compiled_program = _import_artifact_entry_point(artifact) + compiled_program = build_the_program(example_program_source).materialize() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), @@ -73,8 +58,7 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): decorator=_identity, ) ) - artifact = build_the_program(example_program_source) - compiled_program = _import_artifact_entry_point(artifact) + compiled_program = build_the_program(example_program_source).materialize() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), From 9017b33dcd9431629da712244c5a218e0ffd8d24 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:30:06 +0200 Subject: [PATCH 14/35] more cleanup --- .../runners/dace/program.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 1435080f52..e1ddbee455 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -76,19 +76,17 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: gt4py_program_args=[p.type for p in program.params], ) - # The executor and the translation stage may each be wrapped in a `CachedStep` - # depending on backend configuration; unwrap when so. - build_workflow = typing.cast( + compile_workflow = typing.cast( recipes.OTFBuildWorkflow, - self.backend.executor.step - if hasattr(self.backend.executor, "step") - else self.backend.executor, - ) + self.backend.executor + if not hasattr(self.backend.executor, "step") + else self.backend.executor.step, + ) # We know which backend we are using, but we don't know if the compile workflow is cached. compile_workflow_translation = ( - build_workflow.translation.step - if hasattr(build_workflow.translation, "step") - else build_workflow.translation - ) + compile_workflow.translation + if not hasattr(compile_workflow.translation, "step") + else compile_workflow.translation.step + ) # Same for the translation stage, which could be a `CachedStep` depending on backend configuration. # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with # the other parts of the workaround when possible. From 5bffec05e154497216d073a5a2d9f36b1534e860 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:32:39 +0200 Subject: [PATCH 15/35] add tests --- .../compilation_tests/test_compiler.py | 32 +++++++++++++++++ .../dace_tests/test_dace_compilation.py | 36 +++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py new file mode 100644 index 0000000000..250c026920 --- /dev/null +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -0,0 +1,32 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Minimal contract tests for :class:`compiler.CPPBuildArtifact`.""" + +import pathlib +import pickle + +from gt4py._core import definitions as core_defs +from gt4py.next.otf.compilation import compiler + + +def _identity_decorator(raw, _device): + return raw + + +def test_cpp_build_artifact_pickle_round_trip(): + artifact = compiler.CPPBuildArtifact( + src_dir=pathlib.Path("/tmp/build"), + module=pathlib.Path("entry.so"), + entry_point_name="entry", + device_type=core_defs.DeviceType.CPU, + decorator=_identity_decorator, + ) + restored = pickle.loads(pickle.dumps(artifact)) + assert restored == artifact + assert restored.decorator is _identity_decorator diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py new file mode 100644 index 0000000000..1149f3e131 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -0,0 +1,36 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Minimal contract tests for :class:`compilation.DaCeBuildArtifact`.""" + +import pathlib +import pickle + +import pytest + +pytest.importorskip("dace") + +from gt4py._core import definitions as core_defs # noqa: E402 +from gt4py.next.program_processors.runners.dace.workflow import compilation # noqa: E402 + + +def test_dace_build_artifact_pickle_round_trip_drops_live_program(): + artifact = compilation.DaCeBuildArtifact( + build_folder=pathlib.Path("/tmp/build"), + sdfg_dump=pathlib.Path("/tmp/build/program.sdfgz"), + binding_source_code="def update_sdfg_args(*a, **k): ...", + bind_func_name="update_sdfg_args", + device_type=core_defs.DeviceType.CPU, + ) + object.__setattr__(artifact, "_live_program", "") + + restored = pickle.loads(pickle.dumps(artifact)) + + # The data fields round-trip, the live in-process handle does not. + assert restored == artifact + assert restored._live_program is None From 7f219d7ea6bb449936ec4b94261e5fe3efa57635 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:34:04 +0200 Subject: [PATCH 16/35] cleanup --- src/gt4py/next/program_processors/runners/roundtrip.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index d97b3ab238..7d7075157e 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -216,9 +216,9 @@ def fencil_generator( class RoundtripArtifact: """In-memory artifact for the roundtrip backend. - Roundtrip generates and ``exec``\\ s a Python module per program, so its - output is a live callable rather than something on disk. Not picklable — - roundtrip is in-process only. + Roundtrip generates a Python module per program and executes it directly, + so its output is a live callable rather than something on disk. Not + picklable — roundtrip is in-process only. """ program: stages.ExecutableProgram From 69c5ae86d6a5cfbc10bfd854c64f3a03574f63c5 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:49:26 +0200 Subject: [PATCH 17/35] avoid decoration needs picklable --- src/gt4py/next/otf/compilation/compiler.py | 39 +++--- .../next/program_processors/runners/gtfn.py | 118 +++++++++++++++++- .../runners/gtfn_decoration.py | 105 ---------------- .../otf_tests/test_nanobind_build.py | 7 -- .../compilation_tests/test_compiler.py | 6 - 5 files changed, 134 insertions(+), 141 deletions(-) delete mode 100644 src/gt4py/next/program_processors/runners/gtfn_decoration.py diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 5063fb1c1f..25ce0aba09 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,7 +10,7 @@ import dataclasses import pathlib -from typing import Callable, Protocol, TypeVar +from typing import Protocol, TypeVar import factory @@ -41,30 +41,24 @@ def __call__( ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... -# Signature of the per-backend wrapping applied to a freshly imported entry point. -ProgramDecorator = Callable[ - [stages.ExecutableProgram, core_defs.DeviceType], stages.ExecutableProgram -] - - @dataclasses.dataclass(frozen=True) class CPPBuildArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a CPP-style compilation: a Python extension module. - Bindings are baked into the .so (e.g. via nanobind), so :meth:`materialize` - is just an ``importlib`` import + entry-point lookup, plus a per-backend - :attr:`decorator` that adapts the raw callable to the backend's calling - convention. + Bindings are baked into the .so (e.g. via nanobind), so the default + :meth:`materialize` is just an ``importlib`` import + entry-point lookup, + returning the raw imported callable. Backends that need to wrap the + callable in a calling convention (e.g. GTFN's gt4py-shaped argument + conversion) subclass and override :meth:`materialize`. """ src_dir: pathlib.Path module: pathlib.Path entry_point_name: str device_type: core_defs.DeviceType - decorator: ProgramDecorator def materialize(self) -> stages.ExecutableProgram: - """Import the module and apply the configured per-backend decorator. + """Import the .so and return the raw entry point. Must run in the process that will call the returned program: the module is registered in that process's ``sys.modules`` under the @@ -74,7 +68,7 @@ def materialize(self) -> stages.ExecutableProgram: self.src_dir / self.module, sys_modules_prefix="gt4py.__compiled_programs__.", ) - return self.decorator(getattr(m, self.entry_point_name), self.device_type) + return getattr(m, self.entry_point_name) @dataclasses.dataclass(frozen=True) @@ -89,12 +83,15 @@ class CPPCompiler( ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Drive a CPP-style build system and wrap the result in a :class:`CPPBuildArtifact`.""" + """Drive a CPP-style build system and wrap the result in a :class:`CPPBuildArtifact`. + + Backends that need a different artifact subclass (e.g. with a wrapped + ``materialize``) subclass and override :meth:`_make_artifact`. + """ cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] device_type: core_defs.DeviceType - decorator: ProgramDecorator force_recompile: bool = False def __call__( @@ -118,12 +115,16 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) + return self._make_artifact(src_dir, new_data.module, new_data.entry_point_name) + + def _make_artifact( + self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str + ) -> CPPBuildArtifact: return CPPBuildArtifact( src_dir=src_dir, - module=new_data.module, - entry_point_name=new_data.entry_point_name, + module=module, + entry_point_name=entry_point_name, device_type=self.device_type, - decorator=self.decorator, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index dc009c376b..fd652c1ee8 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,18 +6,129 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import dataclasses +import pathlib +from typing import Any + import factory +import numpy as np import gt4py._core.definitions as core_defs import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, config +from gt4py.next import backend, common, config, field_utils +from gt4py.next.embedded import nd_array_field +from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module -from gt4py.next.program_processors.runners import gtfn_decoration + + +def convert_arg(arg: Any) -> Any: + # Note: this function is on the hot path and needs to have minimal overhead. + if (origin := getattr(arg, "__gt_origin__", None)) is not None: + # `Field` is the most likely case, we use `__gt_origin__` as the property is needed anyway + # and (currently) uniquely identifies a `NDArrayField` (which is the only supported `Field`) + assert isinstance(arg, nd_array_field.NdArrayField) + return arg.ndarray, origin + if isinstance(arg, tuple): + return tuple(convert_arg(a) for a in arg) + if isinstance(arg, np.bool_): + # nanobind does not support implicit conversion of `np.bool` to `bool` + return bool(arg) + # TODO(havogt): if this function still appears in profiles, + # we should avoid going through the previous isinstance checks for detecting a scalar. + # E.g. functools.cache on the arg type, returning a function that does the conversion + return arg + + +def convert_args( + inp: stages.ExecutableProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU +) -> stages.ExecutableProgram: + def decorated_program( + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, + ) -> None: + # Note: this function is on the hot path and needs to have minimal overhead. + if out is not None: + args = (*args, out) + converted_args = (convert_arg(arg) for arg in args) + conn_args = extract_connectivity_args(offset_provider, device) + + opt_kwargs: dict[str, Any] = {} + if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): + # If we are collecting metrics, we need to add the `exec_info` argument + # to the `inp` call, which will be used to collect performance metrics. + exec_info: dict[str, float] = {} + opt_kwargs["exec_info"] = exec_info + + # generate implicit domain size arguments only if necessary, using `iter_size_args()` + inp( + *converted_args, + *conn_args, + **opt_kwargs, + ) + + if collect_metrics: + metrics.add_sample_to_current_source( + metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] + ) + + return decorated_program + + +def extract_connectivity_args( + offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: + # Note: this function is on the hot path and needs to have minimal overhead. + zero_origin = (0, 0) + assert all( + hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) + for conn in offset_provider.values() + ) + # Note: the order here needs to agree with the order of the generated bindings. + # This is currently true only because when hashing offset provider dicts, + # the keys' order is taken into account. Any modification to the hashing + # of offset providers may break this assumption here. + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ + (ndarray, zero_origin) + for conn in offset_provider.values() + if (ndarray := getattr(conn, "ndarray", None)) is not None + ] + assert all( + common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) + for conn in offset_provider.values() + if hasattr(conn, "ndarray") + ) + + return args + + +@dataclasses.dataclass(frozen=True) +class GTFNBuildArtifact(compiler.CPPBuildArtifact): + def materialize(self) -> stages.ExecutableProgram: + return convert_args(super().materialize(), device=self.device_type) + + +@dataclasses.dataclass(frozen=True) +class GTFNCompiler(compiler.CPPCompiler): + def _make_artifact( + self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str + ) -> GTFNBuildArtifact: + return GTFNBuildArtifact( + src_dir=src_dir, + module=module, + entry_point_name=entry_point_name, + device_type=self.device_type, + ) + + +class GTFNCompilerFactory(factory.Factory): + class Meta: + model = GTFNCompiler class GTFNBuildWorkflowFactory(factory.Factory): @@ -54,11 +165,10 @@ class Params: nanobind.bind_source ) compilation = factory.SubFactory( - compiler.CompilerFactory, + GTFNCompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), device_type=factory.SelfAttribute("..device_type"), - decorator=gtfn_decoration.convert_args, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn_decoration.py b/src/gt4py/next/program_processors/runners/gtfn_decoration.py deleted file mode 100644 index 1ea2b222ca..0000000000 --- a/src/gt4py/next/program_processors/runners/gtfn_decoration.py +++ /dev/null @@ -1,105 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Calling-convention adapter for GTFN-compiled programs. - -Wraps a freshly-imported GTFN entry point with gt4py's user-facing -argument convention: unpacks fields, splits offset_provider into -connectivity args, attaches metric collection. -""" - -from typing import Any - -import numpy as np - -import gt4py._core.definitions as core_defs -from gt4py.next import common, field_utils -from gt4py.next.embedded import nd_array_field -from gt4py.next.instrumentation import metrics -from gt4py.next.otf import stages - - -def convert_arg(arg: Any) -> Any: - # Note: this function is on the hot path and needs to have minimal overhead. - if (origin := getattr(arg, "__gt_origin__", None)) is not None: - # `Field` is the most likely case, we use `__gt_origin__` as the property is needed anyway - # and (currently) uniquely identifies a `NDArrayField` (which is the only supported `Field`) - assert isinstance(arg, nd_array_field.NdArrayField) - return arg.ndarray, origin - if isinstance(arg, tuple): - return tuple(convert_arg(a) for a in arg) - if isinstance(arg, np.bool_): - # nanobind does not support implicit conversion of `np.bool` to `bool` - return bool(arg) - # TODO(havogt): if this function still appears in profiles, - # we should avoid going through the previous isinstance checks for detecting a scalar. - # E.g. functools.cache on the arg type, returning a function that does the conversion - return arg - - -def convert_args( - inp: stages.ExecutableProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU -) -> stages.ExecutableProgram: - def decorated_program( - *args: Any, - offset_provider: dict[str, common.Connectivity | common.Dimension], - out: Any = None, - ) -> None: - # Note: this function is on the hot path and needs to have minimal overhead. - if out is not None: - args = (*args, out) - converted_args = (convert_arg(arg) for arg in args) - conn_args = extract_connectivity_args(offset_provider, device) - - opt_kwargs: dict[str, Any] = {} - if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): - # If we are collecting metrics, we need to add the `exec_info` argument - # to the `inp` call, which will be used to collect performance metrics. - exec_info: dict[str, float] = {} - opt_kwargs["exec_info"] = exec_info - - # generate implicit domain size arguments only if necessary, using `iter_size_args()` - inp( - *converted_args, - *conn_args, - **opt_kwargs, - ) - - if collect_metrics: - metrics.add_sample_to_current_source( - metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] - ) - - return decorated_program - - -def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: - # Note: this function is on the hot path and needs to have minimal overhead. - zero_origin = (0, 0) - assert all( - hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) - for conn in offset_provider.values() - ) - # Note: the order here needs to agree with the order of the generated bindings. - # This is currently true only because when hashing offset provider dicts, - # the keys' order is taken into account. Any modification to the hashing - # of offset providers may break this assumption here. - args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ - (ndarray, zero_origin) - for conn in offset_provider.values() - if (ndarray := getattr(conn, "ndarray", None)) is not None - ] - assert all( - common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) - for conn in offset_provider.values() - if hasattr(conn, "ndarray") - ) - - return args diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 27ca7c16f1..77db222a11 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -22,11 +22,6 @@ ) -def _identity(raw, _device): - """Pass-through decorator: this test calls the nanobind-bound function with raw args.""" - return raw - - def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( @@ -34,7 +29,6 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory(), device_type=core_defs.DeviceType.CPU, - decorator=_identity, ) ) compiled_program = build_the_program(example_program_source).materialize() @@ -55,7 +49,6 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), device_type=core_defs.DeviceType.CPU, - decorator=_identity, ) ) compiled_program = build_the_program(example_program_source).materialize() diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py index 250c026920..806ea94c93 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -15,18 +15,12 @@ from gt4py.next.otf.compilation import compiler -def _identity_decorator(raw, _device): - return raw - - def test_cpp_build_artifact_pickle_round_trip(): artifact = compiler.CPPBuildArtifact( src_dir=pathlib.Path("/tmp/build"), module=pathlib.Path("entry.so"), entry_point_name="entry", device_type=core_defs.DeviceType.CPU, - decorator=_identity_decorator, ) restored = pickle.loads(pickle.dumps(artifact)) assert restored == artifact - assert restored.decorator is _identity_decorator From e05a33624ba8fe9fabeac18bb667701f74b44ee0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:52:03 +0200 Subject: [PATCH 18/35] cleanup --- .../runners/dace/workflow/compilation.py | 99 +++++++++++++++- .../runners/dace/workflow/compiled_program.py | 110 ------------------ .../runners/dace/workflow/decoration.py | 11 +- .../runners_tests/dace_tests/test_dace.py | 8 +- 4 files changed, 109 insertions(+), 119 deletions(-) delete mode 100644 src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 2f8365e289..2a70628ae8 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -9,21 +9,114 @@ from __future__ import annotations import dataclasses +import os import pathlib -from typing import Optional +import warnings +from collections.abc import Callable, MutableSequence, Sequence +from typing import Any, Optional import dace import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import config, utils as gtx_utils +from gt4py.next import common, config, utils as gtx_utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import ( common as gtx_wfdcommon, decoration as gtx_wfddecoration, ) -from gt4py.next.program_processors.runners.dace.workflow.compiled_program import CompiledDaceProgram + + +class CompiledDaceProgram: + sdfg_program: dace.CompiledSDFG + + # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; + # scalar arguments that are not used in the SDFG will not be present. + sdfg_argtypes: list[dace.dtypes.Data] + + # The compiled program contains a callable object to update the SDFG arguments list. + update_sdfg_ctype_arglist: Callable[ + [ + core_defs.DeviceType, + Sequence[dace.dtypes.Data], + Sequence[Any], + MutableSequence[Any], + common.OffsetProvider, + ], + None, + ] + + # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` + # means that it has not been initialized, i.e. no call was ever performed. + # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. + # - csdfg_init_argv: Arguments used for initialization; used only the first time and + # never updated. + csdfg_argv: MutableSequence[Any] | None + csdfg_init_argv: Sequence[Any] | None + + def __init__( + self, + program: dace.CompiledSDFG, + bind_func_name: str, + binding_source_code: str, + ): + self.sdfg_program = program + + # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument + # name to its data type, in the same order as arguments appear in the program ABI. + # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. + self.sdfg_argtypes = list(program.sdfg.arglist().values()) + + # The binding source code is Python tailored to this specific SDFG. + # We dynamically compile that function and add it to the compiled program. + global_namespace: dict[str, Any] = {} + exec(binding_source_code, global_namespace) + self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] + # For debug purpose, we set a unique module name on the compiled function. + self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) + + # Since the SDFG hasn't been called yet. + self.csdfg_argv = None + self.csdfg_init_argv = None + + def construct_arguments(self, **kwargs: Any) -> None: + """ + This function will process the arguments and store the processed argument + vectors in `self.csdfg_args`, to call them use `self.fast_call()`. + """ + with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): + csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) + # Note we only care about `csdfg_argv` (normal call), since we have to update it, + # we ensure that it is a `list`. + self.csdfg_argv = [*csdfg_argv] + self.csdfg_init_argv = csdfg_init_argv + + def fast_call(self) -> None: + """ + Perform a call to the compiled SDFG using the previously generated argument + vectors, see `self.construct_arguments()`. + """ + assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( + "Argument vector was not set properly." + ) + self.sdfg_program.fast_call( + self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG + ) + + def __call__(self, **kwargs: Any) -> None: + """Call the compiled SDFG with the given arguments. + + Note that this function will not update the argument vectors stored inside + `self`. Furthermore, it is not recommended to use this function as it is + very slow. + """ + warnings.warn( + "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", + stacklevel=1, + ) + result = self.sdfg_program(**kwargs) + assert result is None @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py b/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py deleted file mode 100644 index 5e28853902..0000000000 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compiled_program.py +++ /dev/null @@ -1,110 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import os -import warnings -from collections.abc import Callable, MutableSequence, Sequence -from typing import Any - -import dace - -from gt4py._core import definitions as core_defs -from gt4py.next import common, config - - -class CompiledDaceProgram: - sdfg_program: dace.CompiledSDFG - - # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; - # scalar arguments that are not used in the SDFG will not be present. - sdfg_argtypes: list[dace.dtypes.Data] - - # The compiled program contains a callable object to update the SDFG arguments list. - update_sdfg_ctype_arglist: Callable[ - [ - core_defs.DeviceType, - Sequence[dace.dtypes.Data], - Sequence[Any], - MutableSequence[Any], - common.OffsetProvider, - ], - None, - ] - - # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` - # means that it has not been initialized, i.e. no call was ever performed. - # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. - # - csdfg_init_argv: Arguments used for initialization; used only the first time and - # never updated. - csdfg_argv: MutableSequence[Any] | None - csdfg_init_argv: Sequence[Any] | None - - def __init__( - self, - program: dace.CompiledSDFG, - bind_func_name: str, - binding_source_code: str, - ): - self.sdfg_program = program - - # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument - # name to its data type, in the same order as arguments appear in the program ABI. - # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. - self.sdfg_argtypes = list(program.sdfg.arglist().values()) - - # The binding source code is Python tailored to this specific SDFG. - # We dynamically compile that function and add it to the compiled program. - global_namespace: dict[str, Any] = {} - exec(binding_source_code, global_namespace) - self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] - # For debug purpose, we set a unique module name on the compiled function. - self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) - - # Since the SDFG hasn't been called yet. - self.csdfg_argv = None - self.csdfg_init_argv = None - - def construct_arguments(self, **kwargs: Any) -> None: - """ - This function will process the arguments and store the processed argument - vectors in `self.csdfg_args`, to call them use `self.fast_call()`. - """ - with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): - csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) - # Note we only care about `csdfg_argv` (normal call), since we have to update it, - # we ensure that it is a `list`. - self.csdfg_argv = [*csdfg_argv] - self.csdfg_init_argv = csdfg_init_argv - - def fast_call(self) -> None: - """ - Perform a call to the compiled SDFG using the previously generated argument - vectors, see `self.construct_arguments()`. - """ - assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( - "Argument vector was not set properly." - ) - self.sdfg_program.fast_call( - self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG - ) - - def __call__(self, **kwargs: Any) -> None: - """Call the compiled SDFG with the given arguments. - - Note that this function will not update the argument vectors stored inside - `self`. Furthermore, it is not recommended to use this function as it is - very slow. - """ - warnings.warn( - "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", - stacklevel=1, - ) - result = self.sdfg_program(**kwargs) - assert result is None diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 6b828d5a97..27eb57a82b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence import numpy as np @@ -19,7 +19,14 @@ from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon -from gt4py.next.program_processors.runners.dace.workflow.compiled_program import CompiledDaceProgram + + +if TYPE_CHECKING: + # Type-only: evaluating ``compilation`` at module load would create a cycle + # (compilation imports this module for the materialize body). + from gt4py.next.program_processors.runners.dace.workflow.compilation import ( + CompiledDaceProgram, + ) def convert_args( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 06b3d428bb..a204886690 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -82,7 +82,7 @@ def make_mocks(monkeypatch): # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object mock_fast_call = unittest.mock.MagicMock() gt4py_fast_call = ( - gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram.fast_call + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.fast_call ) def mocked_fast_call(self): @@ -99,21 +99,21 @@ def mocked_fast_call(self): return fast_call_result monkeypatch.setattr( - gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram, + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, "fast_call", mocked_fast_call, ) # Wrap `compiled_sdfg.CompiledSDFG.construct_arguments` with mock object mock_construct_arguments = unittest.mock.MagicMock() - gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram.construct_arguments + gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.construct_arguments def mocked_construct_arguments(self, *args, **kwargs): mock_construct_arguments.__call__(*args, **kwargs) return gt4py_construct_arguments(self, *args, **kwargs) monkeypatch.setattr( - gtx.program_processors.runners.dace.workflow.compiled_program.CompiledDaceProgram, + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, "construct_arguments", mocked_construct_arguments, ) From 7e9fa08e2d38f49827476a424e51fa0f824406f1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 21:54:25 +0200 Subject: [PATCH 19/35] cleanup --- .../codegens_tests/gtfn_tests/test_gtfn_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index d027c9dcb1..3d50fbaf52 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -178,7 +178,7 @@ def testee(a: cases.IJKField) -> cases.IJKField: # first call: this generates the cache file cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) - # clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again + # clearing the OTFBuildWorkflow cache such that the OTFBuildWorkflow step is executed again object.__setattr__(cartesian_case.backend.executor, "cache", {}) # second call: the cache file is used cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) From 8f35cb7794813ea154a2f4fc6135873453c400dd Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 22:03:13 +0200 Subject: [PATCH 20/35] Build->Compile, materialize->load --- src/gt4py/next/backend.py | 4 ++-- src/gt4py/next/otf/compilation/compiler.py | 22 +++++++++---------- src/gt4py/next/otf/definitions.py | 12 +++++----- src/gt4py/next/otf/recipes.py | 8 +++---- src/gt4py/next/otf/stages.py | 8 +++---- .../runners/dace/program.py | 2 +- .../runners/dace/workflow/__init__.py | 2 +- .../runners/dace/workflow/compilation.py | 20 ++++++++--------- .../runners/dace/workflow/decoration.py | 2 +- .../runners/dace/workflow/factory.py | 2 +- .../next/program_processors/runners/gtfn.py | 16 +++++++------- .../program_processors/runners/roundtrip.py | 2 +- .../otf_tests/test_nanobind_build.py | 4 ++-- .../compilation_tests/test_compiler.py | 4 ++-- .../otf_tests/test_compiled_program.py | 4 ++-- .../gtfn_tests/test_gtfn_module.py | 2 +- .../dace_tests/test_dace_compilation.py | 4 ++-- 17 files changed, 60 insertions(+), 58 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index b7ad2b2d2c..ae599ece6d 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -147,7 +147,7 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): name: str - executor: workflow.Workflow[definitions.CompilableProgramDef, stages.BuildArtifact] + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.CompilationArtifact] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef] @@ -157,7 +157,7 @@ def compile( artifact = self.executor( self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args)) ) - return artifact.materialize() + return artifact.load() @property def __gt_allocator__( diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 25ce0aba09..4c0b2681aa 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -42,14 +42,14 @@ def __call__( @dataclasses.dataclass(frozen=True) -class CPPBuildArtifact(gtx_utils.MetadataBasedPickling): +class CPPCompilationArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a CPP-style compilation: a Python extension module. Bindings are baked into the .so (e.g. via nanobind), so the default - :meth:`materialize` is just an ``importlib`` import + entry-point lookup, + :meth:`load` is just an ``importlib`` import + entry-point lookup, returning the raw imported callable. Backends that need to wrap the callable in a calling convention (e.g. GTFN's gt4py-shaped argument - conversion) subclass and override :meth:`materialize`. + conversion) subclass and override :meth:`load`. """ src_dir: pathlib.Path @@ -57,7 +57,7 @@ class CPPBuildArtifact(gtx_utils.MetadataBasedPickling): entry_point_name: str device_type: core_defs.DeviceType - def materialize(self) -> stages.ExecutableProgram: + def load(self) -> stages.ExecutableProgram: """Import the .so and return the raw entry point. Must run in the process that will call the returned program: the @@ -75,18 +75,18 @@ def materialize(self) -> stages.ExecutableProgram: class CPPCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - CPPBuildArtifact, + CPPCompilationArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - CPPBuildArtifact, + CPPCompilationArtifact, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Drive a CPP-style build system and wrap the result in a :class:`CPPBuildArtifact`. + """Drive a CPP-style build system and wrap the result in a :class:`CPPCompilationArtifact`. Backends that need a different artifact subclass (e.g. with a wrapped - ``materialize``) subclass and override :meth:`_make_artifact`. + ``load``) subclass and override :meth:`_make_artifact`. """ cache_lifetime: config.BuildCacheLifetime @@ -97,7 +97,7 @@ class CPPCompiler( def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> CPPBuildArtifact: + ) -> CPPCompilationArtifact: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -119,8 +119,8 @@ def __call__( def _make_artifact( self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str - ) -> CPPBuildArtifact: - return CPPBuildArtifact( + ) -> CPPCompilationArtifact: + return CPPCompilationArtifact( src_dir=src_dir, module=module, entry_point_name=entry_point_name, diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index b5d6a0ecfa..6b33465949 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -56,16 +56,18 @@ def __call__( class CompilationStep( - workflow.Workflow[stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.BuildArtifact], + workflow.Workflow[ + stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.CompilationArtifact + ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Run the build system and produce a :class:`stages.BuildArtifact`. + """Run the build system and produce a :class:`stages.CompilationArtifact`. Each backend defines its own concrete artifact dataclass (frozen, - picklable, self-materializing); they all satisfy the - :class:`stages.BuildArtifact` Protocol structurally. + picklable, with a :meth:`stages.CompilationArtifact.load` method); they all + satisfy the :class:`stages.CompilationArtifact` Protocol structurally. """ def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> stages.BuildArtifact: ... + ) -> stages.CompilationArtifact: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index f784a20a12..13a626926d 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,10 +14,10 @@ @dataclasses.dataclass(frozen=True) -class OTFBuildWorkflow( - workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.BuildArtifact] +class OTFCompileWorkflow( + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.CompilationArtifact] ): - """Translation + bindings + build system; ends at a :class:`stages.BuildArtifact`. + """Translation + bindings + build system; ends at a :class:`stages.CompilationArtifact`. Used as :attr:`gt4py.next.backend.Backend.executor`. The ``cached=True`` backend trait wraps it in a :class:`workflow.CachedStep` keyed on @@ -26,4 +26,4 @@ class OTFBuildWorkflow( translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, stages.BuildArtifact] + compilation: workflow.Workflow[stages.CompilableProject, stages.CompilationArtifact] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 4a735a76aa..27ee8b45a6 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -129,17 +129,17 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable -class BuildArtifact(Protocol): - """The output of an :class:`recipes.OTFBuildWorkflow`. +class CompilationArtifact(Protocol): + """The output of an :class:`recipes.OTFCompileWorkflow`. Each backend defines its own concrete artifact dataclass; all share this Protocol. Implementations are frozen dataclasses, picklable, and have no - live process-bound state — that is reconstructed by :meth:`materialize`, + live process-bound state — that is reconstructed by :meth:`load`, which returns a directly-callable :class:`ExecutableProgram` taking gt4py-shaped arguments. """ - def materialize(self) -> ExecutableProgram: ... + def load(self) -> ExecutableProgram: ... def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index e1ddbee455..f8c8fd84a3 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -77,7 +77,7 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: ) compile_workflow = typing.cast( - recipes.OTFBuildWorkflow, + recipes.OTFCompileWorkflow, self.backend.executor if not hasattr(self.backend.executor, "step") else self.backend.executor.step, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py index f822709cd2..4d825c0c9b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py @@ -10,7 +10,7 @@ The main module is `backend`, that exports the backends for CPU and GPU devices. The `backend` module uses `factory` to define a workflow that implements the -`OTFBuildWorkflow` recipe. The different stages are implemeted in separate modules: +`OTFCompileWorkflow` recipe. The different stages are implemeted in separate modules: - `translation` for lowering of GTIR to SDFG and applying SDFG transformations - `compilation` for compiling the SDFG into a program - `decoration` to parse the program arguments and pass them to the program call diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 2a70628ae8..ae3fee1540 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -120,7 +120,7 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) -class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): +class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" build_folder: pathlib.Path @@ -130,8 +130,8 @@ class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): device_type: core_defs.DeviceType # Process-local cache of the live :class:`CompiledDaceProgram`. Populated by - # ``DaCeCompiler`` after a fresh compile so :meth:`materialize` can skip the - # SDFG re-deserialize + .so re-link round-trip in the same process. Marked + # ``DaCeCompiler`` after a fresh compile so :meth:`load` can skip the SDFG + # re-deserialize + .so re-link round-trip in the same process. Marked # ``pickle=False`` via :func:`gtx_utils.gt4py_metadata` so a receiver of the # artifact in a different process sees ``None`` and falls back to the # disk-based path. @@ -143,7 +143,7 @@ class DaCeBuildArtifact(gtx_utils.MetadataBasedPickling): metadata=gtx_utils.gt4py_metadata(pickle=False), ) - def materialize(self) -> stages.ExecutableProgram: + def load(self) -> stages.ExecutableProgram: """Wrap the compiled program in gt4py's calling convention. Uses the live program cached on the artifact when available; otherwise @@ -172,15 +172,15 @@ def _load_compiled_program(self) -> CompiledDaceProgram: class DaCeCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - DaCeBuildArtifact, + DaCeCompilationArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - DaCeBuildArtifact, + DaCeCompilationArtifact, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], ): - """Run the DaCe build system and produce an on-disk :class:`DaCeBuildArtifact`.""" + """Run the DaCe build system and produce an on-disk :class:`DaCeCompilationArtifact`.""" bind_func_name: str cache_lifetime: config.BuildCacheLifetime @@ -190,7 +190,7 @@ class DaCeCompiler( def __call__( self, inp: stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - ) -> DaCeBuildArtifact: + ) -> DaCeCompilationArtifact: with gtx_wfdcommon.dace_context( device_type=self.device_type, cmake_build_type=self.cmake_build_type, @@ -201,7 +201,7 @@ def __call__( sdfg = dace.SDFG.from_json(inp.program_source.source_code) sdfg.build_folder = str(sdfg_build_folder) with locking.lock(sdfg_build_folder): - # Keep the program handle so the artifact's materialize() can + # Keep the program handle so the artifact's load() can # skip the SDFG re-deserialize + .so re-link round-trip when # used in this same process. sdfg_program = sdfg.compile(validate=False) @@ -216,7 +216,7 @@ def __call__( ) assert inp.binding_source is not None - artifact = DaCeBuildArtifact( + artifact = DaCeCompilationArtifact( build_folder=sdfg_build_folder, sdfg_dump=sdfg_dump, binding_source_code=inp.binding_source.source_code, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 27eb57a82b..8e520856ac 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: # Type-only: evaluating ``compilation`` at module load would create a cycle - # (compilation imports this module for the materialize body). + # (compilation imports this module for the load body). from gt4py.next.program_processors.runners.dace.workflow.compilation import ( CompiledDaceProgram, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 9f6c80fd07..069854a586 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -30,7 +30,7 @@ class DaCeWorkflowFactory(factory.Factory): class Meta: - model = recipes.OTFBuildWorkflow + model = recipes.OTFCompileWorkflow class Params: auto_optimize: bool = False diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index fd652c1ee8..6a8ff1fc69 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -108,17 +108,17 @@ def extract_connectivity_args( @dataclasses.dataclass(frozen=True) -class GTFNBuildArtifact(compiler.CPPBuildArtifact): - def materialize(self) -> stages.ExecutableProgram: - return convert_args(super().materialize(), device=self.device_type) +class GTFNCompilationArtifact(compiler.CPPCompilationArtifact): + def load(self) -> stages.ExecutableProgram: + return convert_args(super().load(), device=self.device_type) @dataclasses.dataclass(frozen=True) class GTFNCompiler(compiler.CPPCompiler): def _make_artifact( self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str - ) -> GTFNBuildArtifact: - return GTFNBuildArtifact( + ) -> GTFNCompilationArtifact: + return GTFNCompilationArtifact( src_dir=src_dir, module=module, entry_point_name=entry_point_name, @@ -131,9 +131,9 @@ class Meta: model = GTFNCompiler -class GTFNBuildWorkflowFactory(factory.Factory): +class GTFNCompileWorkflowFactory(factory.Factory): class Meta: - model = recipes.OTFBuildWorkflow + model = recipes.OTFCompileWorkflow class Params: device_type: core_defs.DeviceType = core_defs.DeviceType.CPU @@ -195,7 +195,7 @@ class Params: device_type = core_defs.DeviceType.CPU hash_function = stages.compilation_hash otf_workflow = factory.SubFactory( - GTFNBuildWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) name = factory.LazyAttribute( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 7d7075157e..f076018571 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -223,7 +223,7 @@ class RoundtripArtifact: program: stages.ExecutableProgram - def materialize(self) -> stages.ExecutableProgram: + def load(self) -> stages.ExecutableProgram: return self.program diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 77db222a11..84226c4e03 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -31,7 +31,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source).materialize() + compiled_program = build_the_program(example_program_source).load() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), @@ -51,7 +51,7 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source).materialize() + compiled_program = build_the_program(example_program_source).load() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py index 806ea94c93..42a6687699 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Minimal contract tests for :class:`compiler.CPPBuildArtifact`.""" +"""Minimal contract tests for :class:`compiler.CPPCompilationArtifact`.""" import pathlib import pickle @@ -16,7 +16,7 @@ def test_cpp_build_artifact_pickle_round_trip(): - artifact = compiler.CPPBuildArtifact( + artifact = compiler.CPPCompilationArtifact( src_dir=pathlib.Path("/tmp/build"), module=pathlib.Path("entry.so"), entry_point_name="entry", diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 233e5a2f6e..ed881c9495 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -116,9 +116,9 @@ def test_inlining_of_scalar_works_integration(testee_prog): @dataclasses.dataclass(frozen=True) class _NoOpArtifact: - """A trivial BuildArtifact that materializes to a no-op callable.""" + """A trivial CompilationArtifact that loads to a no-op callable.""" - def materialize(self): + def load(self): return lambda *args, **kwargs: None def pirate(program: toolchain.ConcreteArtifact): diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 3d50fbaf52..d027c9dcb1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -178,7 +178,7 @@ def testee(a: cases.IJKField) -> cases.IJKField: # first call: this generates the cache file cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) - # clearing the OTFBuildWorkflow cache such that the OTFBuildWorkflow step is executed again + # clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again object.__setattr__(cartesian_case.backend.executor, "cache", {}) # second call: the cache file is used cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index 1149f3e131..ec1a926e4a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Minimal contract tests for :class:`compilation.DaCeBuildArtifact`.""" +"""Minimal contract tests for :class:`compilation.DaCeCompilationArtifact`.""" import pathlib import pickle @@ -20,7 +20,7 @@ def test_dace_build_artifact_pickle_round_trip_drops_live_program(): - artifact = compilation.DaCeBuildArtifact( + artifact = compilation.DaCeCompilationArtifact( build_folder=pathlib.Path("/tmp/build"), sdfg_dump=pathlib.Path("/tmp/build/program.sdfgz"), binding_source_code="def update_sdfg_args(*a, **k): ...", From 29508bf421dfd799f1e9174681358b6198bc5139 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 27 Apr 2026 22:09:18 +0200 Subject: [PATCH 21/35] cleanup --- src/gt4py/next/otf/compilation/compiler.py | 12 +++------ src/gt4py/next/otf/recipes.py | 7 +----- .../runners/dace/workflow/compilation.py | 25 ++++++++----------- .../runners/dace/workflow/decoration.py | 3 +-- 4 files changed, 17 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 4c0b2681aa..8f5da88b77 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -45,11 +45,8 @@ def __call__( class CPPCompilationArtifact(gtx_utils.MetadataBasedPickling): """On-disk result of a CPP-style compilation: a Python extension module. - Bindings are baked into the .so (e.g. via nanobind), so the default - :meth:`load` is just an ``importlib`` import + entry-point lookup, - returning the raw imported callable. Backends that need to wrap the - callable in a calling convention (e.g. GTFN's gt4py-shaped argument - conversion) subclass and override :meth:`load`. + The default :meth:`load` is an ``importlib`` import + entry-point lookup; + backends override to apply their own calling convention. """ src_dir: pathlib.Path @@ -83,10 +80,9 @@ class CPPCompiler( ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Drive a CPP-style build system and wrap the result in a :class:`CPPCompilationArtifact`. + """Drive a CPP-style build system into a :class:`CPPCompilationArtifact`. - Backends that need a different artifact subclass (e.g. with a wrapped - ``load``) subclass and override :meth:`_make_artifact`. + Backends override :meth:`_make_artifact` to use their own artifact subclass. """ cache_lifetime: config.BuildCacheLifetime diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 13a626926d..0b809e4731 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -17,12 +17,7 @@ class OTFCompileWorkflow( workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.CompilationArtifact] ): - """Translation + bindings + build system; ends at a :class:`stages.CompilationArtifact`. - - Used as :attr:`gt4py.next.backend.Backend.executor`. The ``cached=True`` - backend trait wraps it in a :class:`workflow.CachedStep` keyed on - :class:`definitions.CompilableProgramDef`. - """ + """The typical compiled backend steps composed into a workflow.""" translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index ae3fee1540..2734a67161 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -13,7 +13,7 @@ import pathlib import warnings from collections.abc import Callable, MutableSequence, Sequence -from typing import Any, Optional +from typing import Any import dace import factory @@ -130,12 +130,11 @@ class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): device_type: core_defs.DeviceType # Process-local cache of the live :class:`CompiledDaceProgram`. Populated by - # ``DaCeCompiler`` after a fresh compile so :meth:`load` can skip the SDFG - # re-deserialize + .so re-link round-trip in the same process. Marked - # ``pickle=False`` via :func:`gtx_utils.gt4py_metadata` so a receiver of the - # artifact in a different process sees ``None`` and falls back to the - # disk-based path. - _live_program: Optional[CompiledDaceProgram] = dataclasses.field( + # ``DaCeCompiler`` to skip the disk round-trip when the artifact stays in + # the same process. Excluded from pickle (``pickle=False`` metadata) so + # receivers in other processes see ``None`` and fall through to the + # disk-based load. + _live_program: CompiledDaceProgram | None = dataclasses.field( init=False, default=None, compare=False, @@ -146,10 +145,9 @@ class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): def load(self) -> stages.ExecutableProgram: """Wrap the compiled program in gt4py's calling convention. - Uses the live program cached on the artifact when available; otherwise - re-deserializes the SDFG, re-links the .so via ``compiler.use_cache``, - and caches the result for subsequent calls. Must run in the process - that will call the returned program. + On a miss, re-deserializes the SDFG and re-links the .so via + ``compiler.use_cache``. Must run in the process that will call the + returned program. """ program = self._live_program if program is None: @@ -201,9 +199,8 @@ def __call__( sdfg = dace.SDFG.from_json(inp.program_source.source_code) sdfg.build_folder = str(sdfg_build_folder) with locking.lock(sdfg_build_folder): - # Keep the program handle so the artifact's load() can - # skip the SDFG re-deserialize + .so re-link round-trip when - # used in this same process. + # Keep the handle so the artifact's load() can skip the disk + # round-trip in the same process. sdfg_program = sdfg.compile(validate=False) for dump_name in ("program.sdfgz", "program.sdfg"): diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 8e520856ac..07707b1f1a 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -22,8 +22,7 @@ if TYPE_CHECKING: - # Type-only: evaluating ``compilation`` at module load would create a cycle - # (compilation imports this module for the load body). + # Type-only: a top-level import would cycle with ``compilation``. from gt4py.next.program_processors.runners.dace.workflow.compilation import ( CompiledDaceProgram, ) From 9f4e776e6f1bfa5dd733e8fe3fdf6715a29a4e1e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 28 Apr 2026 08:23:17 +0200 Subject: [PATCH 22/35] use tmp_path fixture --- .../program_processors/runners/dace/workflow/decoration.py | 4 +--- .../unit_tests/otf_tests/compilation_tests/test_compiler.py | 4 ++-- .../runners_tests/dace_tests/test_dace_compilation.py | 6 +++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 07707b1f1a..f9e9f7181b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -23,9 +23,7 @@ if TYPE_CHECKING: # Type-only: a top-level import would cycle with ``compilation``. - from gt4py.next.program_processors.runners.dace.workflow.compilation import ( - CompiledDaceProgram, - ) + from gt4py.next.program_processors.runners.dace.workflow.compilation import CompiledDaceProgram def convert_args( diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py index 42a6687699..7dbaeaf719 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -15,9 +15,9 @@ from gt4py.next.otf.compilation import compiler -def test_cpp_build_artifact_pickle_round_trip(): +def test_cpp_compilation_artifact_pickle_round_trip(tmp_path: pathlib.Path): artifact = compiler.CPPCompilationArtifact( - src_dir=pathlib.Path("/tmp/build"), + src_dir=tmp_path, module=pathlib.Path("entry.so"), entry_point_name="entry", device_type=core_defs.DeviceType.CPU, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index ec1a926e4a..acb4ea24a8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -19,10 +19,10 @@ from gt4py.next.program_processors.runners.dace.workflow import compilation # noqa: E402 -def test_dace_build_artifact_pickle_round_trip_drops_live_program(): +def test_dace_compilation_artifact_pickle_round_trip_drops_live_program(tmp_path: pathlib.Path): artifact = compilation.DaCeCompilationArtifact( - build_folder=pathlib.Path("/tmp/build"), - sdfg_dump=pathlib.Path("/tmp/build/program.sdfgz"), + build_folder=tmp_path, + sdfg_dump=tmp_path / "program.sdfgz", binding_source_code="def update_sdfg_args(*a, **k): ...", bind_func_name="update_sdfg_args", device_type=core_defs.DeviceType.CPU, From 9c9234d82eb453dd6aa9b0144046195f5cd1ac41 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 28 Apr 2026 08:36:12 +0200 Subject: [PATCH 23/35] refactor roundtrip to resepect picklability --- .../program_processors/runners/roundtrip.py | 161 +++++++++--------- 1 file changed, 85 insertions(+), 76 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index f076018571..396eecc173 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -11,11 +11,10 @@ import dataclasses import functools import importlib.util -import pathlib import tempfile import textwrap -import typing -from collections.abc import Callable, Iterable +import types +from collections.abc import Iterable from typing import Any, Optional from gt4py.eve import codegen @@ -106,28 +105,20 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" -_FENCIL_CACHE: dict[int, Callable] = {} +# Caches the generated source by IR hash so re-codegen is skipped within a process. +_SOURCE_CACHE: dict[int, tuple[str, str]] = {} +# Caches the loaded module by source string so re-exec is skipped within a process. +_MODULE_CACHE: dict[str, types.ModuleType] = {} -def fencil_generator( +def _generate_source( ir: itir.Program, debug: bool, use_embedded: bool, offset_provider: common.OffsetProvider, transforms: itir_transforms.GTIRTransform, -) -> stages.ExecutableProgram: - """ - Generate a directly executable fencil from an ITIR node. - - Arguments: - ir: The iterator IR (ITIR) node. - debug: Keep module source containing fencil implementation. - extract_temporaries: Extract intermediate field values into temporaries. - use_embedded: Directly use builtins from embedded backend instead of - generic dispatcher. Gives faster performance and is easier - to debug. - offset_provider: A mapping from offset names to offset providers. - """ +) -> tuple[str, str]: + """Generate the Python source for an ITIR program. Returns ``(source_code, entry_point_name)``.""" # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism cache_key = hash( @@ -139,10 +130,10 @@ def fencil_generator( tuple(common.offset_provider_to_type(offset_provider).items()), ) ) - if cache_key in _FENCIL_CACHE: + if cache_key in _SOURCE_CACHE: if debug: - print(f"Using cached fencil for key {cache_key}") - return _FENCIL_CACHE[cache_key] # A CompiledProgram is just a Callable + print(f"Using cached source for key {cache_key}") + return _SOURCE_CACHE[cache_key] ir = transforms(ir, offset_provider=offset_provider) @@ -178,53 +169,84 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".py", encoding="utf-8", delete=False - ) as source_file: - source_file_name = source_file.name - if debug: - print(source_file_name) - offset_literals = [f'{o} = offset("{o}")' for o in offset_literals] - axis_literals = [ - f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' - for o in axis_literals_set - ] - source_file.write(header) - source_file.write("\n".join(offset_literals)) - source_file.write("\n") - source_file.write("\n".join(axis_literals)) - source_file.write("\n") - source_file.write(program) - try: - spec = importlib.util.spec_from_file_location("module.name", source_file_name) - mod = importlib.util.module_from_spec(spec) # type: ignore - spec.loader.exec_module(mod) # type: ignore - finally: - if not debug: - pathlib.Path(source_file_name).unlink(missing_ok=True) + offset_literals_src = "\n".join(f'{o} = offset("{o}")' for o in offset_literals) + axis_literals_src = "\n".join( + f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' + for o in axis_literals_set + ) + source_code = f"{header}{offset_literals_src}\n{axis_literals_src}\n{program}" assert isinstance(ir, itir.Program) - fencil_name = ir.id - fencil = getattr(mod, fencil_name) + entry_point_name = ir.id + + _SOURCE_CACHE[cache_key] = (source_code, entry_point_name) + return source_code, entry_point_name - _FENCIL_CACHE[cache_key] = fencil - return typing.cast(stages.ExecutableProgram, fencil) +def _load_module(source_code: str, debug: bool) -> types.ModuleType: + if source_code in _MODULE_CACHE: + return _MODULE_CACHE[source_code] + + if debug: + # Write to a real .py so debuggers/tracebacks have file/line info. + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: + source_file.write(source_code) + source_file_name = source_file.name + print(source_file_name) + spec = importlib.util.spec_from_file_location("module.name", source_file_name) + mod = importlib.util.module_from_spec(spec) # type: ignore[arg-type] + spec.loader.exec_module(mod) # type: ignore[union-attr] + else: + mod = types.ModuleType("roundtrip_module") + exec(compile(source_code, "", "exec"), mod.__dict__) + + _MODULE_CACHE[source_code] = mod + return mod @dataclasses.dataclass(frozen=True) class RoundtripArtifact: - """In-memory artifact for the roundtrip backend. + """Source-string artifact for the roundtrip backend. - Roundtrip generates a Python module per program and executes it directly, - so its output is a live callable rather than something on disk. Not - picklable — roundtrip is in-process only. + The generated Python source is the artifact: picklable, re-execed on + :meth:`load`. When ``debug`` is true, ``load`` writes a temporary ``.py`` + so debuggers/tracebacks resolve to source lines. """ - program: stages.ExecutableProgram + source_code: str + entry_point_name: str + column_axis: common.Dimension | None + dispatch_backend: next_backend.Backend | None + debug: bool def load(self) -> stages.ExecutableProgram: - return self.program + mod = _load_module(self.source_code, self.debug) + fencil = getattr(mod, self.entry_point_name) + captured_column_axis = self.column_axis + dispatch_backend = self.dispatch_backend + + def decorated_fencil( + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, + column_axis: Optional[ + common.Dimension + ] = None, # TODO(tehrengruber): unused, kept for signature compat + **kwargs: Any, + ) -> None: + if out is not None: + args = (*args, out) + fencil( + *args, + offset_provider=offset_provider, + backend=dispatch_backend, + column_axis=captured_column_axis, + **kwargs, + ) + + return decorated_fencil @dataclasses.dataclass(frozen=True) @@ -237,7 +259,7 @@ class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, RoundtripArt def __call__(self, inp: definitions.CompilableProgramDef) -> RoundtripArtifact: debug = config.DEBUG if self.debug is None else self.debug - fencil = fencil_generator( + source_code, entry_point_name = _generate_source( inp.data, offset_provider=inp.args.offset_provider, debug=debug, @@ -245,26 +267,13 @@ def __call__(self, inp: definitions.CompilableProgramDef) -> RoundtripArtifact: transforms=self.transforms, ) - def decorated_fencil( - *args: Any, - offset_provider: dict[str, common.Connectivity | common.Dimension], - out: Any = None, - column_axis: Optional[common.Dimension] = None, - **kwargs: Any, - ) -> None: - if out is not None: - args = (*args, out) - if not column_axis: # TODO(tehrengruber): This variable is never used. Bug? - column_axis = inp.args.column_axis - fencil( - *args, - offset_provider=offset_provider, - backend=self.dispatch_backend, - column_axis=inp.args.column_axis, - **kwargs, - ) - - return RoundtripArtifact(program=decorated_fencil) + return RoundtripArtifact( + source_code=source_code, + entry_point_name=entry_point_name, + column_axis=inp.args.column_axis, + dispatch_backend=self.dispatch_backend, + debug=debug, + ) # TODO(tehrengruber): introduce factory From edd9e9372ff90f6fd13f7017f72b35167c1787db Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 28 Apr 2026 12:38:49 +0200 Subject: [PATCH 24/35] sdfg as part of artifact (because of upcoming change) --- .../runners/dace/workflow/compilation.py | 46 ++++++++++--------- .../dace_tests/test_dace_compilation.py | 2 +- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 2734a67161..1f69f1ad71 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import json import os import pathlib import warnings @@ -16,6 +17,7 @@ from typing import Any import dace +import dace.codegen.compiler as dace_compiler import factory from gt4py._core import definitions as core_defs, locking @@ -121,10 +123,16 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): - """On-disk result of a DaCe compilation: a build folder + the SDFG bindings.""" + """Result of a DaCe compilation: build folder + SDFG bindings + the SDFG itself. + + The SDFG is carried inline as JSON because dace's load path + (:func:`get_program_handle`) needs an SDFG instance to wrap into the + returned :class:`CompiledSDFG`, and the build folder may not contain a + ``program.sdfg(z)`` dump under the upcoming minimal-build-dir mode. + """ build_folder: pathlib.Path - sdfg_dump: pathlib.Path + sdfg_json: str binding_source_code: str bind_func_name: str device_type: core_defs.DeviceType @@ -145,9 +153,10 @@ class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): def load(self) -> stages.ExecutableProgram: """Wrap the compiled program in gt4py's calling convention. - On a miss, re-deserializes the SDFG and re-links the .so via - ``compiler.use_cache``. Must run in the process that will call the - returned program. + On a miss, loads the precompiled .so directly via + :func:`dace.codegen.compiler.get_program_handle` — no recompilation, + no ``dace.config`` re-entry. Must run in the process that will call + the returned program. """ program = self._live_program if program is None: @@ -156,13 +165,15 @@ def load(self) -> stages.ExecutableProgram: return gtx_wfddecoration.convert_args(program, device=self.device_type) def _load_compiled_program(self) -> CompiledDaceProgram: - sdfg = dace.SDFG.from_file(str(self.sdfg_dump)) - sdfg.build_folder = str(self.build_folder) - - with gtx_wfdcommon.dace_context(device_type=self.device_type): - with dace.config.set_temporary("compiler", "use_cache", value=True): - sdfg_program = sdfg.compile(validate=False) - + # TODO(phimuell): Drop ``sdfg_json`` from the artifact once dace + # exposes a load path that doesn't require an SDFG instance to wrap + # into the returned ``CompiledSDFG``. + sdfg = dace.SDFG.from_json(json.loads(self.sdfg_json)) + folder_version = dace_compiler.get_folder_version(self.build_folder) + library_path = dace_compiler.get_binary_name( + self.build_folder, sdfg_name=sdfg.name, folder_version=folder_version + ) + sdfg_program = dace_compiler.get_program_handle(library_path, sdfg) return CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) @@ -203,19 +214,10 @@ def __call__( # round-trip in the same process. sdfg_program = sdfg.compile(validate=False) - for dump_name in ("program.sdfgz", "program.sdfg"): - sdfg_dump = sdfg_build_folder / dump_name - if sdfg_dump.exists(): - break - else: - raise RuntimeError( - f"No SDFG dump (program.sdfgz / program.sdfg) found in '{sdfg_build_folder}'." - ) - assert inp.binding_source is not None artifact = DaCeCompilationArtifact( build_folder=sdfg_build_folder, - sdfg_dump=sdfg_dump, + sdfg_json=json.dumps(inp.program_source.source_code), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, device_type=self.device_type, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index acb4ea24a8..29d0ded9e1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -22,7 +22,7 @@ def test_dace_compilation_artifact_pickle_round_trip_drops_live_program(tmp_path: pathlib.Path): artifact = compilation.DaCeCompilationArtifact( build_folder=tmp_path, - sdfg_dump=tmp_path / "program.sdfgz", + sdfg_json="{}", binding_source_code="def update_sdfg_args(*a, **k): ...", bind_func_name="update_sdfg_args", device_type=core_defs.DeviceType.CPU, From 57c55ba8cc9502ecfe48f877446ae28a2bcabc5b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 19 Jun 2026 13:13:50 +0200 Subject: [PATCH 25/35] refactor[next]: lift dace live-program cache out of CompilationArtifact Move the in-process CompiledDaceProgram cache from a `pickle=False` field on `DaCeCompilationArtifact` to a module-level dict keyed by build folder. The artifact becomes a plain frozen dataclass with no live process-bound state, restoring the strong `CompilationArtifact` Protocol contract; receivers in another process see an empty cache and fall through to the disk-based `get_program_handle` load path as before. `_live_program` was the only `pickle=False` field in the tree, so the `MetadataBasedPickling` mixin + `_get_metadata_based_state_getstate` helper become dead code and are deleted (re-doing what #2648 had removed; the merge re-introduced them only to serve this one user). `CPPCompilationArtifact` drops the now-no-op `MetadataBasedPickling` parent. `gt4py_metadata` itself stays, still used for the orthogonal `fingerprint=False` axis introduced in #2648. The Protocol docstring is restored to the strong form and explicitly flags `RoundtripArtifact.dispatch_backend` as the one remaining deviation, to be lifted into the runner / load-time seam alongside the staged-compilation API end-state from egparedes' layered- architecture proposal. Co-Authored-By: Claude Opus 4.7 --- src/gt4py/next/otf/compilation/compiler.py | 4 +- src/gt4py/next/otf/stages.py | 14 ++- .../runners/dace/workflow/compilation.py | 34 ++---- src/gt4py/next/utils.py | 106 +----------------- .../dace_tests/test_dace_compilation.py | 66 ++++++++++- 5 files changed, 83 insertions(+), 141 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 8f5da88b77..63b3de42bd 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -15,7 +15,7 @@ import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import config, utils as gtx_utils +from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer @@ -42,7 +42,7 @@ def __call__( @dataclasses.dataclass(frozen=True) -class CPPCompilationArtifact(gtx_utils.MetadataBasedPickling): +class CPPCompilationArtifact: """On-disk result of a CPP-style compilation: a Python extension module. The default :meth:`load` is an ``importlib`` import + entry-point lookup; diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 90c3de48fa..2f8d8f80da 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -133,10 +133,16 @@ class CompilationArtifact(Protocol): """The output of an :class:`recipes.OTFCompileWorkflow`. Each backend defines its own concrete artifact dataclass; all share this - Protocol. Implementations are frozen dataclasses, picklable, and have no - live process-bound state — that is reconstructed by :meth:`load`, - which returns a directly-callable :class:`ExecutableProgram` taking - gt4py-shaped arguments. + Protocol. Implementations are frozen dataclasses, picklable, and carry no + live process-bound state — that is reconstructed by :meth:`load`, which + returns a directly-callable :class:`ExecutableProgram` taking gt4py-shaped + arguments. + + The one current exception is + :class:`gt4py.next.program_processors.runners.roundtrip.RoundtripArtifact` + when it is configured with a ``dispatch_backend``: that field holds a + :class:`gt4py.next.backend.Backend` reference whose role belongs at the + runner / load-time seam, not in the artifact itself. """ def load(self) -> ExecutableProgram: ... diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 992cc5d9c4..37c3785204 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -21,7 +21,7 @@ import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import common, config, utils as gtx_utils +from gt4py.next import common, config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import ( @@ -135,8 +135,11 @@ def __call__(self, **kwargs: Any) -> None: assert result is None +_live_program_cache: dict[pathlib.Path, CompiledDaceProgram] = {} + + @dataclasses.dataclass(frozen=True) -class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): +class DaCeCompilationArtifact: """Result of a DaCe compilation: build folder + SDFG bindings + the SDFG itself. The SDFG is carried inline as JSON because dace's load path @@ -151,19 +154,6 @@ class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): bind_func_name: str device_type: core_defs.DeviceType - # Process-local cache of the live :class:`CompiledDaceProgram`. Populated by - # ``DaCeCompiler`` to skip the disk round-trip when the artifact stays in - # the same process. Excluded from pickle (``pickle=False`` metadata) so - # receivers in other processes see ``None`` and fall through to the - # disk-based load. - _live_program: CompiledDaceProgram | None = dataclasses.field( - init=False, - default=None, - compare=False, - repr=False, - metadata=gtx_utils.gt4py_metadata(pickle=False), - ) - def load(self) -> stages.ExecutableProgram: """Wrap the compiled program in gt4py's calling convention. @@ -172,10 +162,10 @@ def load(self) -> stages.ExecutableProgram: no ``dace.config`` re-entry. Must run in the process that will call the returned program. """ - program = self._live_program + program = _live_program_cache.get(self.build_folder) if program is None: program = self._load_compiled_program() - object.__setattr__(self, "_live_program", program) + _live_program_cache[self.build_folder] = program return gtx_wfddecoration.convert_args(program, device=self.device_type) def _load_compiled_program(self) -> CompiledDaceProgram: @@ -234,8 +224,6 @@ def __call__( sdfg.build_folder = str(sdfg_build_folder) with locking.lock(sdfg_build_folder): - # Keep the handle so the artifact's load() can skip the disk - # round-trip in the same process. sdfg_program = sdfg.compile(validate=False) assert inp.binding_source is not None @@ -246,12 +234,8 @@ def __call__( bind_func_name=self.bind_func_name, device_type=self.device_type, ) - object.__setattr__( - artifact, - "_live_program", - CompiledDaceProgram( - sdfg_program, artifact.bind_func_name, artifact.binding_source_code - ), + _live_program_cache[artifact.build_folder] = CompiledDaceProgram( + sdfg_program, artifact.bind_func_name, artifact.binding_source_code ) return artifact diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 529737b7a8..9dcc911d10 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -8,8 +8,6 @@ from __future__ import annotations -import copyreg -import dataclasses import functools import inspect import itertools @@ -22,14 +20,13 @@ Optional, ParamSpec, Sequence, - TypeAlias, TypeGuard, TypeVar, cast, overload, ) -from gt4py.eve import datamodels, utils as eve_utils +from gt4py.eve import utils as eve_utils _T = TypeVar("_T") @@ -45,110 +42,11 @@ def gt4py_metadata(**kwargs: Any) -> dict[str, dict[str, Any]]: Helper function to store dataclass/datamodel field metadata within a GT4Py namespace. Individual fields can opt out of fingerprinting with - `foo = field(..., metadata=gt4py_metadata(fingerprint=False))`, or out of - pickling with `gt4py_metadata(pickle=False)` (see `MetadataBasedPickling`). + `foo = field(..., metadata=gt4py_metadata(fingerprint=False))`. """ return {GT4PY_CLASS_METADATA_NS: kwargs} -_StandardPickleState = None | dict[str, Any] | tuple[dict[str, Any] | None, dict[str, Any]] -_StandardGetStateMethod: TypeAlias = Callable[[object], _StandardPickleState] -_StandardSetStateMethod: TypeAlias = Callable[[object, _StandardPickleState], None] - - -@functools.cache -def _get_metadata_based_state_getstate(cls: type) -> _StandardGetStateMethod: - """ - Helper function to make class-specific `__getstate__` method following the standard implementation. - """ - if not isinstance(cls, type) or not ( - (is_dataclass := dataclasses.is_dataclass(cls)) or datamodels.is_datamodel(cls) - ): - raise TypeError(f"Expected a dataclass or datamodel type, got '{cls}'") - - if is_dataclass: - field_metadata = { - field.name: field.metadata.get(GT4PY_CLASS_METADATA_NS, None) - for field in dataclasses.fields(cls) - } - else: - field_metadata = { - field.name: field.metadata.get(GT4PY_CLASS_METADATA_NS, None) - for field in datamodels.get_fields(cls).values() - } - - # To gather all slots we need to traverse the whole MRO not just the class itself. - # We reuse the implementation of `copyreg._slotnames` to avoid code duplication and - # potential bugs, even if it is not (unfortunately) part of the module's public API. - class_slots = copyreg._slotnames(cls) # type: ignore[attr-defined] # copyreg._slotnames is not recognized - - has_slots = len(class_slots) > 0 - has_dict = any("__dict__" in c.__dict__ for c in cls.__mro__) - - dict_names = [] - slot_names = [] - for name, metadata in field_metadata.items(): - if metadata and not metadata.get("pickle", True): - continue - if name in class_slots: - slot_names.append(name) - else: - dict_names.append(name) - - # Comply with the default implementation of object.__getstate__() / object.__setstate__(state) - if not (has_slots or has_dict): - # for class instances without __dict__ nor __slots__: state = None. - def __getstate__(self: object) -> _StandardPickleState: - return None - - elif not has_slots: - # for class instances with __dict__ and no __slots__: state = self.__dict__. - def __getstate__(self: object) -> _StandardPickleState: - return {name: getattr(self, name) for name in dict_names} - - elif not has_dict: - # for class instances with __slots__ and no __dict__: state = (None, {slot.name: slot.value for all slots}). - def __getstate__(self: object) -> _StandardPickleState: - return (None, {name: getattr(self, name) for name in slot_names}) - - else: - # for class instances with __dict__ and __slots__: state = (self.__dict__, {slot.name: slot.value for all slots}). - def __getstate__(self: object) -> _StandardPickleState: - return ( - {name: getattr(self, name) for name in dict_names}, - {name: getattr(self, name) for name in slot_names}, - ) - - return __getstate__ - - -class MetadataBasedPickling: - """ - Mixin for adding metadata-based pickling to dataclass-like objects. - - It uses the class field information to select only instance fields which - are not marked with `pickle=False` in the 'GT4PY_META' metadata namespace. - Individual fields can therefore opt out of pickling. - For example: `foo = field(..., metadata=gt4py_metadata(pickle=False))`. - """ - - __slots__ = () # to avoid creation of __dict__ when not needed - - def __getstate__(self) -> _StandardPickleState: - """ - Get the state of the object for pickling. - - It returns the same kind of arguments as the default `__getstate__` - implementation, used by `pickle` (as documented in `object.__getstate__`, - check: https://devdocs.io/python~3.14/library/pickle#object.__getstate__) - """ - return _get_metadata_based_state_getstate(type(self))(self) # type: ignore[arg-type] # type(self) should be hashable - - # Note: we don't implement `__setstate__` as the output of our custom - # `__getstate__` implementation should be compatible with the default - # implementation. - - class RecursionGuard: """ Context manager to guard against inifinite recursion. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index a216a9b056..c9b549df10 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -8,8 +8,9 @@ """Tests for the compilation stage of the dace backend workflow. -Covers the GPU TX-marker instrumentation and the picklability of -:class:`compilation.DaCeCompilationArtifact`. +Covers the GPU TX-marker instrumentation, the picklability of +:class:`compilation.DaCeCompilationArtifact`, and the process-local +live-program cache that backs its :meth:`load`. """ import contextlib @@ -155,7 +156,7 @@ def test_compiler_skips_tx_markers_for_non_gpu_device(tmp_path): assert compiled_sdfg.instrument == _NONE -def test_dace_compilation_artifact_pickle_round_trip_drops_live_program(tmp_path: pathlib.Path): +def test_dace_compilation_artifact_pickle_round_trip(tmp_path: pathlib.Path): artifact = dace_wf_compilation.DaCeCompilationArtifact( build_folder=tmp_path, sdfg_json="{}", @@ -163,10 +164,63 @@ def test_dace_compilation_artifact_pickle_round_trip_drops_live_program(tmp_path bind_func_name="update_sdfg_args", device_type=core_defs.DeviceType.CPU, ) - object.__setattr__(artifact, "_live_program", "") restored = pickle.loads(pickle.dumps(artifact)) - # The data fields round-trip, the live in-process handle does not. assert restored == artifact - assert restored._live_program is None + + +def test_dace_compilation_artifact_load_uses_live_program_cache(tmp_path: pathlib.Path): + """``load()`` returns the cached live program without touching the disk.""" + artifact = dace_wf_compilation.DaCeCompilationArtifact( + build_folder=tmp_path, + sdfg_json="{}", + binding_source_code="def update_sdfg_args(*a, **k): ...", + bind_func_name="update_sdfg_args", + device_type=core_defs.DeviceType.CPU, + ) + sentinel = mock.MagicMock(name="CompiledDaceProgram") + + with ( + mock.patch.dict( + dace_wf_compilation._live_program_cache, {artifact.build_folder: sentinel}, clear=True + ), + mock.patch.object( + dace_wf_compilation.DaCeCompilationArtifact, "_load_compiled_program" + ) as load_mock, + mock.patch.object(dace_wf_compilation.gtx_wfddecoration, "convert_args") as convert_mock, + ): + result = artifact.load() + + load_mock.assert_not_called() + convert_mock.assert_called_once_with(sentinel, device=core_defs.DeviceType.CPU) + assert result is convert_mock.return_value + + +def test_dace_compilation_artifact_load_falls_back_to_disk_on_cache_miss( + tmp_path: pathlib.Path, +): + """On a cold cache, ``load()`` reloads from disk and warms the cache.""" + artifact = dace_wf_compilation.DaCeCompilationArtifact( + build_folder=tmp_path, + sdfg_json="{}", + binding_source_code="def update_sdfg_args(*a, **k): ...", + bind_func_name="update_sdfg_args", + device_type=core_defs.DeviceType.CPU, + ) + reloaded = mock.MagicMock(name="CompiledDaceProgram") + + with ( + mock.patch.dict(dace_wf_compilation._live_program_cache, {}, clear=True), + mock.patch.object( + dace_wf_compilation.DaCeCompilationArtifact, + "_load_compiled_program", + return_value=reloaded, + ) as load_mock, + mock.patch.object(dace_wf_compilation.gtx_wfddecoration, "convert_args"), + ): + artifact.load() + assert dace_wf_compilation._live_program_cache[artifact.build_folder] is reloaded + # Second call must come from the cache. + artifact.load() + load_mock.assert_called_once() From a3e223468977988a81b76107a928e5d5d61bb069 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 19 Jun 2026 13:18:02 +0200 Subject: [PATCH 26/35] docs: drop reST roles from docstrings; emphasize Google style in AGENTS.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace `:class:` / `:meth:` Sphinx cross-reference roles introduced in the previous commit with plain backtick literals, per CODING_GUIDELINES.md §3.8 (Google style + Sphinx-napoleon; reST roles forbidden). Add an explicit Do entry in AGENTS.md so future agents do not re-introduce them. Co-Authored-By: Claude Opus 4.7 --- AGENTS.md | 7 +++++++ src/gt4py/next/otf/stages.py | 15 +++++++-------- .../dace_tests/test_dace_compilation.py | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 7083185fbb..85321d3e5d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -65,6 +65,13 @@ If a command above is wrong for your environment, fix `pyproject.toml`, ## Do - Run `uv run pre-commit run` on staged files before claiming a task done. +- Write docstrings in **Google style** (rendered via Sphinx-napoleon). Use + `Args:` / `Returns:` / `Raises:` / `Examples:` sections — not the reST + `:param:` / `:returns:` / `:raises:` variants. **Do not** use Sphinx + cross-reference roles like `` :class:`Foo` ``, `` :meth:`bar` ``, + `` :func:`baz` `` inside docstrings; the only allowed inline markup is + plain `` `literal` `` backticks and bulleted lists (see + [`CODING_GUIDELINES.md`](CODING_GUIDELINES.md) §3.8). - When touching `gt4py.cartesian`, `gt4py.next`, `gt4py.eve`, `gt4py.storage`, or `gt4py._core`, run the matching `nox -s test_` session before opening the PR. diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 2f8d8f80da..c40a2eaa0f 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -130,19 +130,18 @@ def build(self) -> None: ... class CompilationArtifact(Protocol): - """The output of an :class:`recipes.OTFCompileWorkflow`. + """The output of an ``OTFCompileWorkflow``. Each backend defines its own concrete artifact dataclass; all share this Protocol. Implementations are frozen dataclasses, picklable, and carry no - live process-bound state — that is reconstructed by :meth:`load`, which - returns a directly-callable :class:`ExecutableProgram` taking gt4py-shaped + live process-bound state — that is reconstructed by ``load``, which + returns a directly-callable ``ExecutableProgram`` taking gt4py-shaped arguments. - The one current exception is - :class:`gt4py.next.program_processors.runners.roundtrip.RoundtripArtifact` - when it is configured with a ``dispatch_backend``: that field holds a - :class:`gt4py.next.backend.Backend` reference whose role belongs at the - runner / load-time seam, not in the artifact itself. + The one current exception is ``RoundtripArtifact`` when it is configured + with a ``dispatch_backend``: that field holds a ``Backend`` reference + whose role belongs at the runner / load-time seam, not in the artifact + itself. """ def load(self) -> ExecutableProgram: ... diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index c9b549df10..20cc1c350f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -9,8 +9,8 @@ """Tests for the compilation stage of the dace backend workflow. Covers the GPU TX-marker instrumentation, the picklability of -:class:`compilation.DaCeCompilationArtifact`, and the process-local -live-program cache that backs its :meth:`load`. +``DaCeCompilationArtifact``, and the process-local live-program cache +that backs its ``load`` method. """ import contextlib From 21de820c89a96559502945c378adfc1f3b5c842c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 19 Jun 2026 13:26:13 +0200 Subject: [PATCH 27/35] docs: drop remaining reST roles from PR-introduced docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sweep the docstrings this PR added or modified relative to upstream/main and replace `:class:` / `:meth:` / `:func:` Sphinx roles with plain backtick literals, per CODING_GUIDELINES.md §3.8. Co-Authored-By: Claude Opus 4.7 --- src/gt4py/next/otf/compilation/compiler.py | 6 +++--- src/gt4py/next/otf/definitions.py | 6 +++--- .../runners/dace/workflow/compilation.py | 8 ++++---- src/gt4py/next/program_processors/runners/roundtrip.py | 2 +- .../otf_tests/compilation_tests/test_compiler.py | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 63b3de42bd..6048cf82c3 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -45,7 +45,7 @@ def __call__( class CPPCompilationArtifact: """On-disk result of a CPP-style compilation: a Python extension module. - The default :meth:`load` is an ``importlib`` import + entry-point lookup; + The default ``load`` is an ``importlib`` import + entry-point lookup; backends override to apply their own calling convention. """ @@ -80,9 +80,9 @@ class CPPCompiler( ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Drive a CPP-style build system into a :class:`CPPCompilationArtifact`. + """Drive a CPP-style build system into a ``CPPCompilationArtifact``. - Backends override :meth:`_make_artifact` to use their own artifact subclass. + Backends override ``_make_artifact`` to use their own artifact subclass. """ cache_lifetime: config.BuildCacheLifetime diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 6b33465949..e85e372ebd 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -61,11 +61,11 @@ class CompilationStep( ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Run the build system and produce a :class:`stages.CompilationArtifact`. + """Run the build system and produce a ``stages.CompilationArtifact``. Each backend defines its own concrete artifact dataclass (frozen, - picklable, with a :meth:`stages.CompilationArtifact.load` method); they all - satisfy the :class:`stages.CompilationArtifact` Protocol structurally. + picklable, with a ``load`` method); they all satisfy the + ``stages.CompilationArtifact`` Protocol structurally. """ def __call__( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 37c3785204..1bad7e0f0b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -143,8 +143,8 @@ class DaCeCompilationArtifact: """Result of a DaCe compilation: build folder + SDFG bindings + the SDFG itself. The SDFG is carried inline as JSON because dace's load path - (:func:`get_program_handle`) needs an SDFG instance to wrap into the - returned :class:`CompiledSDFG`, and the build folder may not contain a + (``get_program_handle``) needs an SDFG instance to wrap into the + returned ``CompiledSDFG``, and the build folder may not contain a ``program.sdfg(z)`` dump under the upcoming minimal-build-dir mode. """ @@ -158,7 +158,7 @@ def load(self) -> stages.ExecutableProgram: """Wrap the compiled program in gt4py's calling convention. On a miss, loads the precompiled .so directly via - :func:`dace.codegen.compiler.get_program_handle` — no recompilation, + ``dace.codegen.compiler.get_program_handle`` — no recompilation, no ``dace.config`` re-entry. Must run in the process that will call the returned program. """ @@ -193,7 +193,7 @@ class DaCeCompiler( ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], ): - """Run the DaCe build system and produce an on-disk :class:`DaCeCompilationArtifact`.""" + """Run the DaCe build system and produce an on-disk ``DaCeCompilationArtifact``.""" bind_func_name: str cache_lifetime: config.BuildCacheLifetime diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 78128c81c3..3b7161c3bd 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -215,7 +215,7 @@ class RoundtripArtifact: """Source-string artifact for the roundtrip backend. The generated Python source is the artifact: picklable, re-execed on - :meth:`load`. When ``debug`` is true, ``load`` writes a temporary ``.py`` + ``load``. When ``debug`` is true, ``load`` writes a temporary ``.py`` so debuggers/tracebacks resolve to source lines. """ diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py index 7dbaeaf719..45abaf86c3 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Minimal contract tests for :class:`compiler.CPPCompilationArtifact`.""" +"""Minimal contract tests for ``CPPCompilationArtifact``.""" import pathlib import pickle From c125a2ae4a042ce548a425c9499d4d72f98ab128 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 25 Jun 2026 13:31:50 +0200 Subject: [PATCH 28/35] refactor[next]: drop Path/str round-trip in DaCeCompiler Keep `sdfg_build_folder` as the str returned by `get_cache_folder` and convert to `pathlib.Path` only where consumers require it. Co-Authored-By: Claude Opus 4.7 --- .../runners/dace/workflow/compilation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 1bad7e0f0b..e549043f2c 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -213,8 +213,8 @@ def __call__( device_type=self.device_type, cmake_build_type=self.cmake_build_type, ): - sdfg_build_folder = pathlib.Path(gtx_cache.get_cache_folder(inp, self.cache_lifetime)) - sdfg_build_folder.mkdir(parents=True, exist_ok=True) + sdfg_build_folder = gtx_cache.get_cache_folder(inp, self.cache_lifetime) + pathlib.Path(sdfg_build_folder).mkdir(parents=True, exist_ok=True) sdfg = dace.SDFG.from_json(inp.program_source.source_code) @@ -222,13 +222,13 @@ def __call__( if self.add_gpu_trace_markers and self.device_type == core_defs.CUPY_DEVICE_TYPE: _add_tx_markers(sdfg) - sdfg.build_folder = str(sdfg_build_folder) + sdfg.build_folder = sdfg_build_folder with locking.lock(sdfg_build_folder): sdfg_program = sdfg.compile(validate=False) assert inp.binding_source is not None artifact = DaCeCompilationArtifact( - build_folder=sdfg_build_folder, + build_folder=pathlib.Path(sdfg_build_folder), sdfg_json=json.dumps(inp.program_source.source_code), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, From 3c656967f768d87c92414687431ffb5e04a60c5b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 25 Jun 2026 13:35:02 +0200 Subject: [PATCH 29/35] refactor[next]: store library_path on DaCeCompilationArtifact MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `load()` previously recomputed the .so path by calling two dace-internal helpers (`get_folder_version`, `get_binary_name`) against `build_folder` and `sdfg.name`. That coupled the load path to dace internals — `get_folder_version` was removed in dace 2.0.0a3, silently breaking artifact loading on that release. Capture `sdfg_program.filename` after `sdfg.compile()` and store it on the artifact. `_load_compiled_program` now goes straight to `get_program_handle(self.library_path, sdfg)`; the two helper calls disappear. Co-Authored-By: Claude Opus 4.7 --- .../runners/dace/workflow/compilation.py | 10 ++++------ .../runners_tests/dace_tests/test_dace_compilation.py | 3 +++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index e549043f2c..773a2eb496 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -140,7 +140,7 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) class DaCeCompilationArtifact: - """Result of a DaCe compilation: build folder + SDFG bindings + the SDFG itself. + """Result of a DaCe compilation: build folder + library path + SDFG bindings + the SDFG itself. The SDFG is carried inline as JSON because dace's load path (``get_program_handle``) needs an SDFG instance to wrap into the @@ -149,6 +149,7 @@ class DaCeCompilationArtifact: """ build_folder: pathlib.Path + library_path: pathlib.Path sdfg_json: str binding_source_code: str bind_func_name: str @@ -173,11 +174,7 @@ def _load_compiled_program(self) -> CompiledDaceProgram: # exposes a load path that doesn't require an SDFG instance to wrap # into the returned ``CompiledSDFG``. sdfg = dace.SDFG.from_json(json.loads(self.sdfg_json)) - folder_version = dace_compiler.get_folder_version(self.build_folder) - library_path = dace_compiler.get_binary_name( - self.build_folder, sdfg_name=sdfg.name, folder_version=folder_version - ) - sdfg_program = dace_compiler.get_program_handle(library_path, sdfg) + sdfg_program = dace_compiler.get_program_handle(self.library_path, sdfg) return CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) @@ -229,6 +226,7 @@ def __call__( assert inp.binding_source is not None artifact = DaCeCompilationArtifact( build_folder=pathlib.Path(sdfg_build_folder), + library_path=pathlib.Path(sdfg_program.filename), sdfg_json=json.dumps(inp.program_source.source_code), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index 20cc1c350f..28cf001090 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -159,6 +159,7 @@ def test_compiler_skips_tx_markers_for_non_gpu_device(tmp_path): def test_dace_compilation_artifact_pickle_round_trip(tmp_path: pathlib.Path): artifact = dace_wf_compilation.DaCeCompilationArtifact( build_folder=tmp_path, + library_path=tmp_path / "build" / "libprogram.so", sdfg_json="{}", binding_source_code="def update_sdfg_args(*a, **k): ...", bind_func_name="update_sdfg_args", @@ -174,6 +175,7 @@ def test_dace_compilation_artifact_load_uses_live_program_cache(tmp_path: pathli """``load()`` returns the cached live program without touching the disk.""" artifact = dace_wf_compilation.DaCeCompilationArtifact( build_folder=tmp_path, + library_path=tmp_path / "build" / "libprogram.so", sdfg_json="{}", binding_source_code="def update_sdfg_args(*a, **k): ...", bind_func_name="update_sdfg_args", @@ -203,6 +205,7 @@ def test_dace_compilation_artifact_load_falls_back_to_disk_on_cache_miss( """On a cold cache, ``load()`` reloads from disk and warms the cache.""" artifact = dace_wf_compilation.DaCeCompilationArtifact( build_folder=tmp_path, + library_path=tmp_path / "build" / "libprogram.so", sdfg_json="{}", binding_source_code="def update_sdfg_args(*a, **k): ...", bind_func_name="update_sdfg_args", From 1a13a63c3ef898a57ae32b52296bdded80b2a4ea Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 25 Jun 2026 13:46:51 +0200 Subject: [PATCH 30/35] docs[next]: explain why DaCeCompilationArtifact has _live_program_cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without the in-process hand-off, the back-to-back `sdfg.compile()` → `artifact.load()` sequence in thread mode hits dace's "library already loaded, renaming file" path: dace renames the .so on disk and dlopens the renamed copy, silently invalidating `library_path` for any later load. Co-Authored-By: Claude Opus 4.7 --- .../program_processors/runners/dace/workflow/compilation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 773a2eb496..5e515771d2 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -135,6 +135,12 @@ def __call__(self, **kwargs: Any) -> None: assert result is None +# Hand off the live `CompiledDaceProgram` from `DaCeCompiler.__call__` to the +# subsequent `DaCeCompilationArtifact.load()` in the same process. Required for +# correctness in thread mode: `sdfg.compile()` dlopens the .so internally, so a +# second `get_program_handle(library_path, ...)` triggers dace's +# "library already loaded, renaming file" path — which renames the .so on disk +# and would invalidate `library_path` for any later load. _live_program_cache: dict[pathlib.Path, CompiledDaceProgram] = {} From c9dc7bead32e92e58260baa3c7a654af0644b791 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 25 Jun 2026 13:50:44 +0200 Subject: [PATCH 31/35] docs[next]: TODO to drop _live_program_cache if dace stops renaming Co-Authored-By: Claude Opus 4.7 --- .../program_processors/runners/dace/workflow/compilation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 5e515771d2..9522db2b6b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -141,6 +141,9 @@ def __call__(self, **kwargs: Any) -> None: # second `get_program_handle(library_path, ...)` triggers dace's # "library already loaded, renaming file" path — which renames the .so on disk # and would invalidate `library_path` for any later load. +# TODO(havogt): drop this hand-off if dace stops renaming the .so on the +# second dlopen of an already-loaded library. The cache would then become +# a pure (modest) optimization and could be reconsidered on its own merits. _live_program_cache: dict[pathlib.Path, CompiledDaceProgram] = {} From 46a25784072cd297f8279153c0db091f6c75411a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 26 Jun 2026 11:20:12 +0200 Subject: [PATCH 32/35] feat[next]: make CompilationArtifact Protocol runtime-checkable Allows ``isinstance(x, CompilationArtifact)`` in tests / asserts. ``runtime_checkable`` on a single-method Protocol only checks for the presence of ``load``, not its signature. Co-Authored-By: Claude Opus 4.7 --- src/gt4py/next/otf/stages.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index c40a2eaa0f..cc631da4a8 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -10,7 +10,16 @@ import dataclasses from collections.abc import Callable -from typing import TYPE_CHECKING, Final, Generic, Optional, Protocol, TypeAlias, TypeVar +from typing import ( + TYPE_CHECKING, + Final, + Generic, + Optional, + Protocol, + TypeAlias, + TypeVar, + runtime_checkable, +) from gt4py.next import common, fingerprinting from gt4py.next.iterator import ir as itir @@ -129,6 +138,7 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable +@runtime_checkable class CompilationArtifact(Protocol): """The output of an ``OTFCompileWorkflow``. From bce71221640a2ca26ca34f8b88c7b8f53b05727f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 26 Jun 2026 11:20:33 +0200 Subject: [PATCH 33/35] refactor[next]: drop unused CompilerFactory Zero call sites anywhere in src/ or tests/. ``GTFNCompiler`` is built via its own ``GTFNCompilerFactory`` and ``DaCeCompiler`` via ``DaCeCompilationStepFactory``; the base ``CompilerFactory`` was never referenced. Co-Authored-By: Claude Opus 4.7 --- src/gt4py/next/otf/compilation/compiler.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 6048cf82c3..98f247fb72 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -12,8 +12,6 @@ import pathlib from typing import Protocol, TypeVar -import factory - from gt4py._core import definitions as core_defs, locking from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow @@ -124,9 +122,4 @@ def _make_artifact( ) -class CompilerFactory(factory.Factory): - class Meta: - model = CPPCompiler - - class CompilationError(RuntimeError): ... From d8c17896a0c00deb590ba6b6a20b281c74f975d3 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 26 Jun 2026 14:42:59 +0200 Subject: [PATCH 34/35] refactor[next]: skip dlopen at dace compile time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``sdfg.compile(return_program_handle=False)`` builds the .so without loading it; the subsequent ``DaCeCompilationArtifact.load()`` is then a clean first dlopen via ``get_program_handle``. The "library already loaded, renaming file" warning the compile→load sequence used to trigger is gone, and the artifact no longer needs to pre-populate ``_live_program_cache`` to dodge that rename. The cache stays, but for a narrower reason: hand off the live ``CompiledDaceProgram`` (carrying mutable ``csdfg_argv`` state) across ``Backend.compile()`` invocations for the same build folder, so ``fast_call`` doesn't re-run ``construct_arguments`` every time. Compile-time pre-population removed; the cache is now load-time only. ``library_path`` is resolved inside ``dace_context`` (which sets ``build_folder_mode``), since ``get_binary_name`` falls back to the global config when ``folder_mode`` is unset. --- .../runners/dace/workflow/compilation.py | 36 +++++++------------ .../dace_tests/test_dace_compilation.py | 1 - 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 9522db2b6b..a4b01539fd 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -135,15 +135,11 @@ def __call__(self, **kwargs: Any) -> None: assert result is None -# Hand off the live `CompiledDaceProgram` from `DaCeCompiler.__call__` to the -# subsequent `DaCeCompilationArtifact.load()` in the same process. Required for -# correctness in thread mode: `sdfg.compile()` dlopens the .so internally, so a -# second `get_program_handle(library_path, ...)` triggers dace's -# "library already loaded, renaming file" path — which renames the .so on disk -# and would invalidate `library_path` for any later load. -# TODO(havogt): drop this hand-off if dace stops renaming the .so on the -# second dlopen of an already-loaded library. The cache would then become -# a pure (modest) optimization and could be reconsidered on its own merits. +# Share the live ``CompiledDaceProgram`` across all ``load()`` calls for the same +# build folder. ``CompiledDaceProgram.csdfg_argv`` is mutable state populated on +# first call and reused via ``fast_call()``; if a second ``Backend.compile()`` for +# the same program returns a fresh instance, that state is lost and ``fast_call`` +# regresses to ``construct_arguments`` every call. _live_program_cache: dict[pathlib.Path, CompiledDaceProgram] = {} @@ -165,13 +161,6 @@ class DaCeCompilationArtifact: device_type: core_defs.DeviceType def load(self) -> stages.ExecutableProgram: - """Wrap the compiled program in gt4py's calling convention. - - On a miss, loads the precompiled .so directly via - ``dace.codegen.compiler.get_program_handle`` — no recompilation, - no ``dace.config`` re-entry. Must run in the process that will call - the returned program. - """ program = _live_program_cache.get(self.build_folder) if program is None: program = self._load_compiled_program() @@ -230,21 +219,22 @@ def __call__( sdfg.build_folder = sdfg_build_folder with locking.lock(sdfg_build_folder): - sdfg_program = sdfg.compile(validate=False) + sdfg.compile(validate=False, return_program_handle=False) + # ``build_folder_mode`` is set by ``dace_context``; resolve the library + # path here so ``get_binary_name`` sees the same mode dace built under. + library_path = dace_compiler.get_binary_name( + object_folder=sdfg_build_folder, sdfg_name=sdfg.name + ) assert inp.binding_source is not None - artifact = DaCeCompilationArtifact( + return DaCeCompilationArtifact( build_folder=pathlib.Path(sdfg_build_folder), - library_path=pathlib.Path(sdfg_program.filename), + library_path=library_path, sdfg_json=json.dumps(inp.program_source.source_code), binding_source_code=inp.binding_source.source_code, bind_func_name=self.bind_func_name, device_type=self.device_type, ) - _live_program_cache[artifact.build_folder] = CompiledDaceProgram( - sdfg_program, artifact.bind_func_name, artifact.binding_source_code - ) - return artifact class DaCeCompilationStepFactory(factory.Factory): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index 28cf001090..6c3e18757e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -98,7 +98,6 @@ def _run_compiler( wraps=dace_wf_compilation._add_tx_markers, ) as spy, mock.patch.object(dace.SDFG, "compile", autospec=True) as compile_mock, - mock.patch.object(dace_wf_compilation, "CompiledDaceProgram"), mock.patch.object( dace_wf_compilation.gtx_wfdcommon, "dace_context", From f9e7f25f4ace544c73253386c19c1789ad6046b2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 26 Jun 2026 15:29:10 +0200 Subject: [PATCH 35/35] refactor[next]: drop _live_program_cache in dace compilation artifact MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After ``return_program_handle=False``, the cache no longer serves the rename-workaround purpose. Its only remaining role was to share ``CompiledDaceProgram.csdfg_argv`` state across ``Backend.compile()`` invocations for the same build folder — a scenario only triggered by calling ``with_grid_type(...).with_backend(...)`` per program invocation, which is an antipattern that creates a fresh ``CompiledProgramsPool`` every call. The two fastcall tests were going through ``cases.verify`` → ``cases.run`` which rebinds on every call, masking the antipattern with the cache. Rewrite them to bind once at the top of the test and call the bound program directly. They now exercise fast_call reuse via the recommended pattern (single bind, many calls) and would catch a real regression. --- .../runners/dace/workflow/compilation.py | 18 +----- .../runners_tests/dace_tests/test_dace.py | 35 ++++++----- .../dace_tests/test_dace_compilation.py | 63 +------------------ 3 files changed, 25 insertions(+), 91 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index a4b01539fd..1c9daf6af2 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -135,14 +135,6 @@ def __call__(self, **kwargs: Any) -> None: assert result is None -# Share the live ``CompiledDaceProgram`` across all ``load()`` calls for the same -# build folder. ``CompiledDaceProgram.csdfg_argv`` is mutable state populated on -# first call and reused via ``fast_call()``; if a second ``Backend.compile()`` for -# the same program returns a fresh instance, that state is lost and ``fast_call`` -# regresses to ``construct_arguments`` every call. -_live_program_cache: dict[pathlib.Path, CompiledDaceProgram] = {} - - @dataclasses.dataclass(frozen=True) class DaCeCompilationArtifact: """Result of a DaCe compilation: build folder + library path + SDFG bindings + the SDFG itself. @@ -161,19 +153,13 @@ class DaCeCompilationArtifact: device_type: core_defs.DeviceType def load(self) -> stages.ExecutableProgram: - program = _live_program_cache.get(self.build_folder) - if program is None: - program = self._load_compiled_program() - _live_program_cache[self.build_folder] = program - return gtx_wfddecoration.convert_args(program, device=self.device_type) - - def _load_compiled_program(self) -> CompiledDaceProgram: # TODO(phimuell): Drop ``sdfg_json`` from the artifact once dace # exposes a load path that doesn't require an SDFG instance to wrap # into the returned ``CompiledSDFG``. sdfg = dace.SDFG.from_json(json.loads(self.sdfg_json)) sdfg_program = dace_compiler.get_program_handle(self.library_path, sdfg) - return CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) + program = CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) + return gtx_wfddecoration.convert_args(program, device=self.device_type) @dataclasses.dataclass(frozen=True) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 7463a859b1..a73497cb42 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -149,20 +149,26 @@ def testee( mock_fast_call, mock_construct_arguments = make_mocks(monkeypatch) - # Reset mock objects and run/verify GT4Py program + # Bind once: cases.verify would re-call with_grid_type().with_backend() every + # invocation (creating a fresh CompiledProgramsPool each time and dropping + # dace's fast_call csdfg_argv state). Calling the bound program directly is + # the recommended pattern this test is checking the behavior of. + testee_bound = testee.with_grid_type(cartesian_case.grid_type).with_backend( + cartesian_case.backend + ) + def verify_testee(): mock_construct_arguments.reset_mock() mock_fast_call.reset_mock() - cases.verify( - cartesian_case, - testee, + testee_bound( a, a_index, unused_field, *a_offset, out=out, - ref=numpy_ref(a.asnumpy(), *a_offset[0:3]), + offset_provider=cartesian_case.offset_provider, ) + np.testing.assert_allclose(out.asnumpy(), numpy_ref(a.asnumpy(), *a_offset[0:3])) mock_fast_call.assert_called_once() # On first run, the SDFG arguments will have to be constructed @@ -201,18 +207,19 @@ def testee(a: cases.VField) -> cases.EField: mock_fast_call, mock_construct_arguments = make_mocks(monkeypatch) - # Reset mock objects and run/verify GT4Py program + # Bind once: cases.verify would re-call with_grid_type().with_backend() every + # invocation (creating a fresh CompiledProgramsPool each time and dropping + # dace's fast_call csdfg_argv state). Calling the bound program directly is + # the recommended pattern this test is checking the behavior of. + testee_bound = testee.with_grid_type(unstructured_case.grid_type).with_backend( + unstructured_case.backend + ) + def verify_testee(): mock_construct_arguments.reset_mock() mock_fast_call.reset_mock() - cases.verify( - unstructured_case, - testee, - a, - **kwfields, - offset_provider=unstructured_case.offset_provider, - ref=numpy_ref(a.asnumpy()), - ) + testee_bound(a, **kwfields, offset_provider=unstructured_case.offset_provider) + np.testing.assert_allclose(kwfields["out"].asnumpy(), numpy_ref(a.asnumpy())) mock_fast_call.assert_called_once() verify_testee() diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index 6c3e18757e..488c371193 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -8,9 +8,8 @@ """Tests for the compilation stage of the dace backend workflow. -Covers the GPU TX-marker instrumentation, the picklability of -``DaCeCompilationArtifact``, and the process-local live-program cache -that backs its ``load`` method. +Covers the GPU TX-marker instrumentation and the picklability of +``DaCeCompilationArtifact``. """ import contextlib @@ -168,61 +167,3 @@ def test_dace_compilation_artifact_pickle_round_trip(tmp_path: pathlib.Path): restored = pickle.loads(pickle.dumps(artifact)) assert restored == artifact - - -def test_dace_compilation_artifact_load_uses_live_program_cache(tmp_path: pathlib.Path): - """``load()`` returns the cached live program without touching the disk.""" - artifact = dace_wf_compilation.DaCeCompilationArtifact( - build_folder=tmp_path, - library_path=tmp_path / "build" / "libprogram.so", - sdfg_json="{}", - binding_source_code="def update_sdfg_args(*a, **k): ...", - bind_func_name="update_sdfg_args", - device_type=core_defs.DeviceType.CPU, - ) - sentinel = mock.MagicMock(name="CompiledDaceProgram") - - with ( - mock.patch.dict( - dace_wf_compilation._live_program_cache, {artifact.build_folder: sentinel}, clear=True - ), - mock.patch.object( - dace_wf_compilation.DaCeCompilationArtifact, "_load_compiled_program" - ) as load_mock, - mock.patch.object(dace_wf_compilation.gtx_wfddecoration, "convert_args") as convert_mock, - ): - result = artifact.load() - - load_mock.assert_not_called() - convert_mock.assert_called_once_with(sentinel, device=core_defs.DeviceType.CPU) - assert result is convert_mock.return_value - - -def test_dace_compilation_artifact_load_falls_back_to_disk_on_cache_miss( - tmp_path: pathlib.Path, -): - """On a cold cache, ``load()`` reloads from disk and warms the cache.""" - artifact = dace_wf_compilation.DaCeCompilationArtifact( - build_folder=tmp_path, - library_path=tmp_path / "build" / "libprogram.so", - sdfg_json="{}", - binding_source_code="def update_sdfg_args(*a, **k): ...", - bind_func_name="update_sdfg_args", - device_type=core_defs.DeviceType.CPU, - ) - reloaded = mock.MagicMock(name="CompiledDaceProgram") - - with ( - mock.patch.dict(dace_wf_compilation._live_program_cache, {}, clear=True), - mock.patch.object( - dace_wf_compilation.DaCeCompilationArtifact, - "_load_compiled_program", - return_value=reloaded, - ) as load_mock, - mock.patch.object(dace_wf_compilation.gtx_wfddecoration, "convert_args"), - ): - artifact.load() - assert dace_wf_compilation._live_program_cache[artifact.build_folder] is reloaded - # Second call must come from the cache. - artifact.load() - load_mock.assert_called_once()