Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions docs/user/next/advanced/HackTheToolchain.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ skip_linting_transforms = SkipLinting(**same_steps)
skip_linting_transforms.step_order(DUMMY_FOP)
```

## Alternative Factory
## Alternative Workflow Steps

The builders (`make_gtfn_workflow` / `make_gtfn_backend`) expose only cross-cutting
configuration (device, caching, build type). To customize a single sub-component,
**inject** a pre-built step — the builder uses it instead of its default and stamps
the cross-cutting fields (e.g. `device_type`) onto it.

```python
class MyCodeGen: ...
Expand All @@ -55,18 +60,23 @@ class MyCodeGen: ...
class Cpp2BindingsGen: ...


class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWorkflowFactory):
translation: workflow.Workflow[
gtx.otf.definitions.CompilableProgramDef, gtx.otf.stages.ProgramSource
] = MyCodeGen()
bindings: workflow.Workflow[gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableProject] = (
Cpp2BindingsGen()
)

# Inject whole custom steps:
pure_cpp2_workflow = gtx.program_processors.runners.gtfn.make_gtfn_workflow(
cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG,
translation=MyCodeGen(),
bindings=Cpp2BindingsGen(),
)

PureCpp2WorkflowFactory(cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG)
# Or inject a configured default step to flip one knob:
GTFNTranslationStep = gtx.program_processors.runners.gtfn.gtfn_module.GTFNTranslationStep
no_transforms = gtx.program_processors.runners.gtfn.make_gtfn_backend(
translation=GTFNTranslationStep(enable_itir_transforms=False),
)
```

A pre-built backend is itself a frozen dataclass, so you can also override after the
fact with `dataclasses.replace` (e.g. `dataclasses.replace(backend, executor=...)`).

## Invent new Workflow Types

```mermaid
Expand Down
2 changes: 1 addition & 1 deletion docs/user/next/advanced/WorkflowPatterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -491,5 +491,5 @@ gtx.program_processors.runners.gtfn.run_gtfn_gpu.executor.otf_workflow??
```

```python
gtx.program_processors.runners.gtfn.GTFNBackendFactory??
gtx.program_processors.runners.gtfn.make_gtfn_backend??
```
8 changes: 1 addition & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ profiling = [
scripts = ["pyyaml>=6.0.1", "typer>=0.12.3", "packaging"]
test = [
'coverage[toml]>=7.6.1',
'factory-boy>=3.3.3',
'hypothesis>=6.0.0',
'nbmake>=1.4.6',
'nox>=2025.02.09',
Expand Down Expand Up @@ -103,7 +104,6 @@ dependencies = [
'dace>=2.0.0a3',
'deepdiff>=8.1.0',
'devtools>=0.6',
'factory-boy>=3.3.3',
"filelock>=3.18.0",
'frozendict>=2.3',
'gridtools-cpp>=2.3.9,==2.*',
Expand Down Expand Up @@ -263,12 +263,6 @@ module = 'gt4py.next.iterator.*'
ignore_errors = true
module = 'gt4py.next.iterator.runtime'

[[tool.mypy.overrides]]
ignore_missing_imports = true
implicit_reexport = true
# factory-boy is broken, see https://github.com/FactoryBoy/factory_boy/pull/1114
module = "factory.*"

[[tool.mypy.overrides]]
disallow_incomplete_defs = false
disallow_untyped_defs = false
Expand Down
7 changes: 0 additions & 7 deletions src/gt4py/next/otf/compilation/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import pathlib
from typing import Protocol, TypeVar

import factory

from gt4py._core import locking
from gt4py.next import config
from gt4py.next.otf import code_specs, definitions, stages, workflow
Expand Down Expand Up @@ -91,9 +89,4 @@ def __call__(
return func


class CompilerFactory(factory.Factory):
class Meta:
model = Compiler


class CompilationError(RuntimeError): ...
19 changes: 19 additions & 0 deletions src/gt4py/next/otf/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@
HashT = TypeVar("HashT")
DataT = TypeVar("DataT")
ArgT = TypeVar("ArgT")
StepT = TypeVar("StepT")


def with_changes(step: StepT, **changes: Any) -> StepT:
"""
Return a copy of ``step`` with ``changes`` applied, where it declares them.

Used by the backend/workflow builders to stamp cross-cutting configuration
(e.g. ``device_type``) onto a sub-component, whether it was built by default
or injected by the caller. Only fields that ``step`` declares as a dataclass
are applied; any other ``step`` (e.g. a plain callable) is returned unchanged,
so injecting an arbitrary workflow step stays safe.
"""
if dataclasses.is_dataclass(step) and not isinstance(step, type):
names = {f.name for f in dataclasses.fields(step)}
valid = {key: value for key, value in changes.items() if key in names}
if valid:
return dataclasses.replace(step, **valid)
return step


def make_step(function: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[StartT, EndT]:
Expand Down
10 changes: 2 additions & 8 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import functools
from typing import Any, Final, Optional

import factory
import numpy as np

from gt4py._core import definitions as core_defs
Expand Down Expand Up @@ -284,13 +283,8 @@ def _not_implemented_for_device_type(self) -> NotImplementedError:
)


class GTFNTranslationStepFactory(factory.Factory[GTFNTranslationStep]):
class Meta:
model = GTFNTranslationStep
translate_program_cpu: Final[definitions.TranslationStep] = GTFNTranslationStep()


translate_program_cpu: Final[definitions.TranslationStep] = GTFNTranslationStepFactory() # type: ignore[assignment] # factory-boy typing not precise enough

translate_program_gpu: Final[definitions.TranslationStep] = GTFNTranslationStepFactory( # type: ignore[assignment] # factory-boy typing not precise enough
translate_program_gpu: Final[definitions.TranslationStep] = GTFNTranslationStep(
device_type=core_defs.DeviceType.CUDA
)
2 changes: 1 addition & 1 deletion src/gt4py/next/program_processors/formatters/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@program_formatter.program_formatter
def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str:
# TODO(tehrengruber): This is a little ugly. Revisit.
gtfn_translation = gtfn.GTFNBackendFactory().executor.translation # type: ignore[attr-defined]
gtfn_translation = gtfn.make_gtfn_workflow().translation
assert isinstance(gtfn_translation, GTFNTranslationStep)
return gtfn_translation.generate_stencil_source(
program,
Expand Down
182 changes: 53 additions & 129 deletions src/gt4py/next/program_processors/runners/dace/workflow/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,171 +8,95 @@

from __future__ import annotations

import warnings
from typing import Any, Final

import factory

import gt4py.next.custom_layout_allocators as next_allocators
from gt4py._core import definitions as core_defs
from gt4py.next import backend, common, config
from gt4py.next.otf import stages, workflow
from gt4py.next.program_processors.runners.dace.workflow.factory import DaCeWorkflowFactory


class DaCeBackendFactory(factory.Factory):
"""
Workflow factory for the GTIR-DaCe backend.

Several parameters are inherithed from `backend.Backend`, see below the specific ones.

Args:
auto_optimize: Enables the SDFG transformation pipeline.
"""

class Meta:
model = backend.Backend

class Params:
name_device = "cpu"
name_cached = ""
name_postfix = ""
gpu = factory.Trait(
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA,
name_device="gpu",
)
cached = factory.Trait(
executor=factory.LazyAttribute(
lambda o: workflow.CachedStep.in_memory(
o.otf_workflow, input_fingerprinter=o.key_function
)
),
name_cached="_cached",
)
device_type = core_defs.DeviceType.CPU
key_function = stages.fast_compilable_program_fingerprinter
otf_workflow = factory.SubFactory(
DaCeWorkflowFactory,
device_type=factory.SelfAttribute("..device_type"),
auto_optimize=factory.SelfAttribute("..auto_optimize"),
)
auto_optimize = factory.Trait(name_postfix="_opt")

name = factory.LazyAttribute(
lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}"
)

executor = factory.LazyAttribute(lambda o: o.otf_workflow)
allocator = next_allocators.StandardCPUFieldBufferAllocator()
transforms = backend.DEFAULT_TRANSFORMS
from gt4py.next import backend
from gt4py.next.otf import definitions, stages, workflow
from gt4py.next.program_processors.runners.dace.workflow.factory import make_dace_workflow
from gt4py.next.program_processors.runners.dace.workflow.translation import make_dace_translator


def make_dace_backend(
gpu: bool,
*,
cached: bool = True,
auto_optimize: bool = True,
async_sdfg_call: bool = True,
optimization_args: dict[str, Any] | None = None,
unstructured_horizontal_has_unit_stride: bool = config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE,
use_metrics: bool = True,
use_zero_origin: bool = False,
use_max_domain_range_on_unstructured_shift: bool | None = None,
name_postfix: str = "",
translation: definitions.TranslationStep | None = None,
executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram]
| None = None,
) -> backend.Backend:
"""Customize the dace backend with the given configuration parameters.
"""Build the GTIR-DaCe backend for the given device and configuration.

Cross-cutting configuration is passed as keyword arguments. To customize the
translator-local options (async SDFG call, metrics, zero-origin, ...), inject a
pre-built translator via ``translation=`` (see `make_dace_translator`); its
``device_type`` and ``auto_optimize`` are set to match ``gpu``/``auto_optimize``.
Pass ``executor=`` to replace the whole executor workflow.

Args:
gpu: Enable GPU transformations and code generation.
cached: Cache the lowered SDFG as a JSON file and the compiled programs.
auto_optimize: Enable the SDFG auto-optimize pipeline.
async_sdfg_call: Make an asynchronous SDFG call on GPU to allow overlapping
of GPU kernel execution with the Python driver code.
optimization_args: A `dict` containing configuration parameters for
the SDFG auto-optimize pipeline, see `gt_auto_optimize()`.
unstructured_horizontal_has_unit_stride: When the memory layout has unit stride
in the horizontal dimension, replace the field stride symbol with '1'.
use_metrics: Add SDFG instrumentation to collect the metric for stencil
compute time.
use_zero_origin: Can be set to `True` when all fields passed as program
arguments have zero-based origin. This setting will skip generation
of range start-symbols `_range_0` since they can be assumed to be zero.

Note that `gt_auto_optimize()` parameters that are derived from GT4Py configuration
cannot be overriden, and therefore cannot appear here. Thus, this function will
throw an exception if called with any argument included in `gt_optimization_args`.

Returns:
A dace backend with custom configuration for the target device.
A dace backend for the target device.
"""
allocator: next_allocators.FieldBufferAllocatorProtocol
device_type: core_defs.DeviceType
if gpu:
allocator = next_allocators.StandardGPUFieldBufferAllocator()
device_type = core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA
name_device = "gpu"
else:
allocator = next_allocators.StandardCPUFieldBufferAllocator()
device_type = core_defs.DeviceType.CPU
name_device = "cpu"

# The `gt_optimization_args` set contains the parameters of `gt_auto_optimize()`
# that are derived from the gt4py configuration, and therefore cannot be customized.
gt_optimization_args: Final[set[str]] = {"gpu", "constant_symbols", "unit_strides_kind"}

if optimization_args is None:
optimization_args = {}
elif optimization_args and not auto_optimize:
warnings.warn("Optimizations args given, but auto-optimize is disabled.", stacklevel=2)
elif intersect_args := gt_optimization_args.intersection(optimization_args.keys()):
raise ValueError(
f"The following optimization arguments cannot be overriden: {intersect_args}."
if executor is None:
otf_workflow = make_dace_workflow(
device_type=device_type,
auto_optimize=auto_optimize,
cached_translation=cached,
translation=translation,
)
executor = (
workflow.CachedStep.in_memory(
otf_workflow, input_fingerprinter=stages.fast_compilable_program_fingerprinter
)
if cached
else otf_workflow
)

# Set `unit_strides_kind` based on the gt4py env configuration.
optimization_args = optimization_args | {
"unit_strides_kind": common.DimensionKind.HORIZONTAL
if unstructured_horizontal_has_unit_stride
else None
}

return DaCeBackendFactory( # type: ignore[return-value] # factory-boy typing not precise enough
gpu=gpu,
cached=cached,
auto_optimize=auto_optimize,
otf_workflow__cached_translation=cached,
otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False),
otf_workflow__bare_translation__auto_optimize_args=optimization_args,
otf_workflow__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride,
otf_workflow__bare_translation__use_metrics=use_metrics,
otf_workflow__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin,
otf_workflow__bare_translation__use_max_domain_range_on_unstructured_shift=use_max_domain_range_on_unstructured_shift,
name_cached = "_cached" if cached else ""
name_opt = "_opt" if auto_optimize else ""
return backend.Backend(
name=f"run_dace_{name_device}{name_cached}{name_opt}{name_postfix}",
executor=executor,
allocator=allocator,
transforms=backend.DEFAULT_TRANSFORMS,
)


run_dace_cpu = make_dace_backend(
gpu=False,
cached=False,
auto_optimize=True,
async_sdfg_call=False,
)
run_dace_cpu_noopt = make_dace_backend(
gpu=False,
cached=False,
auto_optimize=False,
async_sdfg_call=False,
)
run_dace_cpu_cached = make_dace_backend(
gpu=False,
cached=True,
auto_optimize=True,
async_sdfg_call=False,
)
run_dace_cpu = make_dace_backend(gpu=False, cached=False, auto_optimize=True)
run_dace_cpu_noopt = make_dace_backend(gpu=False, cached=False, auto_optimize=False)
run_dace_cpu_cached = make_dace_backend(gpu=False, cached=True, auto_optimize=True)

run_dace_gpu = make_dace_backend(
gpu=True,
cached=False,
auto_optimize=True,
async_sdfg_call=True,
translation=make_dace_translator(async_sdfg_call=True),
)
run_dace_gpu_noopt = make_dace_backend(
gpu=True,
cached=False,
auto_optimize=False,
async_sdfg_call=True,
translation=make_dace_translator(async_sdfg_call=True),
)
run_dace_gpu_cached = make_dace_backend(
gpu=True,
cached=True,
auto_optimize=True,
async_sdfg_call=True,
translation=make_dace_translator(async_sdfg_call=True),
)
Loading
Loading