Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1b4707c
Tracer prototype
SF-N Feb 20, 2026
902f8a3
Merge branch 'main' into tracer_support
SF-N Apr 27, 2026
02f881f
Introduce GTIR tree_map builtin and transform to make_tuple, also sup…
SF-N Apr 27, 2026
0ec4692
Run pre-commit and fix some tests
SF-N Apr 27, 2026
ab84ecc
Run CollapseTuple after UnrollTreeMap
SF-N Apr 28, 2026
36d6956
Merge branch 'main' into tracer_support_tree_map
SF-N Apr 28, 2026
152300e
Address review comments
SF-N Apr 28, 2026
d459b0e
Address further review comments
SF-N Apr 28, 2026
97af81e
Apply review comments
SF-N Apr 29, 2026
8d75708
Merge branch 'main' into tracer_support_tree_map
SF-N Apr 29, 2026
067bc29
Merge branch 'main' into tracer_support_tree_map
SF-N May 4, 2026
32e5b2d
Rename map_ -> map_list
SF-N May 28, 2026
a7175d7
Run pre-commit
SF-N May 28, 2026
4f89818
Merge branch 'main' into tracer_support_tree_map
SF-N May 28, 2026
2779fd0
Refactor tree_map_tuple and add map_tuple with unrolling support
SF-N May 28, 2026
80f3273
Rename
SF-N May 28, 2026
454e15f
Minor fix
SF-N May 28, 2026
8a1febd
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 2, 2026
31b969a
Remove unnecessary CollapseTuple loop
SF-N Jun 2, 2026
c7fc102
Reposition UnrollTupleMaps and simplify CollapseTuple usage
SF-N Jun 2, 2026
7993b9c
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 2, 2026
b7f8ba9
Refactor tree_map unrolling
SF-N Jun 16, 2026
3d38868
Cleanup
SF-N Jun 16, 2026
7d5c86c
Revert "Cleanup"
SF-N Jun 17, 2026
747f36e
Revert "Refactor tree_map unrolling"
SF-N Jun 17, 2026
b767700
Cleanup
SF-N Jun 17, 2026
e91f1f1
Merge branch 'origin-main' into tracer_support_tree_map
SF-N Jun 17, 2026
7b270a3
Address review comment
SF-N Jun 18, 2026
d3d4e46
Remove CollapseTuple pass after UnrollTupleMaps
SF-N Jun 19, 2026
d0272df
Remove program wrapper in tests
SF-N Jun 19, 2026
56f234e
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N Jun 19, 2026
b7bb0b2
Fix test
SF-N Jun 19, 2026
158d540
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 24, 2026
7d8f56c
Also allow itir.Expr in UnrollTupleMaps and run tye_inference when ne…
SF-N Jun 24, 2026
7808b0f
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,23 +429,15 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._lower_and_map("if_", *node.args)

cond_ = self.visit(node.args[0])
true_ = self.visit(node.args[1])
false_ = self.visit(node.args[2])
cond_symref_name = f"__cond_{itir.lenient_ir_fingerprinter(cond_)}"

def create_if(
true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec]
) -> itir.FunCall:
return _map(
"if_",
(im.ref(cond_symref_name), true_, false_),
(node.args[0].type, *arg_types),
result = im.tree_map_tuple(
im.lambda_("__a", "__b")(
im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b"))
)

result = lowering_utils.process_elements(
create_if,
(self.visit(node.args[1]), self.visit(node.args[2])),
node.type,
arg_types=(node.args[1].type, node.args[2].type),
)
)(true_, false_)

return im.let(cond_symref_name, cond_)(result)

Expand Down Expand Up @@ -551,7 +543,7 @@ def _map(
original_arg_types: tuple[ts.TypeSpec, ...],
) -> itir.FunCall:
"""
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_list`ing lists.
"""
if all(
isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType))
Expand All @@ -564,7 +556,7 @@ def _map(
promote_to_list(arg_type)(larg)
for arg_type, larg in zip(original_arg_types, lowered_args)
)
op = im.map_(op)
op = im.map_list(op)

return im.op_as_fieldop(op)(*lowered_args)

Expand Down
16 changes: 14 additions & 2 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,17 @@ def neighbors(*args):


@builtin_dispatch
def map_(*args):
def map_list(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def tree_map_tuple(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def map_tuple(*args):
raise BackendNotSelectedError()


Expand Down Expand Up @@ -498,7 +508,9 @@ def get_domain_range(*args):
"lift",
"make_const_list",
"make_tuple",
"map_",
"tree_map_tuple",
"map_tuple",
"map_list",
"named_range",
"neighbors",
"reduce",
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,8 +1471,8 @@ def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]:
raise AssertionError("All lists must have the same offset.")


@builtins.map_.register(EMBEDDED)
def map_(op):
@builtins.map_list.register(EMBEDDED)
def map_list(op):
def impl_(*lists):
offset = _get_offset(*lists)
if offset is None:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def is_applied_map(arg: itir.Node) -> TypeGuard[_FunCallToFunCallToRef]:
isinstance(arg, itir.FunCall)
and isinstance(arg.fun, itir.FunCall)
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "map_"
and arg.fun.fun.id == "map_list"
)


Expand Down
16 changes: 13 additions & 3 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,19 @@ def index(dim: common.Dimension) -> itir.FunCall:
return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind))


def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))
def map_list(op):
"""Create a `map_list` call."""
return call(call("map_list")(op))


def tree_map_tuple(op):
"""Create a `tree_map_tuple` call: tree_map_tuple(op)(tup1, tup2, ...)."""
return call(call("tree_map_tuple")(op))


def map_tuple(op):
"""Create a `map_tuple` call: map_tuple(op)(tup)."""
return call(call("map_tuple")(op))


def reduce(op, expr):
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node:
if cpm.is_call_to(node.args[1], "make_const_list"):
return node.args[1].args[0]
if cpm.is_applied_map(node.args[1]):
# list_get(0, map_(λ(val_) → foo(val_, int64))(·__sym_1))
# list_get(0, map_list(λ(val_) → foo(val_, int64))(·__sym_1))
# -> (λ(val_) → foo(val_, int64))(list_get(0, ·__sym_1))
lsts = node.args[1].args
assert len(node.args[1].fun.args) == 1 # a single lambda in the map
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _is_collectable_expr(node: itir.Node) -> bool:
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce, map_ and neighbors nodes if they are left in the IR at this point, this may lead to
# do also not collect reduce, map_list and neighbors nodes if they are left in the IR at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach), see also https://github.com/GridTools/gt4py/issues/1795
Expand All @@ -104,7 +104,7 @@ def _is_collectable_expr(node: itir.Node) -> bool:
# do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if cpm.is_call_to(
node, ("lift", "shift", "neighbors", "reduce", "map_", "index")
node, ("lift", "shift", "neighbors", "reduce", "map_list", "index")
) or cpm.is_applied_lift(node):
return False
return True
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Fuses nested `map_`s.
Fuses nested `map_list`s.

Preconditions:
- `FunctionDefinitions` are inlined
Expand All @@ -29,7 +29,7 @@ class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrai
to
map(λ(a, b, c) → f(a, g(b, c)))(a, b, c)

reduce(λ(x, y) → f(x, y), init)(map_(g(z, w))(a, b))
reduce(λ(x, y) → f(x, y), init)(map_list(g(z, w))(a, b))
to
reduce(λ(x, y, z) → f(x, g(y, z)), init)(a, b)
"""
Expand Down Expand Up @@ -93,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_op = ir.Lambda(params=new_params, expr=new_body)
if cpm.is_applied_map(node):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args
fun=ir.FunCall(fun=ir.SymRef(id="map_list"), args=[new_op]), args=new_args
)
else: # is_applied_reduce(node)
return ir.FunCall(
Expand Down
12 changes: 12 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
prune_empty_concat_where,
remove_broadcast,
symbol_ref_utils,
unroll_tuple_maps,
)
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
Expand Down Expand Up @@ -169,6 +170,11 @@ def apply_common_transforms(
ir = inline_lifts.InlineLifts().visit(ir)

ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
# `UnrollTupleMaps` requires fully-inferred tuple types (relies on `reinfer` to see
# nested `TupleType` chains), so the offset_provider is passed for on-demand inference.
ir = unroll_tuple_maps.UnrollTupleMaps.apply(
ir, uids=uids, offset_provider_type=offset_provider_type
)
ir = dead_code_elimination.dead_code_elimination(
ir, uids=uids, offset_provider_type=offset_provider_type
) # domain inference does not support dead-code
Expand Down Expand Up @@ -282,6 +288,12 @@ def apply_fieldview_transforms(
ir = inline_fundefs.prune_unreferenced_fundefs(ir)
# required for dead-code-elimination and `prune_empty_concat_where` pass
ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
# `UnrollTupleMaps` requires fully-inferred tuple types, so the offset_provider is passed for
# on-demand inference.
ir = unroll_tuple_maps.UnrollTupleMaps.apply(
ir, uids=uids, offset_provider_type=offset_provider_type
)

ir = dead_code_elimination.dead_code_elimination(
ir, offset_provider_type=offset_provider_type, uids=uids
)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def applied_as_fieldop(*args):
"scan": _scan,
"reduce": _reduce,
"neighbors": _neighbors,
"map_": _map,
"map_list": _map,
"if_": _if,
"make_tuple": _make_tuple,
}
Expand Down
140 changes: 140 additions & 0 deletions src/gt4py/next/iterator/transforms/unroll_tuple_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import functools
from typing import TypeVar

from gt4py import eve
from gt4py.next import common, utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.type_system import inference as itir_inference
from gt4py.next.type_system import type_specifications as ts


def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr:
"""Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call.

Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer.
"""
if cpm.is_call_to(expr, "make_tuple"):
return expr.args[i]
return im.tuple_get(i, expr)


def _tree_map_tuple_body(
f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType]
) -> itir.Expr:
"""Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls."""

def tuple_structure(type_: ts.TypeSpec) -> tuple[object, ...] | None:
if isinstance(type_, ts.TupleType):
return tuple(tuple_structure(el_type) for el_type in type_.types)
return None

expected_structure = tuple_structure(tup_types[0])
if any(tuple_structure(tup_type) != expected_structure for tup_type in tup_types[1:]):
raise TypeError("'tree_map_tuple' requires all arguments to have the same tuple structure.")

@utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: im.make_tuple(*elts),
with_path_arg=True,
)
def mapper(*args):
*_el_types, path = args
return im.call(f)(
*(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs)
)

return mapper(*tup_types)


def _map_tuple_body(
f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType]
) -> itir.Expr:
"""Unroll `map_tuple(f)(t)` over top-level elements only (no recursion)."""
(tup_expr,) = tup_exprs
(tup_type,) = tup_types
return im.make_tuple(
*(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types)))
)


_UNROLLERS = {
"tree_map_tuple": _tree_map_tuple_body,
"map_tuple": _map_tuple_body,
}


ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.Expr)


@dataclasses.dataclass
class UnrollTupleMaps(eve.NodeTranslator):
"""Unroll tuple-map ITIR builtins (`tree_map_tuple`, `map_tuple`) into `make_tuple`."""

PRESERVED_ANNEX_ATTRS = ("domain",)

uids: utils.IDGeneratorPool

@classmethod
def apply(
cls,
node: ProgramOrExpr,
*,
uids: utils.IDGeneratorPool | None = None,
offset_provider_type: common.OffsetProviderType | None = None,
) -> ProgramOrExpr:
if node.type is None:
node = itir_inference.infer(
node,
offset_provider_type=offset_provider_type or {},
allow_undeclared_symbols=not isinstance(node, itir.Program),
)
if uids is None:
uids = utils.IDGeneratorPool()
return cls(uids=uids).visit(node)

def visit_FunCall(self, node: itir.FunCall):
node = self.generic_visit(node)

builtin_name = next((name for name in _UNROLLERS if cpm.is_call_to(node.fun, name)), None)
if builtin_name is None:
return node

assert isinstance(node.fun, itir.FunCall)
f = node.fun.args[0]
tup_args = node.args

tup_types: list[ts.TupleType] = []
for tup in tup_args:
itir_inference.reinfer(tup)
assert isinstance(tup.type, ts.TupleType)
tup_types.append(tup.type)

# For trivial args (those that can be duplicated without cost or side effects),
# we substitute them directly into the body. This avoids leaving behind
# `tuple_get(i, make_tuple(...))` patterns that would otherwise require a
# separate cleanup pass (CollapseTuple). For non-trivial args we still
# introduce a `let` binding to avoid duplicating expensive sub-expressions.
substituted_exprs: list[itir.Expr] = []
let_bindings: list[tuple[str, itir.Expr]] = []
for tup in tup_args:
if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"):
substituted_exprs.append(tup)
else:
ref_name = next(self.uids["_utm"])
let_bindings.append((ref_name, tup))
substituted_exprs.append(im.ref(ref_name))

body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types)

result = im.let(*let_bindings)(body) if let_bindings else body
itir_inference.reinfer(result)
return result
Loading