Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6a35549
refactor[next]: otf split build and load
havogt Apr 27, 2026
d251057
refactorings to combine load+decoration
havogt Apr 27, 2026
b8efa44
BuildArtifact can materialize() itself
havogt Apr 27, 2026
915ce27
remove OTFCompileWorkflow
havogt Apr 27, 2026
c51ac55
cleanup docstrings
havogt Apr 27, 2026
9b842e2
refactor: remove lazy import
havogt Apr 27, 2026
c022bde
apply pre-commit
havogt Apr 27, 2026
a107dd6
update roundtrip
havogt Apr 27, 2026
fa18067
Merge remote-tracking branch 'upstream/main' into worktree-otf-build-…
havogt Apr 27, 2026
efecdf5
separate gtfn from generic compiler
havogt Apr 27, 2026
68ead69
don't serialize/deserialize dace in the same process
havogt Apr 27, 2026
4702660
don't search for sdfg in materialize
havogt Apr 27, 2026
77b6235
restore generic compiler
havogt Apr 27, 2026
3869d67
cleanup
havogt Apr 27, 2026
9017b33
more cleanup
havogt Apr 27, 2026
5bffec0
add tests
havogt Apr 27, 2026
7f219d7
cleanup
havogt Apr 27, 2026
69c5ae8
avoid decoration needs picklable
havogt Apr 27, 2026
e05a336
cleanup
havogt Apr 27, 2026
7e9fa08
cleanup
havogt Apr 27, 2026
8f35cb7
Build->Compile, materialize->load
havogt Apr 27, 2026
29508bf
cleanup
havogt Apr 27, 2026
9f4e776
use tmp_path fixture
havogt Apr 28, 2026
9c9234d
refactor roundtrip to resepect picklability
havogt Apr 28, 2026
528dcd5
Merge branch 'main' into worktree-otf-build-finalize-split
havogt Apr 28, 2026
edd9e93
sdfg as part of artifact (because of upcoming change)
havogt Apr 28, 2026
25f8cc1
Merge upstream/main into worktree-otf-build-finalize-split
havogt Jun 19, 2026
57c55ba
refactor[next]: lift dace live-program cache out of CompilationArtifact
havogt Jun 19, 2026
a3e2234
docs: drop reST roles from docstrings; emphasize Google style in AGEN…
havogt Jun 19, 2026
21de820
docs: drop remaining reST roles from PR-introduced docstrings
havogt Jun 19, 2026
c125a2a
refactor[next]: drop Path/str round-trip in DaCeCompiler
havogt Jun 25, 2026
3c65696
refactor[next]: store library_path on DaCeCompilationArtifact
havogt Jun 25, 2026
1a13a63
docs[next]: explain why DaCeCompilationArtifact has _live_program_cache
havogt Jun 25, 2026
c9dc7be
docs[next]: TODO to drop _live_program_cache if dace stops renaming
havogt Jun 25, 2026
13ca4f8
Merge remote-tracking branch 'upstream/main' into worktree-otf-build-…
havogt Jun 25, 2026
46a2578
feat[next]: make CompilationArtifact Protocol runtime-checkable
havogt Jun 26, 2026
bce7122
refactor[next]: drop unused CompilerFactory
havogt Jun 26, 2026
d8c1789
refactor[next]: skip dlopen at dace compile time
havogt Jun 26, 2026
f9e7f25
refactor[next]: drop _live_program_cache in dace compilation artifact
havogt Jun 26, 2026
0451bff
Merge remote-tracking branch 'upstream/main' into worktree-otf-build-…
havogt Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ If a command above is wrong for your environment, fix `pyproject.toml`,
## Do

- Run `uv run pre-commit run` on staged files before claiming a task done.
- Write docstrings in **Google style** (rendered via Sphinx-napoleon). Use
Comment thread
havogt marked this conversation as resolved.
`Args:` / `Returns:` / `Raises:` / `Examples:` sections — not the reST
`:param:` / `:returns:` / `:raises:` variants. **Do not** use Sphinx
cross-reference roles like `` :class:`Foo` ``, `` :meth:`bar` ``,
`` :func:`baz` `` inside docstrings; the only allowed inline markup is
plain `` `literal` `` backticks and bulleted lists (see
[`CODING_GUIDELINES.md`](CODING_GUIDELINES.md) §3.8).
- When touching `gt4py.cartesian`, `gt4py.next`, `gt4py.eve`, `gt4py.storage`,
or `gt4py._core`, run the matching `nox -s test_<subpackage>` session
before opening the PR.
Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/next/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,17 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]:
@dataclasses.dataclass(frozen=True)
class Backend(Generic[core_defs.DeviceTypeT]):
name: str
executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram]
executor: workflow.Workflow[definitions.CompilableProgramDef, stages.CompilationArtifact]
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef]

def compile(
self, program: definitions.IRDefinitionT, compile_time_args: arguments.CompileTimeArgs
) -> stages.ExecutableProgram:
return self.executor(
artifact = self.executor(
self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args))
)
return artifact.load()

@property
def __gt_allocator__(
Expand Down
68 changes: 47 additions & 21 deletions src/gt4py/next/otf/compilation/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@
import pathlib
from typing import Protocol, TypeVar

import factory

from gt4py._core import locking
from gt4py._core import definitions as core_defs, locking
from gt4py.next import config
from gt4py.next.otf import code_specs, definitions, stages, workflow
from gt4py.next.otf.compilation import build_data, cache, importer


T = TypeVar("T")


def is_compiled(data: build_data.BuildData) -> bool:
return data.status >= build_data.BuildStatus.COMPILED

Expand All @@ -45,27 +40,58 @@ def __call__(


@dataclasses.dataclass(frozen=True)
class Compiler(
class CPPCompilationArtifact:
"""On-disk result of a CPP-style compilation: a Python extension module.

The default ``load`` is an ``importlib`` import + entry-point lookup;
backends override to apply their own calling convention.
"""

src_dir: pathlib.Path
module: pathlib.Path
entry_point_name: str
device_type: core_defs.DeviceType

def load(self) -> stages.ExecutableProgram:
"""Import the .so and return the raw entry point.

Must run in the process that will call the returned program: the
module is registered in that process's ``sys.modules`` under the
``gt4py.__compiled_programs__.`` prefix.
"""
m = importer.import_from_path(
self.src_dir / self.module,
sys_modules_prefix="gt4py.__compiled_programs__.",
)
return getattr(m, self.entry_point_name)


@dataclasses.dataclass(frozen=True)
class CPPCompiler(
workflow.ChainableWorkflowMixin[
stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec],
stages.ExecutableProgram,
CPPCompilationArtifact,
],
workflow.ReplaceEnabledWorkflowMixin[
stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec],
stages.ExecutableProgram,
CPPCompilationArtifact,
],
definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec],
):
"""Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``."""
"""Drive a CPP-style build system into a ``CPPCompilationArtifact``.

Backends override ``_make_artifact`` to use their own artifact subclass.
"""

cache_lifetime: config.BuildCacheLifetime
builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec]
device_type: core_defs.DeviceType
force_recompile: bool = False

def __call__(
self,
inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec],
) -> stages.ExecutableProgram:
) -> CPPCompilationArtifact:
src_dir = cache.get_cache_folder(inp, self.cache_lifetime)

# If we are compiling the same program at the same time (e.g. multiple MPI ranks),
Expand All @@ -83,17 +109,17 @@ def __call__(
f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'."
)

m = importer.import_from_path(
src_dir / new_data.module, sys_modules_prefix="gt4py.__compiled_programs__."
)
func = getattr(m, new_data.entry_point_name)

return func
return self._make_artifact(src_dir, new_data.module, new_data.entry_point_name)


class CompilerFactory(factory.Factory):
class Meta:
model = Compiler
def _make_artifact(
self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str
) -> CPPCompilationArtifact:
return CPPCompilationArtifact(
src_dir=src_dir,
module=module,
entry_point_name=entry_point_name,
device_type=self.device_type,
)


class CompilationError(RuntimeError): ...
11 changes: 8 additions & 3 deletions src/gt4py/next/otf/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,17 @@ def __call__(

class CompilationStep(
workflow.Workflow[
stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.ExecutableProgram
stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.CompilationArtifact
],
Protocol[CodeSpecT, TargetCodeSpecT],
):
"""Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram)."""
"""Run the build system and produce a ``stages.CompilationArtifact``.

Each backend defines its own concrete artifact dataclass (frozen,
picklable, with a ``load`` method); they all satisfy the
``stages.CompilationArtifact`` Protocol structurally.
"""

def __call__(
self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT]
) -> stages.ExecutableProgram: ...
) -> stages.CompilationArtifact: ...
7 changes: 4 additions & 3 deletions src/gt4py/next/otf/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@


@dataclasses.dataclass(frozen=True)
class OTFCompileWorkflow(workflow.NamedStepSequence):
class OTFCompileWorkflow(
workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.CompilationArtifact]
):
"""The typical compiled backend steps composed into a workflow."""

translation: definitions.TranslationStep
bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject]
compilation: workflow.Workflow[stages.CompilableProject, stages.ExecutableProgram]
decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram]
compilation: workflow.Workflow[stages.CompilableProject, stages.CompilationArtifact]
30 changes: 29 additions & 1 deletion src/gt4py/next/otf/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@

import dataclasses
from collections.abc import Callable
from typing import TYPE_CHECKING, Final, Generic, Optional, Protocol, TypeAlias, TypeVar
from typing import (
TYPE_CHECKING,
Final,
Generic,
Optional,
Protocol,
TypeAlias,
TypeVar,
runtime_checkable,
)

from gt4py.next import common, fingerprinting
from gt4py.next.iterator import ir as itir
Expand Down Expand Up @@ -129,6 +138,25 @@ def build(self) -> None: ...
ExecutableProgram: TypeAlias = Callable


@runtime_checkable
class CompilationArtifact(Protocol):
Comment thread
havogt marked this conversation as resolved.
"""The output of an ``OTFCompileWorkflow``.

Each backend defines its own concrete artifact dataclass; all share this
Protocol. Implementations are frozen dataclasses, picklable, and carry no
live process-bound state — that is reconstructed by ``load``, which
returns a directly-callable ``ExecutableProgram`` taking gt4py-shaped
arguments.

The one current exception is ``RoundtripArtifact`` when it is configured
with a ``dispatch_backend``: that field holds a ``Backend`` reference
whose role belongs at the runner / load-time seam, not in the artifact
itself.
"""

def load(self) -> ExecutableProgram: ...


def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]:
"""
Filter out multiple occurrences of the same ``interface.LibraryDependency``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,25 @@
from __future__ import annotations

import dataclasses
import json
import os
import pathlib
import warnings
from collections.abc import Callable, MutableSequence, Sequence
from typing import Any

import dace
import dace.codegen.compiler as dace_compiler
import factory

from gt4py._core import definitions as core_defs, locking
from gt4py.next import common, config
from gt4py.next.otf import code_specs, definitions, stages, workflow
from gt4py.next.otf.compilation import cache as gtx_cache
from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon
from gt4py.next.program_processors.runners.dace.workflow import (
common as gtx_wfdcommon,
decoration as gtx_wfddecoration,
)


def _add_tx_markers(sdfg: dace.SDFG) -> None:
Expand Down Expand Up @@ -69,7 +75,7 @@ def __init__(
self,
program: dace.CompiledSDFG,
bind_func_name: str,
binding_source: stages.BindingSource[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
binding_source_code: str,
):
self.sdfg_program = program

Expand All @@ -78,9 +84,10 @@ def __init__(
# This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`.
self.sdfg_argtypes = list(program.sdfg.arglist().values())

# Note that `binding_source` contains Python code tailored to this specific SDFG.
# Here we dinamically compile this function and add it to the compiled program.
exec(binding_source.source_code, global_namespace := {}) # type: ignore[var-annotated]
# The binding source code is Python tailored to this specific SDFG.
# We dynamically compile that function and add it to the compiled program.
global_namespace: dict[str, Any] = {}
exec(binding_source_code, global_namespace)
self.update_sdfg_ctype_arglist = global_namespace[bind_func_name]
# For debug purpose, we set a unique module name on the compiled function.
self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder)
Expand Down Expand Up @@ -128,19 +135,46 @@ def __call__(self, **kwargs: Any) -> None:
assert result is None


@dataclasses.dataclass(frozen=True)
class DaCeCompilationArtifact:
"""Result of a DaCe compilation: build folder + library path + SDFG bindings + the SDFG itself.

The SDFG is carried inline as JSON because dace's load path
(``get_program_handle``) needs an SDFG instance to wrap into the
returned ``CompiledSDFG``, and the build folder may not contain a
``program.sdfg(z)`` dump under the upcoming minimal-build-dir mode.
"""

build_folder: pathlib.Path
library_path: pathlib.Path
sdfg_json: str
binding_source_code: str
bind_func_name: str
device_type: core_defs.DeviceType

def load(self) -> stages.ExecutableProgram:
# TODO(phimuell): Drop ``sdfg_json`` from the artifact once dace
# exposes a load path that doesn't require an SDFG instance to wrap
# into the returned ``CompiledSDFG``.
sdfg = dace.SDFG.from_json(json.loads(self.sdfg_json))
sdfg_program = dace_compiler.get_program_handle(self.library_path, sdfg)
program = CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code)
return gtx_wfddecoration.convert_args(program, device=self.device_type)


@dataclasses.dataclass(frozen=True)
class DaCeCompiler(
workflow.ChainableWorkflowMixin[
stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
CompiledDaceProgram,
DaCeCompilationArtifact,
],
workflow.ReplaceEnabledWorkflowMixin[
stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
CompiledDaceProgram,
DaCeCompilationArtifact,
],
definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
):
"""Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``."""
"""Run the DaCe build system and produce an on-disk ``DaCeCompilationArtifact``."""

bind_func_name: str
cache_lifetime: config.BuildCacheLifetime
Expand All @@ -155,13 +189,13 @@ class DaCeCompiler(
def __call__(
self,
inp: stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
) -> CompiledDaceProgram:
) -> DaCeCompilationArtifact:
with gtx_wfdcommon.dace_context(
device_type=self.device_type,
cmake_build_type=self.cmake_build_type,
):
sdfg_build_folder = gtx_cache.get_cache_folder(inp, self.cache_lifetime)
sdfg_build_folder.mkdir(parents=True, exist_ok=True)
pathlib.Path(sdfg_build_folder).mkdir(parents=True, exist_ok=True)

sdfg = dace.SDFG.from_json(inp.program_source.source_code)

Expand All @@ -171,13 +205,21 @@ def __call__(

sdfg.build_folder = sdfg_build_folder
with locking.lock(sdfg_build_folder):
sdfg_program = sdfg.compile(validate=False)
sdfg.compile(validate=False, return_program_handle=False)
# ``build_folder_mode`` is set by ``dace_context``; resolve the library
# path here so ``get_binary_name`` sees the same mode dace built under.
library_path = dace_compiler.get_binary_name(
object_folder=sdfg_build_folder, sdfg_name=sdfg.name
)

assert inp.binding_source is not None
return CompiledDaceProgram(
sdfg_program,
self.bind_func_name,
inp.binding_source,
return DaCeCompilationArtifact(
build_folder=pathlib.Path(sdfg_build_folder),
library_path=library_path,
sdfg_json=json.dumps(inp.program_source.source_code),
binding_source_code=inp.binding_source.source_code,
bind_func_name=self.bind_func_name,
device_type=self.device_type,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

import functools
from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence

import numpy as np

Expand All @@ -18,14 +18,16 @@
from gt4py.next.instrumentation import metrics
from gt4py.next.otf import stages
from gt4py.next.program_processors.runners.dace import sdfg_callable
from gt4py.next.program_processors.runners.dace.workflow import (
common as gtx_wfdcommon,
compilation as gtx_wfdcompilation,
)
from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon


if TYPE_CHECKING:
# Type-only: a top-level import would cycle with ``compilation``.
from gt4py.next.program_processors.runners.dace.workflow.compilation import CompiledDaceProgram


def convert_args(
fun: gtx_wfdcompilation.CompiledDaceProgram,
fun: CompiledDaceProgram,
device: core_defs.DeviceType = core_defs.DeviceType.CPU,
) -> stages.ExecutableProgram:
# Retieve metrics level from GT4Py environment variable.
Expand Down
Loading