diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 0cb9802423..73c5b54685 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -429,23 +429,15 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) + true_ = self.visit(node.args[1]) + false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{itir.lenient_ir_fingerprinter(cond_)}" - def create_if( - true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] - ) -> itir.FunCall: - return _map( - "if_", - (im.ref(cond_symref_name), true_, false_), - (node.args[0].type, *arg_types), + result = im.tree_map_tuple( + im.lambda_("__a", "__b")( + im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) ) - - result = lowering_utils.process_elements( - create_if, - (self.visit(node.args[1]), self.visit(node.args[2])), - node.type, - arg_types=(node.args[1].type, node.args[2].type), - ) + )(true_, false_) return im.let(cond_symref_name, cond_)(result) @@ -551,7 +543,7 @@ def _map( original_arg_types: tuple[ts.TypeSpec, ...], ) -> itir.FunCall: """ - Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. + Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_list`ing lists. """ if all( isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType)) @@ -564,7 +556,7 @@ def _map( promote_to_list(arg_type)(larg) for arg_type, larg in zip(original_arg_types, lowered_args) ) - op = im.map_(op) + op = im.map_list(op) return im.op_as_fieldop(op)(*lowered_args) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index e54c6ea3d7..d222f100df 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -53,7 +53,17 @@ def neighbors(*args): @builtin_dispatch -def map_(*args): +def map_list(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def tree_map_tuple(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def map_tuple(*args): raise BackendNotSelectedError() @@ -498,7 +508,9 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_", + "tree_map_tuple", + "map_tuple", + "map_list", "named_range", "neighbors", "reduce", diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index e01ecc181e..9086e4a7d4 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1471,8 +1471,8 @@ def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]: raise AssertionError("All lists must have the same offset.") -@builtins.map_.register(EMBEDDED) -def map_(op): +@builtins.map_list.register(EMBEDDED) +def map_list(op): def impl_(*lists): offset = _get_offset(*lists) if offset is None: diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index da13d20bb6..c0090ed3a2 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -71,7 +71,7 @@ def is_applied_map(arg: itir.Node) -> TypeGuard[_FunCallToFunCallToRef]: isinstance(arg, itir.FunCall) and isinstance(arg.fun, itir.FunCall) and isinstance(arg.fun.fun, itir.SymRef) - and arg.fun.fun.id == "map_" + and arg.fun.fun.id == "map_list" ) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 561db992d8..5266e01ffc 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -632,9 +632,19 @@ def index(dim: common.Dimension) -> itir.FunCall: return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) -def map_(op): - """Create a `map_` call.""" - return call(call("map_")(op)) +def map_list(op): + """Create a `map_list` call.""" + return call(call("map_list")(op)) + + +def tree_map_tuple(op): + """Create a `tree_map_tuple` call: tree_map_tuple(op)(tup1, tup2, ...).""" + return call(call("tree_map_tuple")(op)) + + +def map_tuple(op): + """Create a `map_tuple` call: map_tuple(op)(tup).""" + return call(call("map_tuple")(op)) def reduce(op, expr): diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 4c4219bda4..c951dcfdec 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -44,7 +44,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: if cpm.is_call_to(node.args[1], "make_const_list"): return node.args[1].args[0] if cpm.is_applied_map(node.args[1]): - # list_get(0, map_(λ(val_) → foo(val_, int64))(·__sym_1)) + # list_get(0, map_list(λ(val_) → foo(val_, int64))(·__sym_1)) # -> (λ(val_) → foo(val_, int64))(list_get(0, ·__sym_1)) lsts = node.args[1].args assert len(node.args[1].fun.args) == 1 # a single lambda in the map diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index a365cb25e3..ef58e527f2 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -95,7 +95,7 @@ def _is_collectable_expr(node: itir.Node) -> bool: if isinstance(node, itir.FunCall): # do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be # visited, to ensure symbol dependencies are recognized correctly. - # do also not collect reduce, map_ and neighbors nodes if they are left in the IR at this point, this may lead to + # do also not collect reduce, map_list and neighbors nodes if they are left in the IR at this point, this may lead to # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend # backend (single pass eager depth first visit approach), see also https://github.com/GridTools/gt4py/issues/1795 @@ -104,7 +104,7 @@ def _is_collectable_expr(node: itir.Node) -> bool: # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement # instead of an as_fieldop if cpm.is_call_to( - node, ("lift", "shift", "neighbors", "reduce", "map_", "index") + node, ("lift", "shift", "neighbors", "reduce", "map_list", "index") ) or cpm.is_applied_lift(node): return False return True diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index 4efbbe718b..69638861ff 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -18,7 +18,7 @@ @dataclasses.dataclass(frozen=True) class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ - Fuses nested `map_`s. + Fuses nested `map_list`s. Preconditions: - `FunctionDefinitions` are inlined @@ -29,7 +29,7 @@ class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrai to map(λ(a, b, c) → f(a, g(b, c)))(a, b, c) - reduce(λ(x, y) → f(x, y), init)(map_(g(z, w))(a, b)) + reduce(λ(x, y) → f(x, y), init)(map_list(g(z, w))(a, b)) to reduce(λ(x, y, z) → f(x, g(y, z)), init)(a, b) """ @@ -93,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_op = ir.Lambda(params=new_params, expr=new_body) if cpm.is_applied_map(node): return ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args + fun=ir.FunCall(fun=ir.SymRef(id="map_list"), args=[new_op]), args=new_args ) else: # is_applied_reduce(node) return ir.FunCall( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8bea65ed29..68beba6ec3 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,6 +24,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, + unroll_tuple_maps, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -169,6 +170,11 @@ def apply_common_transforms( ir = inline_lifts.InlineLifts().visit(ir) ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + # `UnrollTupleMaps` requires fully-inferred tuple types (relies on `reinfer` to see + # nested `TupleType` chains), so the offset_provider is passed for on-demand inference. + ir = unroll_tuple_maps.UnrollTupleMaps.apply( + ir, uids=uids, offset_provider_type=offset_provider_type + ) ir = dead_code_elimination.dead_code_elimination( ir, uids=uids, offset_provider_type=offset_provider_type ) # domain inference does not support dead-code @@ -282,6 +288,12 @@ def apply_fieldview_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) # required for dead-code-elimination and `prune_empty_concat_where` pass ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + # `UnrollTupleMaps` requires fully-inferred tuple types, so the offset_provider is passed for + # on-demand inference. + ir = unroll_tuple_maps.UnrollTupleMaps.apply( + ir, uids=uids, offset_provider_type=offset_provider_type + ) + ir = dead_code_elimination.dead_code_elimination( ir, offset_provider_type=offset_provider_type, uids=uids ) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index a544d511db..b62f894796 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -264,7 +264,7 @@ def applied_as_fieldop(*args): "scan": _scan, "reduce": _reduce, "neighbors": _neighbors, - "map_": _map, + "map_list": _map, "if_": _if, "make_tuple": _make_tuple, } diff --git a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py new file mode 100644 index 0000000000..461855faab --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py @@ -0,0 +1,140 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses +import functools +from typing import TypeVar + +from gt4py import eve +from gt4py.next import common, utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr: + """Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call. + + Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer. + """ + if cpm.is_call_to(expr, "make_tuple"): + return expr.args[i] + return im.tuple_get(i, expr) + + +def _tree_map_tuple_body( + f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] +) -> itir.Expr: + """Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls.""" + + def tuple_structure(type_: ts.TypeSpec) -> tuple[object, ...] | None: + if isinstance(type_, ts.TupleType): + return tuple(tuple_structure(el_type) for el_type in type_.types) + return None + + expected_structure = tuple_structure(tup_types[0]) + if any(tuple_structure(tup_type) != expected_structure for tup_type in tup_types[1:]): + raise TypeError("'tree_map_tuple' requires all arguments to have the same tuple structure.") + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: im.make_tuple(*elts), + with_path_arg=True, + ) + def mapper(*args): + *_el_types, path = args + return im.call(f)( + *(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs) + ) + + return mapper(*tup_types) + + +def _map_tuple_body( + f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] +) -> itir.Expr: + """Unroll `map_tuple(f)(t)` over top-level elements only (no recursion).""" + (tup_expr,) = tup_exprs + (tup_type,) = tup_types + return im.make_tuple( + *(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types))) + ) + + +_UNROLLERS = { + "tree_map_tuple": _tree_map_tuple_body, + "map_tuple": _map_tuple_body, +} + + +ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.Expr) + + +@dataclasses.dataclass +class UnrollTupleMaps(eve.NodeTranslator): + """Unroll tuple-map ITIR builtins (`tree_map_tuple`, `map_tuple`) into `make_tuple`.""" + + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply( + cls, + node: ProgramOrExpr, + *, + uids: utils.IDGeneratorPool | None = None, + offset_provider_type: common.OffsetProviderType | None = None, + ) -> ProgramOrExpr: + if node.type is None: + node = itir_inference.infer( + node, + offset_provider_type=offset_provider_type or {}, + allow_undeclared_symbols=not isinstance(node, itir.Program), + ) + if uids is None: + uids = utils.IDGeneratorPool() + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall): + node = self.generic_visit(node) + + builtin_name = next((name for name in _UNROLLERS if cpm.is_call_to(node.fun, name)), None) + if builtin_name is None: + return node + + assert isinstance(node.fun, itir.FunCall) + f = node.fun.args[0] + tup_args = node.args + + tup_types: list[ts.TupleType] = [] + for tup in tup_args: + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + tup_types.append(tup.type) + + # For trivial args (those that can be duplicated without cost or side effects), + # we substitute them directly into the body. This avoids leaving behind + # `tuple_get(i, make_tuple(...))` patterns that would otherwise require a + # separate cleanup pass (CollapseTuple). For non-trivial args we still + # introduce a `let` binding to avoid duplicating expensive sub-expressions. + substituted_exprs: list[itir.Expr] = [] + let_bindings: list[tuple[str, itir.Expr]] = [] + for tup in tup_args: + if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"): + substituted_exprs.append(tup) + else: + ref_name = next(self.uids["_utm"]) + let_bindings.append((ref_name, tup)) + substituted_exprs.append(im.ref(ref_name)) + + body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types) + + result = im.let(*let_bindings)(body) if let_bindings else body + itir_inference.reinfer(result) + return result diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index e204377bc0..4969270c11 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -20,7 +20,6 @@ from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts -from gt4py.next.utils import tree_map def _type_synth_arg_cache_key(type_or_synth: TypeOrTypeSynthesizer) -> int: @@ -203,7 +202,7 @@ def if_( pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType ) -> ts.DataType: if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): - return tree_map( + return utils.tree_map( collection_type=ts.TupleType, result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), )(functools.partial(if_, pred))(true_branch, false_branch) @@ -611,7 +610,7 @@ def apply_scan( @_register_builtin_type_synthesizer -def map_(op: TypeSynthesizer) -> TypeSynthesizer: +def map_list(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( *args: ts.ListType, offset_provider_type: common.OffsetProviderType @@ -629,6 +628,67 @@ def applied_map( return applied_map +def _tuple_map_synthesizer( + builtin_name: str, *, recursive: bool +) -> Callable[..., TypeOrTypeSynthesizer]: + """Shared implementation for `tree_map_tuple` (recursive) and `map_tuple` (top-level).""" + + def tuple_structure(type_: ts.TypeSpec) -> tuple[object, ...] | None: + if isinstance(type_, ts.TupleType): + return tuple(tuple_structure(el_type) for el_type in type_.types) + return None + + def ensure_same_tuple_structure(args: tuple[ts.TupleType, ...]) -> None: + expected_structure = tuple_structure(args[0]) + if any(tuple_structure(arg) != expected_structure for arg in args[1:]): + raise TypeError( + f"'{builtin_name}' requires all arguments to have the same tuple structure." + ) + + def factory(op: TypeSynthesizer) -> TypeSynthesizer: + @type_synthesizer + def applied_map( + *args: ts.TupleType, offset_provider_type: common.OffsetProviderType + ) -> ts.TupleType: + if not args: + raise TypeError(f"'{builtin_name}' requires at least one argument.") + if not recursive and len(args) != 1: + raise TypeError(f"'{builtin_name}' requires exactly one argument, got {len(args)}.") + if not all(isinstance(a, ts.TupleType) for a in args): + raise TypeError( + f"'{builtin_name}' requires all top-level arguments to be TupleType, " + f"got {[type(a).__name__ for a in args]}." + ) + if recursive: + ensure_same_tuple_structure(args) + + def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec: + return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] + + if recursive: + return utils.tree_map( # type: ignore[return-value] + leaf_op, + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), + )(*args) + + # Non-recursive: apply `op` once per top-level element. + (arg,) = args + return ts.TupleType(types=[leaf_op(el) for el in arg.types]) + + return applied_map + + return factory + + +tree_map_tuple = _register_builtin_type_synthesizer( + _tuple_map_synthesizer("tree_map_tuple", recursive=True), fun_names=["tree_map_tuple"] +) +map_tuple = _register_builtin_type_synthesizer( + _tuple_map_synthesizer("map_tuple", recursive=False), fun_names=["map_tuple"] +) + + @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @type_synthesizer diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 2e8983ccb8..fb6ae808dd 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -1245,7 +1245,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: The map operation is applied on the local dimension of input fields. In the example below, the local dimension consists of a list of neighbor values as the first argument, and a list of constant values `1.0`: - `map_(plus)(neighbors(V2E, it), make_const_list(1.0))` + `map_list(plus)(neighbors(V2E, it), make_const_list(1.0))` The `plus` operation is lowered to a tasklet inside a map that computes the domain of the local dimension (in this example, max neighbors in V2E). @@ -1303,7 +1303,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: # The dataflow we build in this class has some loose connections on input edges. # These edges are described as set of nodes, that will have to be connected to # external data source nodes passing through the map entry node of the field map. - # Similarly to `neighbors` expressions, the `map_` input edges terminate on view + # Similarly to `neighbors` expressions, the `map_list` input edges terminate on view # nodes (see `_construct_local_view` in the for-loop below), because it is simpler # than representing map-to-map edges (which require memlets with 2 pass-nodes). input_memlets = {} @@ -1330,7 +1330,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: result_node = self.state.add_access(result) if conn_type.has_skip_values: - # In case the `map_` input expressions contain skip values, we use + # In case the `map_list` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. conn_data = gtx_dace_args.connectivity_identifier(offset_type.value) conn_desc = self.sdfg.arrays[conn_data] @@ -1750,12 +1750,12 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: if isinstance(node.type, ts.ListType): # The only builtin function (so far) handled here that returns a list - # is 'make_const_list'. There are other builtin functions (map_, neighbors) + # is 'make_const_list'. There are other builtin functions (map_list, neighbors) # that return a list but they are handled in specialized visit methods. # This method (the generic visitor for builtin functions) always returns # a single value. This is also the case of 'make_const_list' expression: # it simply broadcasts a scalar on the local domain of another expression, - # for example 'map_(plus)(neighbors(V2Eₒ, it), make_const_list(1.0))'. + # for example 'map_list(plus)(neighbors(V2Eₒ, it), make_const_list(1.0))'. # Therefore we handle `ListType` as a single-element array with shape (1,) # that will be accessed in a map expression on a local domain. assert isinstance(node.type.element_type, ts.ScalarType) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_reductions.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_reductions.py index 93bdec4df2..311b1870ed 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_reductions.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_reductions.py @@ -463,7 +463,9 @@ def testee(a: cases.VField) -> cases.VField: @pytest.mark.uses_unstructured_shift -@pytest.mark.xfail(reason="Not yet supported in lowering, requires `map_`ing of inner reduce op.") +@pytest.mark.xfail( + reason="Not yet supported in lowering, requires `map_list`ing of inner reduce op." +) def test_nested_reduction_shift_first(unstructured_case): @gtx.field_operator def testee(inp: cases.EField) -> cases.EField: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 39de7bd660..53ebfb9801 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -16,7 +16,7 @@ lift, list_get, make_const_list, - map_, + map_list, multiplies, neighbors, plus, @@ -101,7 +101,7 @@ def test_sum_edges_to_vertices(program_processor, stencil): @fundef def map_neighbors(in_edges): - return reduce(plus, 0)(map_(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) + return reduce(plus, 0)(map_list(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) def test_map_neighbors(program_processor): @@ -123,7 +123,7 @@ def test_map_neighbors(program_processor): @fundef def map_make_const_list(in_edges): - return reduce(plus, 0)(map_(multiplies)(neighbors(V2E, in_edges), make_const_list(2))) + return reduce(plus, 0)(map_list(multiplies)(neighbors(V2E, in_edges), make_const_list(2))) @pytest.mark.uses_constant_fields diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 584d29e1b4..792cc43d7d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -210,10 +210,9 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.make_tuple( - im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")), - im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")), - ) + reference = im.tree_map_tuple( + im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) + )("b", "c") assert lowered_inlined.expr == reference @@ -303,7 +302,7 @@ def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.cast_("val", "int32"))))("a") + reference = im.op_as_fieldop(im.map_list(im.lambda_("val")(im.cast_("val", "int32"))))("a") assert lowered.expr == reference @@ -836,9 +835,9 @@ def foo(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64] parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - mapped = im.op_as_fieldop(im.map_("multiplies"))( + mapped = im.op_as_fieldop(im.map_list("multiplies"))( im.op_as_fieldop("make_const_list")(im.literal("1.1", "float64")), - im.op_as_fieldop(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), + im.op_as_fieldop(im.map_list("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), ) reference = im.let( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py index 01a259fcec..3a6562e6be 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -17,7 +17,7 @@ deref, if_, make_const_list, - map_, + map_list, neighbors, plus, ) @@ -70,7 +70,7 @@ def testee(): def test_write_map_neighbors_and_const_list(): def testee(inp): domain = runtime.UnstructuredDomain({E: range(2)}) - return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + return as_fieldop(lambda x, y: map_list(plus)(deref(x), deref(y)), domain)( as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), as_fieldop(lambda: make_const_list(42.0), domain)(), ) @@ -86,7 +86,7 @@ def testee(inp): def test_write_map_conditional_neighbors_and_const_list(): def testee(inp, mask): domain = runtime.UnstructuredDomain({E: range(2)}) - return as_fieldop(lambda m, x, y: map_(if_)(deref(m), deref(x), deref(y)), domain)( + return as_fieldop(lambda m, x, y: map_list(if_)(deref(m), deref(x), deref(y)), domain)( as_fieldop(lambda it: make_const_list(deref(it)), domain)(mask), as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), as_fieldop(lambda it: make_const_list(deref(it)), domain)(42.0), @@ -106,7 +106,7 @@ def testee(inp, mask): def test_write_non_mapped_conditional_neighbors_and_const_list(): """ This test-case demonstrates a non-supported pattern: - Current ITIR requires the `if_` to be `map_`ed, see `test_write_map_conditional_neighbors_and_const_list`. + Current ITIR requires the `if_` to be `map_list`ed, see `test_write_map_conditional_neighbors_and_const_list`. We keep it here for documenting corner cases of the `itir.List` implementation for future discussions. """ @@ -134,7 +134,7 @@ def testee(inp, mask): def test_write_map_const_list_and_const_list(): def testee(): domain = runtime.UnstructuredDomain({E: range(2)}) - return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + return as_fieldop(lambda x, y: map_list(plus)(deref(x), deref(y)), domain)( as_fieldop(lambda: make_const_list(1.0), domain)(), as_fieldop(lambda: make_const_list(42.0), domain)(), ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a7ccebc098..00b6f8bd86 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -138,24 +138,41 @@ def expression_test_cases(): # TODO: scan # map ( - im.map_(im.ref("plus"))(im.ref("a", int_list_type), im.ref("b", int_list_type)), + im.map_list(im.ref("plus"))(im.ref("a", int_list_type), im.ref("b", int_list_type)), int_list_type, ), ( - im.map_(im.ref("plus"))(im.call("make_const_list")(1), im.ref("b", int_list_type)), + im.map_list(im.ref("plus"))(im.call("make_const_list")(1), im.ref("b", int_list_type)), int_list_type, ), ( - im.map_(im.ref("plus"))(im.ref("a", int_list_type), im.call("make_const_list")(1)), + im.map_list(im.ref("plus"))(im.ref("a", int_list_type), im.call("make_const_list")(1)), int_list_type, ), ( - im.map_(im.ref("plus"))( + im.map_list(im.ref("plus"))( im.ref("a", int_list_type), im.ref("b", ts.ListType(element_type=int_type, offset_type=V2EDim)), ), ts.ListType(element_type=int_type, offset_type=V2EDim), ), + # tree_map_tuple + ( + im.tree_map_tuple(im.ref("plus"))( + im.ref("t1", ts.TupleType(types=[int_type, int_type])), + im.ref("t2", ts.TupleType(types=[int_type, int_type])), + ), + ts.TupleType(types=[int_type, int_type]), + ), + ( + im.tree_map_tuple(im.ref("not_"))( + im.ref( + "t", + ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), + ), + ), + ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), + ), # reduce (im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type), ( @@ -309,6 +326,24 @@ def test_expression_type(test_case): assert result.type == expected_type +@pytest.mark.parametrize( + "testee", + [ + im.tree_map_tuple(im.ref("plus"))( + im.ref("t1", ts.TupleType(types=[int_type, int_type, int_type])), + im.ref("t2", ts.TupleType(types=[int_type, int_type])), + ), + im.tree_map_tuple(im.ref("plus"))( + im.ref("t1", ts.TupleType(types=[int_type, ts.TupleType(types=[int_type, int_type])])), + im.ref("t2", ts.TupleType(types=[int_type, int_type])), + ), + ], +) +def test_tree_map_tuple_mismatched_structure_raises_type_error(testee): + with pytest.raises(TypeError, match=r"same tuple structure"): + itir_type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) + + @pytest.mark.parametrize( "test_case", [(expr, type_) for expr, type_ in expression_test_cases() if cpm.is_applied_as_fieldop(expr)], diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py index c64ab93b6a..c54de17959 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py @@ -13,11 +13,11 @@ def _map(op: ir.Expr, *args: ir.Expr) -> ir.FunCall: - return ir.FunCall(fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[op]), args=[*args]) + return ir.FunCall(fun=ir.FunCall(fun=ir.SymRef(id="map_list"), args=[op]), args=[*args]) def _map_p(op: ir.Expr | P, *args: ir.Expr | P) -> P: - return P(ir.FunCall, fun=P(ir.FunCall, fun=ir.SymRef(id="map_"), args=[op]), args=[*args]) + return P(ir.FunCall, fun=P(ir.FunCall, fun=ir.SymRef(id="map_list"), args=[op]), args=[*args]) def _reduce(op: ir.Expr, init: ir.Expr, *args: ir.Expr) -> ir.FunCall: diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py new file mode 100644 index 0000000000..d81acb4e79 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py @@ -0,0 +1,173 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py.next import common, utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_tuple_maps import UnrollTupleMaps +from gt4py.next.type_system import type_specifications as ts + + +IDim = common.Dimension("IDim") +T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +i_field = ts.FieldType(dims=[IDim], dtype=T) +i_tuple_field = ts.TupleType(types=[i_field, i_field]) + + +def _apply(expr: itir.Expr) -> itir.Expr: + return UnrollTupleMaps.apply(expr, uids=utils.IDGeneratorPool(), offset_provider_type={}) + + +def _neg(): + return im.lambda_("__a")(im.op_as_fieldop("neg")("__a")) + + +def _plus(): + return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) + + +def test_tree_map_tuple_multi_arg(): + result = _apply( + im.call(im.call("tree_map_tuple")(_plus()))( + im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) + ) + ) + + expected = im.make_tuple( + im.call(_plus())(im.tuple_get(0, "a"), im.tuple_get(0, "b")), + im.call(_plus())(im.tuple_get(1, "a"), im.tuple_get(1, "b")), + ) + assert result == expected + + +def test_tree_map_tuple_nested(): + nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) + result = _apply(im.call(im.call("tree_map_tuple")(_neg()))(im.ref("t", nested))) + + expected = im.make_tuple( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "t"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "t"))), + ), + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(1, "t"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(1, "t"))), + ), + ) + assert result == expected + + +def test_map_tuple_single_arg(): + result = _apply(im.call(im.call("map_tuple")(_neg()))(im.ref("t", i_tuple_field))) + + expected = im.make_tuple( + im.call(_neg())(im.tuple_get(0, "t")), + im.call(_neg())(im.tuple_get(1, "t")), + ) + assert result == expected + + +def test_apply_infers_uninferred_expr(): + expr = im.call(im.call("tree_map_tuple")(_neg()))( + im.make_tuple(im.ref("a", i_field), im.ref("b", i_field)) + ) + + result = UnrollTupleMaps.apply(expr, offset_provider_type={}) + + assert expr.type is None + assert result == im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")) + + +def test_map_tuple_does_not_recurse(): + nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) + g = im.lambda_("__p")(im.op_as_fieldop("plus")(im.tuple_get(0, "__p"), im.tuple_get(1, "__p"))) + result = _apply(im.call(im.call("map_tuple")(g))(im.ref("t", nested))) + + expected = im.make_tuple( + im.call(g)(im.tuple_get(0, "t")), + im.call(g)(im.tuple_get(1, "t")), + ) + assert result == expected + + +def test_make_tuple_arg_is_collapsed(): + """When the input tuple is a `make_tuple` literal, projection should collapse + directly to the element (no residual `tuple_get(make_tuple(...))`).""" + result = _apply( + im.call(im.call("tree_map_tuple")(_neg()))( + im.make_tuple(im.ref("a", i_field), im.ref("b", i_field)) + ) + ) + + expected = im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")) + assert result == expected + + +def test_nested_make_tuple_arg_is_collapsed(): + """A nested `make_tuple` arg should be fully collapsed at every depth: each + `tuple_get(i, make_tuple(...))` along the recursion is folded directly.""" + result = _apply( + im.call(im.call("tree_map_tuple")(_neg()))( + im.make_tuple( + im.make_tuple(im.ref("a", i_field), im.ref("b", i_field)), + im.make_tuple(im.ref("c", i_field), im.ref("d", i_field)), + ) + ) + ) + + expected = im.make_tuple( + im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), + im.make_tuple(im.call(_neg())("c"), im.call(_neg())("d")), + ) + assert result == expected + + +def test_map_tuple_with_make_tuple_arg_is_collapsed(): + """The `make_tuple` short-circuit must also apply for the `map_tuple` builtin.""" + result = _apply( + im.call(im.call("map_tuple")(_neg()))( + im.make_tuple(im.ref("a", i_field), im.ref("b", i_field)) + ) + ) + + expected = im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")) + assert result == expected + + +def test_non_trivial_arg_is_let_bound(): + """Non-trivial (potentially expensive) tuple expressions must still be + let-bound to avoid duplicating work across leaf projections.""" + # `f(t)` is a non-trivial expression returning a tuple + f = im.lambda_("__t")(im.ref("__t", i_tuple_field)) + result = _apply( + im.call(im.call("tree_map_tuple")(_neg()))(im.call(f)(im.ref("t", i_tuple_field))) + ) + + expected = im.let("_utm_0", im.call(f)("t"))( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, "_utm_0")), + im.call(_neg())(im.tuple_get(1, "_utm_0")), + ) + ) + assert result == expected + + +@pytest.mark.parametrize( + "lhs_type, rhs_type", + [ + (ts.TupleType(types=[i_field, i_field, i_field]), i_tuple_field), + (ts.TupleType(types=[i_field, i_tuple_field]), i_tuple_field), + ], +) +def test_tree_map_tuple_mismatched_structure_raises_type_error(lhs_type, rhs_type): + expr = im.call(im.call("tree_map_tuple")(_plus()))(im.ref("a", lhs_type), im.ref("b", rhs_type)) + + with pytest.raises(TypeError, match=r"same tuple structure"): + _apply(expr) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 8fd3a55ef3..b9688f79b8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1230,7 +1230,7 @@ def test_gtir_neighbors_as_input(): im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ), inner_domain, - )(im.op_as_fieldop(im.map_("divides"), inner_domain)("v2e_field", "x")) + )(im.op_as_fieldop(im.map_list("divides"), inner_domain)("v2e_field", "x")) ), domain=outer_domain, target=gtir.SymRef(id="vertices"), @@ -1439,8 +1439,8 @@ def test_gtir_reduce_dot_product(): im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ) )( - im.op_as_fieldop(im.map_("plus"))( - im.op_as_fieldop(im.map_("multiplies"))( + im.op_as_fieldop(im.map_list("plus"))( + im.op_as_fieldop(im.map_list("multiplies"))( im.as_fieldop_neighbors("V2E", "edges"), "v2e_field", ),