diff --git a/AGENTS.md b/AGENTS.md index 7083185fbb..85321d3e5d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -65,6 +65,13 @@ If a command above is wrong for your environment, fix `pyproject.toml`, ## Do - Run `uv run pre-commit run` on staged files before claiming a task done. +- Write docstrings in **Google style** (rendered via Sphinx-napoleon). Use + `Args:` / `Returns:` / `Raises:` / `Examples:` sections — not the reST + `:param:` / `:returns:` / `:raises:` variants. **Do not** use Sphinx + cross-reference roles like `` :class:`Foo` ``, `` :meth:`bar` ``, + `` :func:`baz` `` inside docstrings; the only allowed inline markup is + plain `` `literal` `` backticks and bulleted lists (see + [`CODING_GUIDELINES.md`](CODING_GUIDELINES.md) §3.8). - When touching `gt4py.cartesian`, `gt4py.next`, `gt4py.eve`, `gt4py.storage`, or `gt4py._core`, run the matching `nox -s test_` session before opening the PR. diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 6123da97e0..ae599ece6d 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -147,16 +147,17 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): name: str - executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.CompilationArtifact] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef] def compile( self, program: definitions.IRDefinitionT, compile_time_args: arguments.CompileTimeArgs ) -> stages.ExecutableProgram: - return self.executor( + artifact = self.executor( self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args)) ) + return artifact.load() @property def __gt_allocator__( diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..98f247fb72 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -12,17 +12,12 @@ import pathlib from typing import Protocol, TypeVar -import factory - -from gt4py._core import locking +from gt4py._core import definitions as core_defs, locking from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer -T = TypeVar("T") - - def is_compiled(data: build_data.BuildData) -> bool: return data.status >= build_data.BuildStatus.COMPILED @@ -45,27 +40,58 @@ def __call__( @dataclasses.dataclass(frozen=True) -class Compiler( +class CPPCompilationArtifact: + """On-disk result of a CPP-style compilation: a Python extension module. + + The default ``load`` is an ``importlib`` import + entry-point lookup; + backends override to apply their own calling convention. + """ + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + device_type: core_defs.DeviceType + + def load(self) -> stages.ExecutableProgram: + """Import the .so and return the raw entry point. + + Must run in the process that will call the returned program: the + module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. + """ + m = importer.import_from_path( + self.src_dir / self.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return getattr(m, self.entry_point_name) + + +@dataclasses.dataclass(frozen=True) +class CPPCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.ExecutableProgram, + CPPCompilationArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.ExecutableProgram, + CPPCompilationArtifact, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + """Drive a CPP-style build system into a ``CPPCompilationArtifact``. + + Backends override ``_make_artifact`` to use their own artifact subclass. + """ cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + device_type: core_defs.DeviceType force_recompile: bool = False def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> stages.ExecutableProgram: + ) -> CPPCompilationArtifact: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -83,17 +109,17 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - m = importer.import_from_path( - src_dir / new_data.module, sys_modules_prefix="gt4py.__compiled_programs__." - ) - func = getattr(m, new_data.entry_point_name) - - return func + return self._make_artifact(src_dir, new_data.module, new_data.entry_point_name) - -class CompilerFactory(factory.Factory): - class Meta: - model = Compiler + def _make_artifact( + self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str + ) -> CPPCompilationArtifact: + return CPPCompilationArtifact( + src_dir=src_dir, + module=module, + entry_point_name=entry_point_name, + device_type=self.device_type, + ) class CompilationError(RuntimeError): ... diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 11b42dc6ce..e85e372ebd 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -57,12 +57,17 @@ def __call__( class CompilationStep( workflow.Workflow[ - stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.ExecutableProgram + stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.CompilationArtifact ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" + """Run the build system and produce a ``stages.CompilationArtifact``. + + Each backend defines its own concrete artifact dataclass (frozen, + picklable, with a ``load`` method); they all satisfy the + ``stages.CompilationArtifact`` Protocol structurally. + """ def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> stages.ExecutableProgram: ... + ) -> stages.CompilationArtifact: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 79cd17162b..0b809e4731 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,10 +14,11 @@ @dataclasses.dataclass(frozen=True) -class OTFCompileWorkflow(workflow.NamedStepSequence): +class OTFCompileWorkflow( + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.CompilationArtifact] +): """The typical compiled backend steps composed into a workflow.""" translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, stages.ExecutableProgram] - decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram] + compilation: workflow.Workflow[stages.CompilableProject, stages.CompilationArtifact] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index c0bdddee1c..cc631da4a8 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -10,7 +10,16 @@ import dataclasses from collections.abc import Callable -from typing import TYPE_CHECKING, Final, Generic, Optional, Protocol, TypeAlias, TypeVar +from typing import ( + TYPE_CHECKING, + Final, + Generic, + Optional, + Protocol, + TypeAlias, + TypeVar, + runtime_checkable, +) from gt4py.next import common, fingerprinting from gt4py.next.iterator import ir as itir @@ -129,6 +138,25 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable +@runtime_checkable +class CompilationArtifact(Protocol): + """The output of an ``OTFCompileWorkflow``. + + Each backend defines its own concrete artifact dataclass; all share this + Protocol. Implementations are frozen dataclasses, picklable, and carry no + live process-bound state — that is reconstructed by ``load``, which + returns a directly-callable ``ExecutableProgram`` taking gt4py-shaped + arguments. + + The one current exception is ``RoundtripArtifact`` when it is configured + with a ``dispatch_backend``: that field holds a ``Backend`` reference + whose role belongs at the runner / load-time seam, not in the artifact + itself. + """ + + def load(self) -> ExecutableProgram: ... + + def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index b8e18382d8..1c9daf6af2 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -9,19 +9,25 @@ from __future__ import annotations import dataclasses +import json import os +import pathlib import warnings from collections.abc import Callable, MutableSequence, Sequence from typing import Any import dace +import dace.codegen.compiler as dace_compiler import factory from gt4py._core import definitions as core_defs, locking from gt4py.next import common, config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache -from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon +from gt4py.next.program_processors.runners.dace.workflow import ( + common as gtx_wfdcommon, + decoration as gtx_wfddecoration, +) def _add_tx_markers(sdfg: dace.SDFG) -> None: @@ -69,7 +75,7 @@ def __init__( self, program: dace.CompiledSDFG, bind_func_name: str, - binding_source: stages.BindingSource[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], + binding_source_code: str, ): self.sdfg_program = program @@ -78,9 +84,10 @@ def __init__( # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. self.sdfg_argtypes = list(program.sdfg.arglist().values()) - # Note that `binding_source` contains Python code tailored to this specific SDFG. - # Here we dinamically compile this function and add it to the compiled program. - exec(binding_source.source_code, global_namespace := {}) # type: ignore[var-annotated] + # The binding source code is Python tailored to this specific SDFG. + # We dynamically compile that function and add it to the compiled program. + global_namespace: dict[str, Any] = {} + exec(binding_source_code, global_namespace) self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] # For debug purpose, we set a unique module name on the compiled function. self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) @@ -128,19 +135,46 @@ def __call__(self, **kwargs: Any) -> None: assert result is None +@dataclasses.dataclass(frozen=True) +class DaCeCompilationArtifact: + """Result of a DaCe compilation: build folder + library path + SDFG bindings + the SDFG itself. + + The SDFG is carried inline as JSON because dace's load path + (``get_program_handle``) needs an SDFG instance to wrap into the + returned ``CompiledSDFG``, and the build folder may not contain a + ``program.sdfg(z)`` dump under the upcoming minimal-build-dir mode. + """ + + build_folder: pathlib.Path + library_path: pathlib.Path + sdfg_json: str + binding_source_code: str + bind_func_name: str + device_type: core_defs.DeviceType + + def load(self) -> stages.ExecutableProgram: + # TODO(phimuell): Drop ``sdfg_json`` from the artifact once dace + # exposes a load path that doesn't require an SDFG instance to wrap + # into the returned ``CompiledSDFG``. + sdfg = dace.SDFG.from_json(json.loads(self.sdfg_json)) + sdfg_program = dace_compiler.get_program_handle(self.library_path, sdfg) + program = CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) + return gtx_wfddecoration.convert_args(program, device=self.device_type) + + @dataclasses.dataclass(frozen=True) class DaCeCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - CompiledDaceProgram, + DaCeCompilationArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - CompiledDaceProgram, + DaCeCompilationArtifact, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], ): - """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + """Run the DaCe build system and produce an on-disk ``DaCeCompilationArtifact``.""" bind_func_name: str cache_lifetime: config.BuildCacheLifetime @@ -155,13 +189,13 @@ class DaCeCompiler( def __call__( self, inp: stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - ) -> CompiledDaceProgram: + ) -> DaCeCompilationArtifact: with gtx_wfdcommon.dace_context( device_type=self.device_type, cmake_build_type=self.cmake_build_type, ): sdfg_build_folder = gtx_cache.get_cache_folder(inp, self.cache_lifetime) - sdfg_build_folder.mkdir(parents=True, exist_ok=True) + pathlib.Path(sdfg_build_folder).mkdir(parents=True, exist_ok=True) sdfg = dace.SDFG.from_json(inp.program_source.source_code) @@ -171,13 +205,21 @@ def __call__( sdfg.build_folder = sdfg_build_folder with locking.lock(sdfg_build_folder): - sdfg_program = sdfg.compile(validate=False) + sdfg.compile(validate=False, return_program_handle=False) + # ``build_folder_mode`` is set by ``dace_context``; resolve the library + # path here so ``get_binary_name`` sees the same mode dace built under. + library_path = dace_compiler.get_binary_name( + object_folder=sdfg_build_folder, sdfg_name=sdfg.name + ) assert inp.binding_source is not None - return CompiledDaceProgram( - sdfg_program, - self.bind_func_name, - inp.binding_source, + return DaCeCompilationArtifact( + build_folder=pathlib.Path(sdfg_build_folder), + library_path=library_path, + sdfg_json=json.dumps(inp.program_source.source_code), + binding_source_code=inp.binding_source.source_code, + bind_func_name=self.bind_func_name, + device_type=self.device_type, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 103e7af33b..f9e9f7181b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence import numpy as np @@ -18,14 +18,16 @@ from gt4py.next.instrumentation import metrics from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable -from gt4py.next.program_processors.runners.dace.workflow import ( - common as gtx_wfdcommon, - compilation as gtx_wfdcompilation, -) +from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon + + +if TYPE_CHECKING: + # Type-only: a top-level import would cycle with ``compilation``. + from gt4py.next.program_processors.runners.dace.workflow.compilation import CompiledDaceProgram def convert_args( - fun: gtx_wfdcompilation.CompiledDaceProgram, + fun: CompiledDaceProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU, ) -> stages.ExecutableProgram: # Retieve metrics level from GT4Py environment variable. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index e473e1493c..6238871b8f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -17,10 +17,7 @@ from gt4py.next import config from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.compilation import cache -from gt4py.next.program_processors.runners.dace.workflow import ( - bindings as bindings_step, - decoration as decoration_step, -) +from gt4py.next.program_processors.runners.dace.workflow import bindings as bindings_step from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, ) @@ -78,9 +75,3 @@ class Params: device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - decoration_step.convert_args, - device=o.device_type, - ) - ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c4b4d3d698..ffcd181b3d 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import functools +import dataclasses +import pathlib from typing import Any import factory @@ -103,6 +104,30 @@ def extract_connectivity_args( return args +@dataclasses.dataclass(frozen=True) +class GTFNCompilationArtifact(compiler.CPPCompilationArtifact): + def load(self) -> stages.ExecutableProgram: + return convert_args(super().load(), device=self.device_type) + + +@dataclasses.dataclass(frozen=True) +class GTFNCompiler(compiler.CPPCompiler): + def _make_artifact( + self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str + ) -> GTFNCompilationArtifact: + return GTFNCompilationArtifact( + src_dir=src_dir, + module=module, + entry_point_name=entry_point_name, + device_type=self.device_type, + ) + + +class GTFNCompilerFactory(factory.Factory): + class Meta: + model = GTFNCompiler + + class GTFNCompileWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFCompileWorkflow @@ -138,12 +163,10 @@ class Params: nanobind.bind_source ) compilation = factory.SubFactory( - compiler.CompilerFactory, + GTFNCompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial(convert_args, device=o.device_type) + device_type=factory.SelfAttribute("..device_type"), ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 2d8f45f631..3b7161c3bd 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -11,11 +11,10 @@ import dataclasses import functools import importlib.util -import pathlib import tempfile import textwrap -import typing -from collections.abc import Callable, Iterable +import types +from collections.abc import Iterable from typing import Any, Optional from gt4py.eve import codegen @@ -110,28 +109,20 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" -_FENCIL_CACHE: dict[int, Callable] = {} +# Caches the generated source by IR hash so re-codegen is skipped within a process. +_SOURCE_CACHE: dict[int, tuple[str, str]] = {} +# Caches the loaded module by source string so re-exec is skipped within a process. +_MODULE_CACHE: dict[str, types.ModuleType] = {} -def fencil_generator( +def _generate_source( ir: itir.Program, debug: bool, use_embedded: bool, offset_provider: common.OffsetProvider, transforms: itir_transforms.GTIRTransform, -) -> stages.ExecutableProgram: - """ - Generate a directly executable fencil from an ITIR node. - - Arguments: - ir: The iterator IR (ITIR) node. - debug: Keep module source containing fencil implementation. - extract_temporaries: Extract intermediate field values into temporaries. - use_embedded: Directly use builtins from embedded backend instead of - generic dispatcher. Gives faster performance and is easier - to debug. - offset_provider: A mapping from offset names to offset providers. - """ +) -> tuple[str, str]: + """Generate the Python source for an ITIR program. Returns ``(source_code, entry_point_name)``.""" # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism cache_key = hash( @@ -143,10 +134,10 @@ def fencil_generator( tuple(common.offset_provider_to_type(offset_provider).items()), ) ) - if cache_key in _FENCIL_CACHE: + if cache_key in _SOURCE_CACHE: if debug: - print(f"Using cached fencil for key {cache_key}") - return _FENCIL_CACHE[cache_key] # A CompiledProgram is just a Callable + print(f"Using cached source for key {cache_key}") + return _SOURCE_CACHE[cache_key] ir = transforms(ir, offset_provider=offset_provider) @@ -182,80 +173,113 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".py", encoding="utf-8", delete=False - ) as source_file: - source_file_name = source_file.name - if debug: - print(source_file_name) - offset_literals = [f'{o} = offset("{o}")' for o in offset_literals] - axis_literals = [ - f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' - for o in axis_literals_set - ] - source_file.write(header) - source_file.write("\n".join(offset_literals)) - source_file.write("\n") - source_file.write("\n".join(axis_literals)) - source_file.write("\n") - source_file.write(program) - try: - spec = importlib.util.spec_from_file_location("module.name", source_file_name) - mod = importlib.util.module_from_spec(spec) # type: ignore - spec.loader.exec_module(mod) # type: ignore - finally: - if not debug: - pathlib.Path(source_file_name).unlink(missing_ok=True) + offset_literals_src = "\n".join(f'{o} = offset("{o}")' for o in offset_literals) + axis_literals_src = "\n".join( + f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' + for o in axis_literals_set + ) + source_code = f"{header}{offset_literals_src}\n{axis_literals_src}\n{program}" assert isinstance(ir, itir.Program) - fencil_name = ir.id - fencil = getattr(mod, fencil_name) + entry_point_name = ir.id + + _SOURCE_CACHE[cache_key] = (source_code, entry_point_name) + return source_code, entry_point_name - _FENCIL_CACHE[cache_key] = fencil - return typing.cast(stages.ExecutableProgram, fencil) +def _load_module(source_code: str, debug: bool) -> types.ModuleType: + if source_code in _MODULE_CACHE: + return _MODULE_CACHE[source_code] + + if debug: + # Write to a real .py so debuggers/tracebacks have file/line info. + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: + source_file.write(source_code) + source_file_name = source_file.name + print(source_file_name) + spec = importlib.util.spec_from_file_location("module.name", source_file_name) + mod = importlib.util.module_from_spec(spec) # type: ignore[arg-type] + spec.loader.exec_module(mod) # type: ignore[union-attr] + else: + mod = types.ModuleType("roundtrip_module") + exec(compile(source_code, "", "exec"), mod.__dict__) + + _MODULE_CACHE[source_code] = mod + return mod @dataclasses.dataclass(frozen=True) -class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram]): - debug: Optional[bool] = None - use_embedded: bool = True - dispatch_backend: Optional[next_backend.Backend] = None - transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` +class RoundtripArtifact: + """Source-string artifact for the roundtrip backend. - def __call__(self, inp: definitions.CompilableProgramDef) -> stages.ExecutableProgram: - debug = config.DEBUG if self.debug is None else self.debug + The generated Python source is the artifact: picklable, re-execed on + ``load``. When ``debug`` is true, ``load`` writes a temporary ``.py`` + so debuggers/tracebacks resolve to source lines. + """ - fencil = fencil_generator( - inp.data, - offset_provider=inp.args.offset_provider, - debug=debug, - use_embedded=self.use_embedded, - transforms=self.transforms, - ) + source_code: str + entry_point_name: str + column_axis: common.Dimension | None + dispatch_backend: next_backend.Backend | None + debug: bool + + def load(self) -> stages.ExecutableProgram: + mod = _load_module(self.source_code, self.debug) + fencil = getattr(mod, self.entry_point_name) + captured_column_axis = self.column_axis + dispatch_backend = self.dispatch_backend def decorated_fencil( *args: Any, - offset_provider: dict[str, common.Connectivity], + offset_provider: dict[str, common.Connectivity | common.Dimension], out: Any = None, - column_axis: Optional[common.Dimension] = None, + column_axis: Optional[ + common.Dimension + ] = None, # TODO(tehrengruber): unused, kept for signature compat **kwargs: Any, ) -> None: if out is not None: args = (*args, out) - if not column_axis: # TODO(tehrengruber): This variable is never used. Bug? - column_axis = inp.args.column_axis fencil( *args, offset_provider=offset_provider, - backend=self.dispatch_backend, - column_axis=inp.args.column_axis, + backend=dispatch_backend, + column_axis=captured_column_axis, **kwargs, ) return decorated_fencil +@dataclasses.dataclass(frozen=True) +class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, RoundtripArtifact]): + debug: Optional[bool] = None + use_embedded: bool = True + dispatch_backend: Optional[next_backend.Backend] = None + transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` + + def __call__(self, inp: definitions.CompilableProgramDef) -> RoundtripArtifact: + debug = config.DEBUG if self.debug is None else self.debug + + source_code, entry_point_name = _generate_source( + inp.data, + offset_provider=inp.args.offset_provider, + debug=debug, + use_embedded=self.use_embedded, + transforms=self.transforms, + ) + + return RoundtripArtifact( + source_code=source_code, + entry_point_name=entry_point_name, + column_axis=inp.args.column_axis, + dispatch_backend=self.dispatch_backend, + debug=debug, + ) + + # TODO(tehrengruber): introduce factory default = next_backend.Backend( name="roundtrip", diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 49bd7b8f87..84226c4e03 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -10,6 +10,7 @@ import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind @@ -24,11 +25,13 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( - compiler.Compiler( - cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory() + compiler.CPPCompiler( + cache_lifetime=config.BuildCacheLifetime.SESSION, + builder_factory=cmake.CMakeFactory(), + device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source) + compiled_program = build_the_program(example_program_source).load() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), @@ -42,12 +45,13 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): def test_gtfn_cpp_with_compiledb(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_compiledb") build_the_program = workflow.make_step(nanobind.bind_source).chain( - compiler.Compiler( + compiler.CPPCompiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), + device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source) + compiled_program = build_the_program(example_program_source).load() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py new file mode 100644 index 0000000000..45abaf86c3 --- /dev/null +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -0,0 +1,26 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Minimal contract tests for ``CPPCompilationArtifact``.""" + +import pathlib +import pickle + +from gt4py._core import definitions as core_defs +from gt4py.next.otf.compilation import compiler + + +def test_cpp_compilation_artifact_pickle_round_trip(tmp_path: pathlib.Path): + artifact = compiler.CPPCompilationArtifact( + src_dir=tmp_path, + module=pathlib.Path("entry.so"), + entry_point_name="entry", + device_type=core_defs.DeviceType.CPU, + ) + restored = pickle.loads(pickle.dumps(artifact)) + assert restored == artifact diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index def8800c98..ed881c9495 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -114,12 +114,19 @@ def test_inlining_of_scalar_works_integration(testee_prog): hijacked_program = None + @dataclasses.dataclass(frozen=True) + class _NoOpArtifact: + """A trivial CompilationArtifact that loads to a no-op callable.""" + + def load(self): + return lambda *args, **kwargs: None + def pirate(program: toolchain.ConcreteArtifact): - # Replaces the gtfn otf_workflow: and steals the compilable program, - # then returns a dummy "CompiledProgram" that does nothing. + # Replaces the gtfn otf_workflow: steals the compilable program, then + # returns a dummy artifact whose materialization is a no-op callable. nonlocal hijacked_program hijacked_program = program - return lambda *args, **kwargs: None + return _NoOpArtifact() hacked_gtfn_backend = gtfn.GTFNBackendFactory(name_postfix="_custom", executor=pirate) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py index 1cbf9d3c2e..488c371193 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -6,9 +6,15 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Test the compilation stage of the dace backend workflow.""" +"""Tests for the compilation stage of the dace backend workflow. + +Covers the GPU TX-marker instrumentation and the picklability of +``DaCeCompilationArtifact``. +""" import contextlib +import pathlib +import pickle import unittest.mock as mock import pytest @@ -91,7 +97,6 @@ def _run_compiler( wraps=dace_wf_compilation._add_tx_markers, ) as spy, mock.patch.object(dace.SDFG, "compile", autospec=True) as compile_mock, - mock.patch.object(dace_wf_compilation, "CompiledDaceProgram"), mock.patch.object( dace_wf_compilation.gtx_wfdcommon, "dace_context", @@ -147,3 +152,18 @@ def test_compiler_skips_tx_markers_for_non_gpu_device(tmp_path): spy.assert_not_called() assert compiled_sdfg.instrument == _NONE + + +def test_dace_compilation_artifact_pickle_round_trip(tmp_path: pathlib.Path): + artifact = dace_wf_compilation.DaCeCompilationArtifact( + build_folder=tmp_path, + library_path=tmp_path / "build" / "libprogram.so", + sdfg_json="{}", + binding_source_code="def update_sdfg_args(*a, **k): ...", + bind_func_name="update_sdfg_args", + device_type=core_defs.DeviceType.CPU, + ) + + restored = pickle.loads(pickle.dumps(artifact)) + + assert restored == artifact diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index a2b2502b19..00c47102c4 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -38,8 +38,9 @@ def test_backend_factory_trait_device(): assert isinstance(gpu_version.executor.translation, workflow.CachedStep) assert gpu_version.executor.translation.step.device_type is core_defs.DeviceType.CUDA - assert cpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CPU - assert gpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CUDA + # The compilation step now also carries device_type so it can stamp the artifact. + assert cpu_version.executor.compilation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.compilation.device_type is core_defs.DeviceType.CUDA assert custom_layout_allocators.is_field_allocator_for( cpu_version.allocator, core_defs.DeviceType.CPU