From 9798dc577fed935c7e0ac4ff8842e635874968dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Mon, 22 Jun 2026 18:35:20 +0200 Subject: [PATCH 1/3] refactor[next]: replace factory-boy with plain backend builders factory-boy was repurposed as a production composition mechanism for the GTFN and DaCe backends and their OTF compilation workflows. Every model is already a plain frozen dataclass, so the Trait/SubFactory/SelfAttribute/ LazyAttribute machinery and the stringly-typed `__`-path overrides added indirection without value, and carried `type: ignore`s plus a note that factory-boy is broken. This replaces all production `factory.Factory` classes with plain builder functions or direct dataclass construction: - Leaf steps (Compiler, GTFNTranslationStep, DaCeTranslator, DaCeCompiler) are now constructed directly. - GTFN: `make_gtfn_workflow(...)` / `make_gtfn_backend(...)`; Traits became bool kwargs (gpu, cached) and `__`-paths became real parameters. - DaCe: `make_dace_workflow(...)`; `DaCeBackendFactory` folded into the existing `make_dace_backend(...)` (public signature unchanged). All public pre-built backends keep identical names and wiring (run_gtfn_*, run_dace_*, gtfn_cpu/gtfn_gpu). The few tests that called the production factories were migrated to the new builders. factory-boy moves from runtime `dependencies` to the `test` group (still used by the cartesian/eve IR test-data factories). Advanced docs updated accordingly. Aligned with the layered-architecture refactoring proposal ("plain-code backend builders, not factory-boy"); the Backend->toolchain rename and the pipeline collapse are intentionally left for later. --- docs/user/next/advanced/HackTheToolchain.md | 22 +-- docs/user/next/advanced/WorkflowPatterns.md | 2 +- pyproject.toml | 8 +- src/gt4py/next/otf/compilation/compiler.py | 7 - .../codegens/gtfn/gtfn_module.py | 10 +- .../program_processors/formatters/gtfn.py | 2 +- .../runners/dace/workflow/backend.py | 99 ++++------ .../runners/dace/workflow/compilation.py | 6 - .../runners/dace/workflow/factory.py | 108 ++++++----- .../runners/dace/workflow/translation.py | 6 - .../next/program_processors/runners/gtfn.py | 170 +++++++++--------- .../otf_tests/test_compiled_program.py | 2 +- .../gtfn_tests/test_gtfn_module.py | 12 +- .../runners_tests/test_gtfn.py | 14 +- uv.lock | 6 +- 15 files changed, 210 insertions(+), 264 deletions(-) diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 785cc0b24d..2e5ee4821e 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -46,7 +46,11 @@ skip_linting_transforms = SkipLinting(**same_steps) skip_linting_transforms.step_order(DUMMY_FOP) ``` -## Alternative Factory +## Alternative Workflow Steps + +The compiled GTFN workflow is a frozen dataclass built by `make_gtfn_workflow`. +To swap individual steps, build the default workflow and use `dataclasses.replace` +to override the steps you care about. ```python class MyCodeGen: ... @@ -55,16 +59,12 @@ class MyCodeGen: ... class Cpp2BindingsGen: ... -class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWorkflowFactory): - translation: workflow.Workflow[ - gtx.otf.definitions.CompilableProgramDef, gtx.otf.stages.ProgramSource - ] = MyCodeGen() - bindings: workflow.Workflow[gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableProject] = ( - Cpp2BindingsGen() - ) - - -PureCpp2WorkflowFactory(cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG) +base_workflow = gtx.program_processors.runners.gtfn.make_gtfn_workflow( + cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG +) +pure_cpp2_workflow = dataclasses.replace( + base_workflow, translation=MyCodeGen(), bindings=Cpp2BindingsGen() +) ``` ## Invent new Workflow Types diff --git a/docs/user/next/advanced/WorkflowPatterns.md b/docs/user/next/advanced/WorkflowPatterns.md index e65114a479..eb320cbe51 100644 --- a/docs/user/next/advanced/WorkflowPatterns.md +++ b/docs/user/next/advanced/WorkflowPatterns.md @@ -491,5 +491,5 @@ gtx.program_processors.runners.gtfn.run_gtfn_gpu.executor.otf_workflow?? ``` ```python -gtx.program_processors.runners.gtfn.GTFNBackendFactory?? +gtx.program_processors.runners.gtfn.make_gtfn_backend?? ``` diff --git a/pyproject.toml b/pyproject.toml index ec15e98106..7180c4da8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ profiling = [ scripts = ["pyyaml>=6.0.1", "typer>=0.12.3", "packaging"] test = [ 'coverage[toml]>=7.6.1', + 'factory-boy>=3.3.3', 'hypothesis>=6.0.0', 'nbmake>=1.4.6', 'nox>=2025.02.09', @@ -103,7 +104,6 @@ dependencies = [ 'dace>=2.0.0a3', 'deepdiff>=8.1.0', 'devtools>=0.6', - 'factory-boy>=3.3.3', "filelock>=3.18.0", 'frozendict>=2.3', 'gridtools-cpp>=2.3.9,==2.*', @@ -263,12 +263,6 @@ module = 'gt4py.next.iterator.*' ignore_errors = true module = 'gt4py.next.iterator.runtime' -[[tool.mypy.overrides]] -ignore_missing_imports = true -implicit_reexport = true -# factory-boy is broken, see https://github.com/FactoryBoy/factory_boy/pull/1114 -module = "factory.*" - [[tool.mypy.overrides]] disallow_incomplete_defs = false disallow_untyped_defs = false diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..42e7bd5545 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -12,8 +12,6 @@ import pathlib from typing import Protocol, TypeVar -import factory - from gt4py._core import locking from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow @@ -91,9 +89,4 @@ def __call__( return func -class CompilerFactory(factory.Factory): - class Meta: - model = Compiler - - class CompilationError(RuntimeError): ... diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index a452bf53fe..c525271088 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -12,7 +12,6 @@ import functools from typing import Any, Final, Optional -import factory import numpy as np from gt4py._core import definitions as core_defs @@ -284,13 +283,8 @@ def _not_implemented_for_device_type(self) -> NotImplementedError: ) -class GTFNTranslationStepFactory(factory.Factory[GTFNTranslationStep]): - class Meta: - model = GTFNTranslationStep +translate_program_cpu: Final[definitions.TranslationStep] = GTFNTranslationStep() - -translate_program_cpu: Final[definitions.TranslationStep] = GTFNTranslationStepFactory() # type: ignore[assignment] # factory-boy typing not precise enough - -translate_program_gpu: Final[definitions.TranslationStep] = GTFNTranslationStepFactory( # type: ignore[assignment] # factory-boy typing not precise enough +translate_program_gpu: Final[definitions.TranslationStep] = GTFNTranslationStep( device_type=core_defs.DeviceType.CUDA ) diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index 1d65b8d8d0..9d2e610ab2 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -17,7 +17,7 @@ @program_formatter.program_formatter def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. - gtfn_translation = gtfn.GTFNBackendFactory().executor.translation # type: ignore[attr-defined] + gtfn_translation = gtfn.make_gtfn_workflow().translation assert isinstance(gtfn_translation, GTFNTranslationStep) return gtfn_translation.generate_stencil_source( program, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index 18ba132324..74c8913b1b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -11,61 +11,11 @@ import warnings from typing import Any, Final -import factory - import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import definitions as core_defs from gt4py.next import backend, common, config from gt4py.next.otf import stages, workflow -from gt4py.next.program_processors.runners.dace.workflow.factory import DaCeWorkflowFactory - - -class DaCeBackendFactory(factory.Factory): - """ - Workflow factory for the GTIR-DaCe backend. - - Several parameters are inherithed from `backend.Backend`, see below the specific ones. - - Args: - auto_optimize: Enables the SDFG transformation pipeline. - """ - - class Meta: - model = backend.Backend - - class Params: - name_device = "cpu" - name_cached = "" - name_postfix = "" - gpu = factory.Trait( - allocator=next_allocators.StandardGPUFieldBufferAllocator(), - device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, - name_device="gpu", - ) - cached = factory.Trait( - executor=factory.LazyAttribute( - lambda o: workflow.CachedStep.in_memory( - o.otf_workflow, input_fingerprinter=o.key_function - ) - ), - name_cached="_cached", - ) - device_type = core_defs.DeviceType.CPU - key_function = stages.fast_compilable_program_fingerprinter - otf_workflow = factory.SubFactory( - DaCeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - auto_optimize=factory.SelfAttribute("..auto_optimize"), - ) - auto_optimize = factory.Trait(name_postfix="_opt") - - name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}" - ) - - executor = factory.LazyAttribute(lambda o: o.otf_workflow) - allocator = next_allocators.StandardCPUFieldBufferAllocator() - transforms = backend.DEFAULT_TRANSFORMS +from gt4py.next.program_processors.runners.dace.workflow.factory import make_dace_workflow def make_dace_backend( @@ -125,17 +75,44 @@ def make_dace_backend( else None } - return DaCeBackendFactory( # type: ignore[return-value] # factory-boy typing not precise enough - gpu=gpu, - cached=cached, + allocator: next_allocators.FieldBufferAllocatorProtocol + device_type: core_defs.DeviceType + if gpu: + allocator = next_allocators.StandardGPUFieldBufferAllocator() + device_type = core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA + name_device = "gpu" + else: + allocator = next_allocators.StandardCPUFieldBufferAllocator() + device_type = core_defs.DeviceType.CPU + name_device = "cpu" + + otf_workflow = make_dace_workflow( + device_type=device_type, auto_optimize=auto_optimize, - otf_workflow__cached_translation=cached, - otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), - otf_workflow__bare_translation__auto_optimize_args=optimization_args, - otf_workflow__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, - otf_workflow__bare_translation__use_metrics=use_metrics, - otf_workflow__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, - otf_workflow__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, + cached_translation=cached, + async_sdfg_call=(async_sdfg_call if gpu else False), + auto_optimize_args=optimization_args, + unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, + use_metrics=use_metrics, + disable_field_origin_on_program_arguments=use_zero_origin, + use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, + ) + + executor = ( + workflow.CachedStep.in_memory( + otf_workflow, input_fingerprinter=stages.fast_compilable_program_fingerprinter + ) + if cached + else otf_workflow + ) + + name_cached = "_cached" if cached else "" + name_postfix = "_opt" if auto_optimize else "" + return backend.Backend( + name=f"run_dace_{name_device}{name_cached}{name_postfix}", + executor=executor, + allocator=allocator, + transforms=backend.DEFAULT_TRANSFORMS, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index b8e18382d8..9aeda9cc82 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -15,7 +15,6 @@ from typing import Any import dace -import factory from gt4py._core import definitions as core_defs, locking from gt4py.next import common, config @@ -179,8 +178,3 @@ def __call__( self.bind_func_name, inp.binding_source, ) - - -class DaCeCompilationStepFactory(factory.Factory): - class Meta: - model = DaCeCompiler diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index e473e1493c..205fd840f3 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -9,78 +9,74 @@ from __future__ import annotations import functools -from typing import Final - -import factory +from typing import Any, Final from gt4py._core import definitions as core_defs, filecache from gt4py.next import config -from gt4py.next.otf import recipes, stages, workflow +from gt4py.next.otf import definitions, recipes, stages, workflow from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.runners.dace.workflow import ( bindings as bindings_step, decoration as decoration_step, ) -from gt4py.next.program_processors.runners.dace.workflow.compilation import ( - DaCeCompilationStepFactory, -) -from gt4py.next.program_processors.runners.dace.workflow.translation import ( - DaCeTranslationStepFactory, -) +from gt4py.next.program_processors.runners.dace.workflow.compilation import DaCeCompiler +from gt4py.next.program_processors.runners.dace.workflow.translation import DaCeTranslator _GT_DACE_BINDING_FUNCTION_NAME: Final[str] = "update_sdfg_args" -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow +def make_dace_workflow( + *, + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU, + auto_optimize: bool = False, + cached_translation: bool = False, + async_sdfg_call: bool = False, + auto_optimize_args: dict[str, Any] | None = None, + unstructured_horizontal_has_unit_stride: bool = False, + use_metrics: bool = True, + disable_field_origin_on_program_arguments: bool = False, + use_max_domain_range_on_unstructured_shift: bool | None = None, + cmake_build_type: config.CMakeBuildType | None = None, +) -> recipes.OTFCompileWorkflow: + """Build the DaCe translation -> bindings -> compilation -> decoration workflow.""" + cmake_build_type = config.CMAKE_BUILD_TYPE if cmake_build_type is None else cmake_build_type - class Params: - auto_optimize: bool = False - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE - ) - - cached_translation = factory.Trait( - translation=factory.LazyAttribute( - lambda o: workflow.CachedStep.persistent( - o.bare_translation, - input_fingerprinter=stages.compilable_program_fingerprinter, - cache=filecache.FileCache( - str( - cache.get_cache_base_path(config.BUILD_CACHE_LIFETIME) - / "translation_cache" - ) - ), - ) + bare_translation = DaCeTranslator( + device_type=device_type, + auto_optimize=auto_optimize, + auto_optimize_args=auto_optimize_args, + async_sdfg_call=async_sdfg_call, + unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, + use_metrics=use_metrics, + disable_field_origin_on_program_arguments=disable_field_origin_on_program_arguments, + use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, + ) + translation: definitions.TranslationStep + if cached_translation: + translation = workflow.CachedStep.persistent( + bare_translation, + # mypy cannot solve `CachedStep`'s `HashT` type variable here (it only + # appears in the fingerprinter's return), so the `str` fingerprint is + # not recognized as a valid `HashT`. + input_fingerprinter=stages.compilable_program_fingerprinter, # type: ignore[arg-type] + cache=filecache.FileCache( + str(cache.get_cache_base_path(config.BUILD_CACHE_LIFETIME) / "translation_cache") ), ) + else: + translation = bare_translation - bare_translation = factory.SubFactory( - DaCeTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - auto_optimize=factory.SelfAttribute("..auto_optimize"), - ) - - translation = factory.LazyAttribute(lambda o: o.bare_translation) - bindings = factory.LazyAttribute( - lambda o: functools.partial( - bindings_step.bind_sdfg, + return recipes.OTFCompileWorkflow( + translation=translation, + bindings=functools.partial( + bindings_step.bind_sdfg, bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME + ), + compilation=DaCeCompiler( bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME, - ) - ) - compilation = factory.SubFactory( - DaCeCompilationStepFactory, - bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - device_type=factory.SelfAttribute("..device_type"), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - decoration_step.convert_args, - device=o.device_type, - ) + cache_lifetime=config.BUILD_CACHE_LIFETIME, + device_type=device_type, + cmake_build_type=cmake_build_type, + ), + decoration=functools.partial(decoration_step.convert_args, device=device_type), ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 5c8e0cc260..e7480d23fe 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -12,7 +12,6 @@ from typing import Any, Optional import dace -import factory from gt4py._core import definitions as core_defs from gt4py.next import common @@ -465,8 +464,3 @@ def __call__( code_spec=code_specs.SDFGCodeSpec(), ) return module - - -class DaCeTranslationStepFactory(factory.Factory): - class Meta: - model = DaCeTranslator diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index d4b86463f2..52e5dcc039 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -9,7 +9,6 @@ import functools from typing import Any -import factory import numpy as np import gt4py._core.definitions as core_defs @@ -18,7 +17,7 @@ from gt4py.next import backend, common, config, field_utils from gt4py.next.embedded import nd_array_field from gt4py.next.instrumentation import metrics -from gt4py.next.otf import recipes, stages, workflow +from gt4py.next.otf import definitions, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import cache, compiler from gt4py.next.otf.compilation.build_systems import compiledb @@ -103,102 +102,107 @@ def extract_connectivity_args( return args -class GTFNCompileWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE - ) - builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough - lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) - ) +def make_gtfn_workflow( + *, + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU, + cached_translation: bool = False, + enable_itir_transforms: bool = True, + use_imperative_backend: bool = False, + cmake_build_type: config.CMakeBuildType | None = None, +) -> recipes.OTFCompileWorkflow: + """Build the GTFN translation -> bindings -> compilation -> decoration workflow.""" + cmake_build_type = config.CMAKE_BUILD_TYPE if cmake_build_type is None else cmake_build_type + builder_factory: compiler.BuildSystemProjectGenerator = compiledb.CompiledbFactory( + cmake_build_type=cmake_build_type + ) - cached_translation = factory.Trait( - translation=factory.LazyAttribute( - lambda o: workflow.CachedStep.persistent( - o.bare_translation, - input_fingerprinter=stages.compilable_program_fingerprinter, - cache=filecache.FileCache( - str(cache.get_cache_base_path(config.BUILD_CACHE_LIFETIME) / "gtfn_cache") - ), - ) + bare_translation = gtfn_module.GTFNTranslationStep( + device_type=device_type, + enable_itir_transforms=enable_itir_transforms, + use_imperative_backend=use_imperative_backend, + ) + translation: definitions.TranslationStep + if cached_translation: + translation = workflow.CachedStep.persistent( + bare_translation, + # mypy cannot solve `CachedStep`'s `HashT` type variable here (it only + # appears in the fingerprinter's return), so the `str` fingerprint is + # not recognized as a valid `HashT`. + input_fingerprinter=stages.compilable_program_fingerprinter, # type: ignore[arg-type] + cache=filecache.FileCache( + str(cache.get_cache_base_path(config.BUILD_CACHE_LIFETIME) / "gtfn_cache") ), ) + else: + translation = bare_translation + + return recipes.OTFCompileWorkflow( + translation=translation, + bindings=nanobind.bind_source, + compilation=compiler.Compiler( + cache_lifetime=config.BUILD_CACHE_LIFETIME, + builder_factory=builder_factory, + ), + decoration=functools.partial(convert_args, device=device_type), + ) - bare_translation = factory.SubFactory( - gtfn_module.GTFNTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - ) - translation = factory.LazyAttribute(lambda o: o.bare_translation) +def make_gtfn_backend( + *, + gpu: bool = False, + cached: bool = False, + cached_translation: bool = False, + enable_itir_transforms: bool = True, + use_imperative_backend: bool = False, + name_postfix: str = "", + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] + | None = None, +) -> backend.Backend: + """Build a GTFN backend for the given device and caching configuration.""" + allocator: next_allocators.FieldBufferAllocatorProtocol + device_type: core_defs.DeviceType + if gpu: + allocator = next_allocators.StandardGPUFieldBufferAllocator() + device_type = core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA + name_device = "gpu" + else: + allocator = next_allocators.StandardCPUFieldBufferAllocator() + device_type = core_defs.DeviceType.CPU + name_device = "cpu" - bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] = ( - nanobind.bind_source - ) - compilation = factory.SubFactory( - compiler.CompilerFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - builder_factory=factory.SelfAttribute("..builder_factory"), + otf_workflow = make_gtfn_workflow( + device_type=device_type, + cached_translation=cached_translation, + enable_itir_transforms=enable_itir_transforms, + use_imperative_backend=use_imperative_backend, ) - decoration = factory.LazyAttribute( - lambda o: functools.partial(convert_args, device=o.device_type) - ) - - -class GTFNBackendFactory(factory.Factory): - class Meta: - model = backend.Backend - class Params: - name_device = "cpu" - name_cached = "" - name_temps = "" - name_postfix = "" - gpu = factory.Trait( - allocator=next_allocators.StandardGPUFieldBufferAllocator(), - device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, - name_device="gpu", - ) - cached = factory.Trait( - executor=factory.LazyAttribute( - lambda o: workflow.CachedStep.in_memory( - o.otf_workflow, input_fingerprinter=o.key_function - ) - ), - name_cached="_cached", - ) - device_type = core_defs.DeviceType.CPU - key_function = stages.fast_compilable_program_fingerprinter - otf_workflow = factory.SubFactory( - GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + if executor is None: + executor = ( + workflow.CachedStep.in_memory( + otf_workflow, input_fingerprinter=stages.fast_compilable_program_fingerprinter + ) + if cached + else otf_workflow ) - name = factory.LazyAttribute( - lambda o: f"run_gtfn_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" + name_cached = "_cached" if cached else "" + return backend.Backend( + name=f"run_gtfn_{name_device}{name_cached}{name_postfix}", + executor=executor, + allocator=allocator, + transforms=backend.DEFAULT_TRANSFORMS, ) - executor = factory.LazyAttribute(lambda o: o.otf_workflow) - allocator = next_allocators.StandardCPUFieldBufferAllocator() - transforms = backend.DEFAULT_TRANSFORMS - -run_gtfn = GTFNBackendFactory() +run_gtfn = make_gtfn_backend() -run_gtfn_imperative = GTFNBackendFactory( - name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True -) +run_gtfn_imperative = make_gtfn_backend(name_postfix="_imperative", use_imperative_backend=True) -run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) +run_gtfn_cached = make_gtfn_backend(cached=True, cached_translation=True) -run_gtfn_gpu = GTFNBackendFactory(gpu=True) +run_gtfn_gpu = make_gtfn_backend(gpu=True) -run_gtfn_gpu_cached = GTFNBackendFactory( - gpu=True, cached=True, otf_workflow__cached_translation=True -) +run_gtfn_gpu_cached = make_gtfn_backend(gpu=True, cached=True, cached_translation=True) -run_gtfn_no_transforms = GTFNBackendFactory( - otf_workflow__bare_translation__enable_itir_transforms=False -) +run_gtfn_no_transforms = make_gtfn_backend(enable_itir_transforms=False) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index def8800c98..b25e5e6897 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -121,7 +121,7 @@ def pirate(program: toolchain.ConcreteArtifact): hijacked_program = program return lambda *args, **kwargs: None - hacked_gtfn_backend = gtfn.GTFNBackendFactory(name_postfix="_custom", executor=pirate) + hacked_gtfn_backend = gtfn.make_gtfn_backend(name_postfix="_custom", executor=pirate) testee = testee_prog.with_backend(hacked_gtfn_backend).compile(cond=[True], offset_provider={}) testee( diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 2e795d4fb9..18efe7c399 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -134,12 +134,12 @@ def test_gtfn_file_cache(program_example): data=fencil, args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) - cached_gtfn_translation_step = gtfn.GTFNBackendFactory( - gpu=False, cached=True, otf_workflow__cached_translation=True + cached_gtfn_translation_step = gtfn.make_gtfn_backend( + gpu=False, cached=True, cached_translation=True ).executor.step.translation - bare_gtfn_translation_step = gtfn.GTFNBackendFactory( - gpu=False, cached=True, otf_workflow__cached_translation=False + bare_gtfn_translation_step = gtfn.make_gtfn_backend( + gpu=False, cached=True, cached_translation=False ).executor.step.translation cache_key = cached_gtfn_translation_step.cache_key(compilable_program) @@ -162,9 +162,7 @@ def test_gtfn_file_cache(program_example): # TODO(egparedes): we should switch to use the cached backend by default and then remove this test def test_gtfn_file_cache_whole_workflow(cartesian_case_no_backend): cartesian_case = cartesian_case_no_backend - cartesian_case.backend = gtfn.GTFNBackendFactory( - gpu=False, cached=True, otf_workflow__cached_translation=True - ) + cartesian_case.backend = gtfn.make_gtfn_backend(gpu=False, cached=True, cached_translation=True) cartesian_case.allocator = next_allocators.StandardCPUFieldBufferAllocator() assert cartesian_case.backend is not None diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 96d8c6e27c..55204ef8eb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -28,8 +28,8 @@ def test_backend_factory_trait_device(): - cpu_version = gtfn.GTFNBackendFactory(gpu=False, cached=False) - gpu_version = gtfn.GTFNBackendFactory(gpu=True, cached=False) + cpu_version = gtfn.make_gtfn_backend(gpu=False, cached=False) + gpu_version = gtfn.make_gtfn_backend(gpu=True, cached=False) assert cpu_version.name == "run_gtfn_cpu" assert gpu_version.name == "run_gtfn_gpu" @@ -49,16 +49,16 @@ def test_backend_factory_trait_device(): def test_backend_factory_trait_cached(): - cached_version = gtfn.GTFNBackendFactory(gpu=False, cached=True) + cached_version = gtfn.make_gtfn_backend(gpu=False, cached=True) assert isinstance(cached_version.executor, workflow.CachedStep) assert cached_version.name == "run_gtfn_cpu_cached" def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.SESSION) - session_version = gtfn.GTFNBackendFactory() + session_version = gtfn.make_gtfn_backend() monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.PERSISTENT) - persistent_version = gtfn.GTFNBackendFactory() + persistent_version = gtfn.make_gtfn_backend() assert session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION assert ( @@ -69,9 +69,9 @@ def test_backend_factory_build_cache_config(monkeypatch): def test_backend_factory_build_type_config(monkeypatch): monkeypatch.setattr(config, "CMAKE_BUILD_TYPE", config.CMakeBuildType.RELEASE) - release_version = gtfn.GTFNBackendFactory() + release_version = gtfn.make_gtfn_backend() monkeypatch.setattr(config, "CMAKE_BUILD_TYPE", config.CMakeBuildType.MIN_SIZE_REL) - min_size_version = gtfn.GTFNBackendFactory() + min_size_version = gtfn.make_gtfn_backend() assert ( release_version.executor.compilation.builder_factory.cmake_build_type diff --git a/uv.lock b/uv.lock index 7027f7d4f5..b4ef002d33 100644 --- a/uv.lock +++ b/uv.lock @@ -1652,7 +1652,6 @@ dependencies = [ { name = "dace" }, { name = "deepdiff" }, { name = "devtools" }, - { name = "factory-boy" }, { name = "filelock" }, { name = "frozendict" }, { name = "gridtools-cpp" }, @@ -1739,6 +1738,7 @@ dev = [ { name = "coverage", extra = ["toml"] }, { name = "cython" }, { name = "esbonio" }, + { name = "factory-boy" }, { name = "hypothesis" }, { name = "jupytext" }, { name = "matplotlib" }, @@ -1811,6 +1811,7 @@ scripts = [ ] test = [ { name = "coverage", extra = ["toml"] }, + { name = "factory-boy" }, { name = "hypothesis" }, { name = "nbmake" }, { name = "nox" }, @@ -1861,7 +1862,6 @@ requires-dist = [ { name = "dace", specifier = ">=2.0.0a3" }, { name = "deepdiff", specifier = ">=8.1.0" }, { name = "devtools", specifier = ">=0.6" }, - { name = "factory-boy", specifier = ">=3.3.3" }, { name = "filelock", specifier = ">=3.18.0" }, { name = "frozendict", specifier = ">=2.3" }, { name = "gridtools-cpp", specifier = "==2.*,>=2.3.9" }, @@ -1907,6 +1907,7 @@ dev = [ { name = "coverage", extras = ["toml"], specifier = ">=7.6.1" }, { name = "cython", specifier = ">=3.0.0" }, { name = "esbonio", specifier = ">=0.16.0" }, + { name = "factory-boy", specifier = ">=3.3.3" }, { name = "hypothesis", specifier = ">=6.0.0" }, { name = "jupytext", specifier = ">=1.14" }, { name = "matplotlib", specifier = ">=3.9.0" }, @@ -1971,6 +1972,7 @@ scripts = [ ] test = [ { name = "coverage", extras = ["toml"], specifier = ">=7.6.1" }, + { name = "factory-boy", specifier = ">=3.3.3" }, { name = "hypothesis", specifier = ">=6.0.0" }, { name = "nbmake", specifier = ">=1.4.6" }, { name = "nox", specifier = ">=2025.2.9" }, From 6b4321cbb07d4cc9aec7bcad64e384b79cfa0563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Tue, 23 Jun 2026 07:26:17 +0200 Subject: [PATCH 2/3] refactor[next]: inject sub-components instead of forwarding nested config Follow-up to the factory-boy removal. Flattening the old `__`-path overrides into explicit maker kwargs re-coupled each maker to its sub-components' internals (most visibly `make_dace_backend`, forwarding ~6 translator knobs). Replace that with lightweight dependency injection: a maker exposes only cross-cutting configuration (device, caching, auto-optimize, build type) and accepts pre-built sub-components. A single-component knob is set by injecting that component, e.g. make_gtfn_backend(translation=GTFNTranslationStep(enable_itir_transforms=False)) Details: - `workflow.with_fields(step, **fields)`: guarded helper that stamps cross-cutting fields (device_type; dace auto_optimize) onto a step only when it is a dataclass declaring them, so injecting an arbitrary callable is safe. - GTFN makers drop `enable_itir_transforms`/`use_imperative_backend`, gain `translation`/`bindings`/`compilation`/`decoration` injection params. - `DaCeTranslator` gets defaults for its formerly-required fields; a new `make_dace_translator(...)` owns the translator-local knobs and the relocated `optimization_args` validation / `unit_strides_kind` derivation. - `make_dace_workflow`/`make_dace_backend` shed their translator-forward kwargs and accept injection (breaking change to `make_dace_backend`'s signature for external callers passing translator knobs). All pre-built backends keep identical names and behavior. Migrated the in-repo tests that passed removed kwargs to `make_dace_translator(...)` injection and added an injection unit test. --- docs/user/next/advanced/HackTheToolchain.md | 24 +++- src/gt4py/next/otf/workflow.py | 19 +++ .../runners/dace/workflow/backend.py | 127 +++++------------- .../runners/dace/workflow/factory.py | 55 ++++---- .../runners/dace/workflow/translation.py | 86 ++++++++++-- .../next/program_processors/runners/gtfn.py | 74 ++++++---- .../dace_tests/test_dace_backend.py | 12 +- .../dace_tests/test_dace_bindings.py | 10 +- .../runners_tests/test_gtfn.py | 12 ++ 9 files changed, 254 insertions(+), 165 deletions(-) diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 2e5ee4821e..27ee385763 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -48,9 +48,10 @@ skip_linting_transforms.step_order(DUMMY_FOP) ## Alternative Workflow Steps -The compiled GTFN workflow is a frozen dataclass built by `make_gtfn_workflow`. -To swap individual steps, build the default workflow and use `dataclasses.replace` -to override the steps you care about. +The builders (`make_gtfn_workflow` / `make_gtfn_backend`) expose only cross-cutting +configuration (device, caching, build type). To customize a single sub-component, +**inject** a pre-built step — the builder uses it instead of its default and stamps +the cross-cutting fields (e.g. `device_type`) onto it. ```python class MyCodeGen: ... @@ -59,14 +60,23 @@ class MyCodeGen: ... class Cpp2BindingsGen: ... -base_workflow = gtx.program_processors.runners.gtfn.make_gtfn_workflow( - cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG +# Inject whole custom steps: +pure_cpp2_workflow = gtx.program_processors.runners.gtfn.make_gtfn_workflow( + cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG, + translation=MyCodeGen(), + bindings=Cpp2BindingsGen(), ) -pure_cpp2_workflow = dataclasses.replace( - base_workflow, translation=MyCodeGen(), bindings=Cpp2BindingsGen() + +# Or inject a configured default step to flip one knob: +GTFNTranslationStep = gtx.program_processors.runners.gtfn.gtfn_module.GTFNTranslationStep +no_transforms = gtx.program_processors.runners.gtfn.make_gtfn_backend( + translation=GTFNTranslationStep(enable_itir_transforms=False), ) ``` +A pre-built backend is itself a frozen dataclass, so you can also override after the +fact with `dataclasses.replace` (e.g. `dataclasses.replace(backend, executor=...)`). + ## Invent new Workflow Types ```mermaid diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index e2426450a0..4523c50ea9 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -31,6 +31,25 @@ HashT = TypeVar("HashT") DataT = TypeVar("DataT") ArgT = TypeVar("ArgT") +StepT = TypeVar("StepT") + + +def with_fields(step: StepT, **fields: Any) -> StepT: + """ + Return a copy of ``step`` with ``fields`` applied, where it declares them. + + Used by the backend/workflow builders to stamp cross-cutting configuration + (e.g. ``device_type``) onto a sub-component, whether it was built by default + or injected by the caller. Only fields that ``step`` declares as a dataclass + are applied; any other ``step`` (e.g. a plain callable) is returned unchanged, + so injecting an arbitrary workflow step stays safe. + """ + if dataclasses.is_dataclass(step) and not isinstance(step, type): + names = {f.name for f in dataclasses.fields(step)} + valid = {key: value for key, value in fields.items() if key in names} + if valid: + return dataclasses.replace(step, **valid) + return step def make_step(function: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[StartT, EndT]: diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index 74c8913b1b..c47faf3ba2 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -8,73 +8,40 @@ from __future__ import annotations -import warnings -from typing import Any, Final - import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import definitions as core_defs -from gt4py.next import backend, common, config -from gt4py.next.otf import stages, workflow +from gt4py.next import backend +from gt4py.next.otf import definitions, stages, workflow from gt4py.next.program_processors.runners.dace.workflow.factory import make_dace_workflow +from gt4py.next.program_processors.runners.dace.workflow.translation import make_dace_translator def make_dace_backend( gpu: bool, + *, cached: bool = True, auto_optimize: bool = True, - async_sdfg_call: bool = True, - optimization_args: dict[str, Any] | None = None, - unstructured_horizontal_has_unit_stride: bool = config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE, - use_metrics: bool = True, - use_zero_origin: bool = False, - use_max_domain_range_on_unstructured_shift: bool | None = None, + name_postfix: str = "", + translation: definitions.TranslationStep | None = None, + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] + | None = None, ) -> backend.Backend: - """Customize the dace backend with the given configuration parameters. + """Build the GTIR-DaCe backend for the given device and configuration. + + Cross-cutting configuration is passed as keyword arguments. To customize the + translator-local options (async SDFG call, metrics, zero-origin, ...), inject a + pre-built translator via ``translation=`` (see `make_dace_translator`); its + ``device_type`` and ``auto_optimize`` are set to match ``gpu``/``auto_optimize``. + Pass ``executor=`` to replace the whole executor workflow. Args: gpu: Enable GPU transformations and code generation. cached: Cache the lowered SDFG as a JSON file and the compiled programs. auto_optimize: Enable the SDFG auto-optimize pipeline. - async_sdfg_call: Make an asynchronous SDFG call on GPU to allow overlapping - of GPU kernel execution with the Python driver code. - optimization_args: A `dict` containing configuration parameters for - the SDFG auto-optimize pipeline, see `gt_auto_optimize()`. - unstructured_horizontal_has_unit_stride: When the memory layout has unit stride - in the horizontal dimension, replace the field stride symbol with '1'. - use_metrics: Add SDFG instrumentation to collect the metric for stencil - compute time. - use_zero_origin: Can be set to `True` when all fields passed as program - arguments have zero-based origin. This setting will skip generation - of range start-symbols `_range_0` since they can be assumed to be zero. - - Note that `gt_auto_optimize()` parameters that are derived from GT4Py configuration - cannot be overriden, and therefore cannot appear here. Thus, this function will - throw an exception if called with any argument included in `gt_optimization_args`. Returns: - A dace backend with custom configuration for the target device. + A dace backend for the target device. """ - - # The `gt_optimization_args` set contains the parameters of `gt_auto_optimize()` - # that are derived from the gt4py configuration, and therefore cannot be customized. - gt_optimization_args: Final[set[str]] = {"gpu", "constant_symbols", "unit_strides_kind"} - - if optimization_args is None: - optimization_args = {} - elif optimization_args and not auto_optimize: - warnings.warn("Optimizations args given, but auto-optimize is disabled.", stacklevel=2) - elif intersect_args := gt_optimization_args.intersection(optimization_args.keys()): - raise ValueError( - f"The following optimization arguments cannot be overriden: {intersect_args}." - ) - - # Set `unit_strides_kind` based on the gt4py env configuration. - optimization_args = optimization_args | { - "unit_strides_kind": common.DimensionKind.HORIZONTAL - if unstructured_horizontal_has_unit_stride - else None - } - allocator: next_allocators.FieldBufferAllocatorProtocol device_type: core_defs.DeviceType if gpu: @@ -86,70 +53,50 @@ def make_dace_backend( device_type = core_defs.DeviceType.CPU name_device = "cpu" - otf_workflow = make_dace_workflow( - device_type=device_type, - auto_optimize=auto_optimize, - cached_translation=cached, - async_sdfg_call=(async_sdfg_call if gpu else False), - auto_optimize_args=optimization_args, - unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, - use_metrics=use_metrics, - disable_field_origin_on_program_arguments=use_zero_origin, - use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, - ) - - executor = ( - workflow.CachedStep.in_memory( - otf_workflow, input_fingerprinter=stages.fast_compilable_program_fingerprinter + if executor is None: + otf_workflow = make_dace_workflow( + device_type=device_type, + auto_optimize=auto_optimize, + cached_translation=cached, + translation=translation, + ) + executor = ( + workflow.CachedStep.in_memory( + otf_workflow, input_fingerprinter=stages.fast_compilable_program_fingerprinter + ) + if cached + else otf_workflow ) - if cached - else otf_workflow - ) name_cached = "_cached" if cached else "" - name_postfix = "_opt" if auto_optimize else "" + name_opt = "_opt" if auto_optimize else "" return backend.Backend( - name=f"run_dace_{name_device}{name_cached}{name_postfix}", + name=f"run_dace_{name_device}{name_cached}{name_opt}{name_postfix}", executor=executor, allocator=allocator, transforms=backend.DEFAULT_TRANSFORMS, ) -run_dace_cpu = make_dace_backend( - gpu=False, - cached=False, - auto_optimize=True, - async_sdfg_call=False, -) -run_dace_cpu_noopt = make_dace_backend( - gpu=False, - cached=False, - auto_optimize=False, - async_sdfg_call=False, -) -run_dace_cpu_cached = make_dace_backend( - gpu=False, - cached=True, - auto_optimize=True, - async_sdfg_call=False, -) +run_dace_cpu = make_dace_backend(gpu=False, cached=False, auto_optimize=True) +run_dace_cpu_noopt = make_dace_backend(gpu=False, cached=False, auto_optimize=False) +run_dace_cpu_cached = make_dace_backend(gpu=False, cached=True, auto_optimize=True) run_dace_gpu = make_dace_backend( gpu=True, cached=False, auto_optimize=True, - async_sdfg_call=True, + translation=make_dace_translator(async_sdfg_call=True), ) run_dace_gpu_noopt = make_dace_backend( gpu=True, cached=False, auto_optimize=False, - async_sdfg_call=True, + translation=make_dace_translator(async_sdfg_call=True), ) run_dace_gpu_cached = make_dace_backend( gpu=True, cached=True, auto_optimize=True, - async_sdfg_call=True, + translation=make_dace_translator(async_sdfg_call=True), ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 205fd840f3..6c55e9fed1 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Any, Final +from typing import Final from gt4py._core import definitions as core_defs, filecache from gt4py.next import config @@ -20,7 +20,7 @@ decoration as decoration_step, ) from gt4py.next.program_processors.runners.dace.workflow.compilation import DaCeCompiler -from gt4py.next.program_processors.runners.dace.workflow.translation import DaCeTranslator +from gt4py.next.program_processors.runners.dace.workflow.translation import make_dace_translator _GT_DACE_BINDING_FUNCTION_NAME: Final[str] = "update_sdfg_args" @@ -31,30 +31,33 @@ def make_dace_workflow( device_type: core_defs.DeviceType = core_defs.DeviceType.CPU, auto_optimize: bool = False, cached_translation: bool = False, - async_sdfg_call: bool = False, - auto_optimize_args: dict[str, Any] | None = None, - unstructured_horizontal_has_unit_stride: bool = False, - use_metrics: bool = True, - disable_field_origin_on_program_arguments: bool = False, - use_max_domain_range_on_unstructured_shift: bool | None = None, cmake_build_type: config.CMakeBuildType | None = None, + translation: definitions.TranslationStep | None = None, + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] | None = None, + compilation: workflow.Workflow[stages.CompilableProject, stages.ExecutableProgram] + | None = None, + decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram] | None = None, ) -> recipes.OTFCompileWorkflow: - """Build the DaCe translation -> bindings -> compilation -> decoration workflow.""" + """Build the DaCe translation -> bindings -> compilation -> decoration workflow. + + Cross-cutting configuration (device, auto-optimize, translation caching, build + type) is passed as keyword arguments. To customize translator-local options, + inject a pre-built translator, e.g. + ``translation=make_dace_translator(async_sdfg_call=True)``; its ``device_type`` + and ``auto_optimize`` are set to match the cross-cutting arguments. + """ cmake_build_type = config.CMAKE_BUILD_TYPE if cmake_build_type is None else cmake_build_type - bare_translation = DaCeTranslator( + bare_translation = workflow.with_fields( + translation + if translation is not None + else make_dace_translator(auto_optimize=auto_optimize), device_type=device_type, auto_optimize=auto_optimize, - auto_optimize_args=auto_optimize_args, - async_sdfg_call=async_sdfg_call, - unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, - use_metrics=use_metrics, - disable_field_origin_on_program_arguments=disable_field_origin_on_program_arguments, - use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, ) - translation: definitions.TranslationStep + translation_step: definitions.TranslationStep if cached_translation: - translation = workflow.CachedStep.persistent( + translation_step = workflow.CachedStep.persistent( bare_translation, # mypy cannot solve `CachedStep`'s `HashT` type variable here (it only # appears in the fingerprinter's return), so the `str` fingerprint is @@ -65,18 +68,24 @@ def make_dace_workflow( ), ) else: - translation = bare_translation + translation_step = bare_translation return recipes.OTFCompileWorkflow( - translation=translation, - bindings=functools.partial( + translation=translation_step, + bindings=bindings + if bindings is not None + else functools.partial( bindings_step.bind_sdfg, bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME ), - compilation=DaCeCompiler( + compilation=compilation + if compilation is not None + else DaCeCompiler( bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME, cache_lifetime=config.BUILD_CACHE_LIFETIME, device_type=device_type, cmake_build_type=cmake_build_type, ), - decoration=functools.partial(decoration_step.convert_args, device=device_type), + decoration=decoration + if decoration is not None + else functools.partial(decoration_step.convert_args, device=device_type), ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index e7480d23fe..9baa53f154 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -9,12 +9,13 @@ from __future__ import annotations import dataclasses -from typing import Any, Optional +import warnings +from typing import Any, Final, Optional import dace from gt4py._core import definitions as core_defs -from gt4py.next import common +from gt4py.next import common, config from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import code_specs, definitions, stages, workflow @@ -345,12 +346,14 @@ class DaCeTranslator( ], definitions.TranslationStep[code_specs.SDFGCodeSpec], ): - device_type: core_defs.DeviceType - auto_optimize: bool - auto_optimize_args: dict[str, Any] | None - async_sdfg_call: bool - unstructured_horizontal_has_unit_stride: bool - use_metrics: bool + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + auto_optimize: bool = False + auto_optimize_args: dict[str, Any] | None = None + async_sdfg_call: bool = False + unstructured_horizontal_has_unit_stride: bool = dataclasses.field( + default_factory=lambda: config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE + ) + use_metrics: bool = True disable_itir_transforms: bool = False disable_field_origin_on_program_arguments: bool = False @@ -464,3 +467,70 @@ def __call__( code_spec=code_specs.SDFGCodeSpec(), ) return module + + +def make_dace_translator( + *, + auto_optimize: bool = False, + async_sdfg_call: bool = False, + optimization_args: dict[str, Any] | None = None, + unstructured_horizontal_has_unit_stride: bool = config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE, + use_metrics: bool = True, + use_zero_origin: bool = False, + use_max_domain_range_on_unstructured_shift: bool | None = None, +) -> DaCeTranslator: + """Build a `DaCeTranslator` with validated auto-optimize configuration. + + Translator-local configuration is set here; the cross-cutting `device_type` + and `auto_optimize` are stamped on by the workflow/backend builders, so a + caller injecting a translator usually leaves them at their defaults. + + Args: + auto_optimize: Enable the SDFG auto-optimize pipeline. + async_sdfg_call: Make an asynchronous SDFG call on GPU to allow overlapping + of GPU kernel execution with the Python driver code. + optimization_args: A `dict` containing configuration parameters for + the SDFG auto-optimize pipeline, see `gt_auto_optimize()`. + unstructured_horizontal_has_unit_stride: When the memory layout has unit + stride in the horizontal dimension, replace the field stride symbol + with '1'. + use_metrics: Add SDFG instrumentation to collect the metric for stencil + compute time. + use_zero_origin: Can be set to `True` when all fields passed as program + arguments have zero-based origin. This setting will skip generation + of range start-symbols `_range_0` since they can be assumed to be zero. + + Note that `gt_auto_optimize()` parameters that are derived from GT4Py + configuration cannot be overriden, and therefore cannot appear in + `optimization_args`. Thus, this function will throw an exception if called with + any such argument. + """ + # The `gt_optimization_args` set contains the parameters of `gt_auto_optimize()` + # that are derived from the gt4py configuration, and therefore cannot be customized. + gt_optimization_args: Final[set[str]] = {"gpu", "constant_symbols", "unit_strides_kind"} + + if optimization_args is None: + optimization_args = {} + elif optimization_args and not auto_optimize: + warnings.warn("Optimizations args given, but auto-optimize is disabled.", stacklevel=2) + elif intersect_args := gt_optimization_args.intersection(optimization_args.keys()): + raise ValueError( + f"The following optimization arguments cannot be overriden: {intersect_args}." + ) + + # Set `unit_strides_kind` based on the gt4py env configuration. + optimization_args = optimization_args | { + "unit_strides_kind": common.DimensionKind.HORIZONTAL + if unstructured_horizontal_has_unit_stride + else None + } + + return DaCeTranslator( + auto_optimize=auto_optimize, + auto_optimize_args=optimization_args, + async_sdfg_call=async_sdfg_call, + unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, + use_metrics=use_metrics, + disable_field_origin_on_program_arguments=use_zero_origin, + use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift, + ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 52e5dcc039..e1428c3282 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -106,24 +106,29 @@ def make_gtfn_workflow( *, device_type: core_defs.DeviceType = core_defs.DeviceType.CPU, cached_translation: bool = False, - enable_itir_transforms: bool = True, - use_imperative_backend: bool = False, cmake_build_type: config.CMakeBuildType | None = None, + translation: definitions.TranslationStep | None = None, + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] | None = None, + compilation: workflow.Workflow[stages.CompilableProject, stages.ExecutableProgram] + | None = None, + decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram] | None = None, ) -> recipes.OTFCompileWorkflow: - """Build the GTFN translation -> bindings -> compilation -> decoration workflow.""" + """Build the GTFN translation -> bindings -> compilation -> decoration workflow. + + Cross-cutting configuration (device, translation caching, build type) is passed + as keyword arguments. To customize a single step, inject a pre-built one, e.g. + ``translation=GTFNTranslationStep(enable_itir_transforms=False)``; its + ``device_type`` is set to match ``device_type``. + """ cmake_build_type = config.CMAKE_BUILD_TYPE if cmake_build_type is None else cmake_build_type - builder_factory: compiler.BuildSystemProjectGenerator = compiledb.CompiledbFactory( - cmake_build_type=cmake_build_type - ) - bare_translation = gtfn_module.GTFNTranslationStep( + bare_translation = workflow.with_fields( + translation if translation is not None else gtfn_module.GTFNTranslationStep(), device_type=device_type, - enable_itir_transforms=enable_itir_transforms, - use_imperative_backend=use_imperative_backend, ) - translation: definitions.TranslationStep + translation_step: definitions.TranslationStep if cached_translation: - translation = workflow.CachedStep.persistent( + translation_step = workflow.CachedStep.persistent( bare_translation, # mypy cannot solve `CachedStep`'s `HashT` type variable here (it only # appears in the fingerprinter's return), so the `str` fingerprint is @@ -134,16 +139,20 @@ def make_gtfn_workflow( ), ) else: - translation = bare_translation + translation_step = bare_translation return recipes.OTFCompileWorkflow( - translation=translation, - bindings=nanobind.bind_source, - compilation=compiler.Compiler( + translation=translation_step, + bindings=bindings if bindings is not None else nanobind.bind_source, + compilation=compilation + if compilation is not None + else compiler.Compiler( cache_lifetime=config.BUILD_CACHE_LIFETIME, - builder_factory=builder_factory, + builder_factory=compiledb.CompiledbFactory(cmake_build_type=cmake_build_type), ), - decoration=functools.partial(convert_args, device=device_type), + decoration=decoration + if decoration is not None + else functools.partial(convert_args, device=device_type), ) @@ -152,13 +161,17 @@ def make_gtfn_backend( gpu: bool = False, cached: bool = False, cached_translation: bool = False, - enable_itir_transforms: bool = True, - use_imperative_backend: bool = False, name_postfix: str = "", + translation: definitions.TranslationStep | None = None, executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] | None = None, ) -> backend.Backend: - """Build a GTFN backend for the given device and caching configuration.""" + """Build a GTFN backend for the given device and caching configuration. + + Cross-cutting configuration is passed as keyword arguments. To customize the + translation step, inject a pre-built one via ``translation=`` (its device is set + to match ``gpu``). Pass ``executor=`` to replace the whole executor workflow. + """ allocator: next_allocators.FieldBufferAllocatorProtocol device_type: core_defs.DeviceType if gpu: @@ -170,14 +183,12 @@ def make_gtfn_backend( device_type = core_defs.DeviceType.CPU name_device = "cpu" - otf_workflow = make_gtfn_workflow( - device_type=device_type, - cached_translation=cached_translation, - enable_itir_transforms=enable_itir_transforms, - use_imperative_backend=use_imperative_backend, - ) - if executor is None: + otf_workflow = make_gtfn_workflow( + device_type=device_type, + cached_translation=cached_translation, + translation=translation, + ) executor = ( workflow.CachedStep.in_memory( otf_workflow, input_fingerprinter=stages.fast_compilable_program_fingerprinter @@ -197,7 +208,10 @@ def make_gtfn_backend( run_gtfn = make_gtfn_backend() -run_gtfn_imperative = make_gtfn_backend(name_postfix="_imperative", use_imperative_backend=True) +run_gtfn_imperative = make_gtfn_backend( + name_postfix="_imperative", + translation=gtfn_module.GTFNTranslationStep(use_imperative_backend=True), +) run_gtfn_cached = make_gtfn_backend(cached=True, cached_translation=True) @@ -205,4 +219,6 @@ def make_gtfn_backend( run_gtfn_gpu_cached = make_gtfn_backend(gpu=True, cached=True, cached_translation=True) -run_gtfn_no_transforms = make_gtfn_backend(enable_itir_transforms=False) +run_gtfn_no_transforms = make_gtfn_backend( + translation=gtfn_module.GTFNTranslationStep(enable_itir_transforms=False), +) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py index 64b8be3a8b..249f0f1313 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py @@ -17,6 +17,7 @@ from gt4py._core import definitions as core_defs from gt4py.next.program_processors.runners.dace.workflow import ( backend as dace_wf_backend, + translation as dace_wf_translation, ) from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations @@ -95,10 +96,13 @@ def mocked_gpu_transformation(*args, **kwargs) -> dace.SDFG: gpu=on_gpu, cached=False, auto_optimize=auto_optimize, - async_sdfg_call=True, - optimization_args=optimization_args, - unstructured_horizontal_has_unit_stride=on_gpu, - use_metrics=True, + translation=dace_wf_translation.make_dace_translator( + auto_optimize=auto_optimize, + async_sdfg_call=True, + optimization_args=optimization_args, + unstructured_horizontal_has_unit_stride=on_gpu, + use_metrics=True, + ), ) testee.with_backend(custom_backend).compile(offset_provider={}) gtx.wait_for_compilation() diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py index 87d942a547..48c9d2f66b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py @@ -299,8 +299,9 @@ def testee( gpu=False, cached=False, auto_optimize=True, - use_metrics=use_metrics, - use_zero_origin=use_zero_origin, + translation=dace_workflow.translation.make_dace_translator( + auto_optimize=True, use_metrics=use_metrics, use_zero_origin=use_zero_origin + ), ) monkeypatch.setattr( dace_workflow.compilation.DaCeCompiler, @@ -355,8 +356,9 @@ def testee(a: cases.VField, b: cases.VField): gpu=False, cached=False, auto_optimize=True, - use_metrics=use_metrics, - use_zero_origin=use_zero_origin, + translation=dace_workflow.translation.make_dace_translator( + auto_optimize=True, use_metrics=use_metrics, use_zero_origin=use_zero_origin + ), ) monkeypatch.setattr( dace_workflow.compilation.DaCeCompiler, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 55204ef8eb..dc4395027f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -24,9 +24,21 @@ from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import workflow +from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.program_processors.runners import gtfn +def test_backend_inject_translation(): + # A single-sub-component knob is set by injecting a pre-built step, not via a + # maker keyword argument; the cross-cutting device is stamped onto it. + backend = gtfn.make_gtfn_backend( + gpu=True, + translation=gtfn_module.GTFNTranslationStep(use_imperative_backend=True), + ) + assert backend.executor.translation.use_imperative_backend is True + assert backend.executor.translation.device_type is core_defs.DeviceType.CUDA + + def test_backend_factory_trait_device(): cpu_version = gtfn.make_gtfn_backend(gpu=False, cached=False) gpu_version = gtfn.make_gtfn_backend(gpu=True, cached=False) From 991693b00893a6854b0b8359c65ff41632a6e13a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Tue, 23 Jun 2026 07:44:51 +0200 Subject: [PATCH 3/3] refactor[next]: rename workflow.with_fields to with_changes --- src/gt4py/next/otf/workflow.py | 6 +++--- .../program_processors/runners/dace/workflow/factory.py | 2 +- src/gt4py/next/program_processors/runners/gtfn.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 4523c50ea9..e1e05039a4 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -34,9 +34,9 @@ StepT = TypeVar("StepT") -def with_fields(step: StepT, **fields: Any) -> StepT: +def with_changes(step: StepT, **changes: Any) -> StepT: """ - Return a copy of ``step`` with ``fields`` applied, where it declares them. + Return a copy of ``step`` with ``changes`` applied, where it declares them. Used by the backend/workflow builders to stamp cross-cutting configuration (e.g. ``device_type``) onto a sub-component, whether it was built by default @@ -46,7 +46,7 @@ def with_fields(step: StepT, **fields: Any) -> StepT: """ if dataclasses.is_dataclass(step) and not isinstance(step, type): names = {f.name for f in dataclasses.fields(step)} - valid = {key: value for key, value in fields.items() if key in names} + valid = {key: value for key, value in changes.items() if key in names} if valid: return dataclasses.replace(step, **valid) return step diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 6c55e9fed1..955d88b5bc 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -48,7 +48,7 @@ def make_dace_workflow( """ cmake_build_type = config.CMAKE_BUILD_TYPE if cmake_build_type is None else cmake_build_type - bare_translation = workflow.with_fields( + bare_translation = workflow.with_changes( translation if translation is not None else make_dace_translator(auto_optimize=auto_optimize), diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index e1428c3282..8f1d0bde85 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -122,7 +122,7 @@ def make_gtfn_workflow( """ cmake_build_type = config.CMAKE_BUILD_TYPE if cmake_build_type is None else cmake_build_type - bare_translation = workflow.with_fields( + bare_translation = workflow.with_changes( translation if translation is not None else gtfn_module.GTFNTranslationStep(), device_type=device_type, )