From 3d9e0f152d3ee3d8932402931b8aa0e0118b2a4a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 23 Jun 2026 20:00:05 +0200 Subject: [PATCH 1/6] fix[next]: correct domain inference for never-selected concat_where branches --- .../next/iterator/ir_utils/domain_utils.py | 68 ++++++++--- .../ir_utils_tests/test_domain_utils.py | 100 ++++++++++++++--- .../transforms_tests/test_domain_inference.py | 106 ++++++++++++++++++ 3 files changed, 242 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 49b98e0877..4d1fd5476e 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -81,6 +81,12 @@ def _unstructured_translate_range_statically( assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal) start, stop = int(start_expr.value), int(stop_expr.value) + if range_.empty(): + return SymbolicRange( + im.literal(str("0"), builtins.INTEGER_INDEX_BUILTIN), + im.literal(str("0"), builtins.INTEGER_INDEX_BUILTIN), + ) + nb_index: slice | int if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]: nb_index = slice(None) @@ -235,26 +241,60 @@ def _reduce_ranges( *ranges: SymbolicRange, start_reduce_op: Callable[[itir.Expr, itir.Expr], itir.Expr], stop_reduce_op: Callable[[itir.Expr, itir.Expr], itir.Expr], + neutral_start_reduce_val: itir.InfinityLiteral, + neutral_stop_reduce_val: itir.InfinityLiteral, ) -> SymbolicRange: - """Uses start_op and stop_op to fold the start and stop of a list of ranges.""" - start = functools.reduce( - lambda current_expr, el_expr: start_reduce_op(current_expr, el_expr), - [range_.start for range_ in ranges], - ) - stop = functools.reduce( - lambda current_expr, el_expr: stop_reduce_op(current_expr, el_expr), - [range_.stop for range_ in ranges], - ) - # constant fold expression to keep the tree small - start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr - return SymbolicRange(start, stop) + """ + Fold the start and stop of a list of ranges with `start_reduce_op` / `stop_reduce_op`. + + This function only computes the correct value if the ranges are either disjoint or empty. + Calculation is by means of the convex hull and some special handling for empty ranges. + """ + # symbolic ranges, i.e., `empty()` is `None` must not be dropped; they are guarded below + non_empty_ranges = [range_ for range_ in ranges if range_.empty() is not True] + + result: SymbolicRange + if len(non_empty_ranges) == 0: + result = ranges[0] # all empty -> empty + elif len(non_empty_ranges) == 1: + result = non_empty_ranges[0] # nothing to combine + else: + + def guarded_bound(bound: itir.Expr, range_: SymbolicRange, neutral: itir.Expr) -> itir.Expr: + # guard a symbolically-empty range so it contributes `neutral` instead of `bound` + if range_.empty() is None: + return im.if_(im.greater_equal(range_.start, range_.stop), neutral, bound) + return bound + + result = SymbolicRange( + functools.reduce( + start_reduce_op, + [guarded_bound(r.start, r, neutral_start_reduce_val) for r in non_empty_ranges], + ), + functools.reduce( + stop_reduce_op, + [guarded_bound(r.stop, r, neutral_stop_reduce_val) for r in non_empty_ranges], + ), + ) + + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr + return SymbolicRange(start, stop) _range_union = functools.partial( - _reduce_ranges, start_reduce_op=im.minimum, stop_reduce_op=im.maximum + _reduce_ranges, + start_reduce_op=im.minimum, + neutral_start_reduce_val=itir.InfinityLiteral.POSITIVE, + stop_reduce_op=im.maximum, + neutral_stop_reduce_val=itir.InfinityLiteral.NEGATIVE, ) _range_intersection = functools.partial( - _reduce_ranges, start_reduce_op=im.maximum, stop_reduce_op=im.minimum + _reduce_ranges, + start_reduce_op=im.maximum, + neutral_start_reduce_val=itir.InfinityLiteral.NEGATIVE, + stop_reduce_op=im.minimum, + neutral_stop_reduce_val=itir.InfinityLiteral.POSITIVE, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py index aa68528372..1497bbe6d5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py @@ -43,6 +43,17 @@ def _make_domain(i: int): ) +def _guard(bound: str, neutral: itir.InfinityLiteral): + # Symbolic emptiness guard `_reduce_ranges` wraps undecidable ranges in: an empty range + # contributes `neutral` instead of `bound` (e.g. "start_I_0"); start/stop are deduced from it. + _, dim, i = bound.split("_") + return im.if_( + im.greater_equal(im.ref(f"start_{dim}_{i}"), im.ref(f"end_{dim}_{i}")), + neutral, + im.ref(bound), + ) + + def test_symbolic_range(): with pytest.raises(AssertionError): domain_utils.SymbolicRange(itir.InfinityLiteral.POSITIVE, 0) @@ -73,46 +84,99 @@ def test_domain_union(): domain1 = _make_domain(1) domain2 = _make_domain(2) + pos, neg = itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE expected = domain_utils.SymbolicDomain( grid_type=common.GridType.CARTESIAN, ranges={ - I: domain_utils.SymbolicRange( + dim: domain_utils.SymbolicRange( im.minimum( - im.minimum(im.ref("start_I_0"), im.ref("start_I_1")), im.ref("start_I_2") + im.minimum( + _guard(f"start_{dim_name}_0", pos), _guard(f"start_{dim_name}_1", pos) + ), + _guard(f"start_{dim_name}_2", pos), ), - im.maximum(im.maximum(im.ref("end_I_0"), im.ref("end_I_1")), im.ref("end_I_2")), - ), - J: domain_utils.SymbolicRange( - im.minimum( - im.minimum(im.ref("start_J_0"), im.ref("start_J_1")), im.ref("start_J_2") + im.maximum( + im.maximum(_guard(f"end_{dim_name}_0", neg), _guard(f"end_{dim_name}_1", neg)), + _guard(f"end_{dim_name}_2", neg), ), - im.maximum(im.maximum(im.ref("end_J_0"), im.ref("end_J_1")), im.ref("end_J_2")), - ), + ) + for dim, dim_name in ((I, "I"), (J, "J")) }, ) assert expected == domain_utils.domain_union(domain0, domain1, domain2) +def test_domain_union_drops_empty_domains(): + # Empty domains are the union's identity element; keeping them would over-approximate (#2205). + non_empty = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, ranges={I: domain_utils.SymbolicRange(0, 10)} + ) + empty_a = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, ranges={I: domain_utils.SymbolicRange(10, 10)} + ) + empty_b = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, ranges={I: domain_utils.SymbolicRange(11, 11)} + ) + + # a single non-empty domain among empty ones is returned unchanged + assert domain_utils.domain_union(empty_a, non_empty, empty_b) == non_empty + assert domain_utils.domain_union(empty_a) == empty_a + + +def test_domain_union_all_empty(): + # A union of only empty domains stays empty (not the convex hull `[10, 11)`). + empty_a = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, ranges={I: domain_utils.SymbolicRange(10, 10)} + ) + empty_b = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, ranges={I: domain_utils.SymbolicRange(11, 11)} + ) + result = domain_utils.domain_union(empty_a, empty_b) + assert result.empty() + + +def test_unstructured_translate_empty_range(): + # Translating an empty range must not reduce over the empty connectivity slice (which would + # raise); it yields an empty range in the codomain dimension. + offset_provider = { + "V2E": constructors.as_connectivity( + domain={Vertex: (0, 4), V2EDim: 1}, + codomain=Edge, + data=np.asarray([0, 1, 2, 3], dtype=fbuiltins.IndexType).reshape((4, 1)), + ) + } + shift_chain = [im.ensure_offset(o) for o in ("V2E", 0)] + domain = domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.UNSTRUCTURED, {Vertex: (2, 2)}) # empty + ) + translated = domain.translate(shift_chain, offset_provider) + assert translated.empty() + assert set(translated.ranges.keys()) == {Edge} + + def test_domain_intersection(): domain0 = _make_domain(0) domain1 = _make_domain(1) domain2 = _make_domain(2) + pos, neg = itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE + # symbolic ranges are guarded with the intersection's neutral element (the whole universe) expected = domain_utils.SymbolicDomain( grid_type=common.GridType.CARTESIAN, ranges={ - I: domain_utils.SymbolicRange( + dim: domain_utils.SymbolicRange( im.maximum( - im.maximum(im.ref("start_I_0"), im.ref("start_I_1")), im.ref("start_I_2") + im.maximum( + _guard(f"start_{dim_name}_0", neg), _guard(f"start_{dim_name}_1", neg) + ), + _guard(f"start_{dim_name}_2", neg), ), - im.minimum(im.minimum(im.ref("end_I_0"), im.ref("end_I_1")), im.ref("end_I_2")), - ), - J: domain_utils.SymbolicRange( - im.maximum( - im.maximum(im.ref("start_J_0"), im.ref("start_J_1")), im.ref("start_J_2") + im.minimum( + im.minimum(_guard(f"end_{dim_name}_0", pos), _guard(f"end_{dim_name}_1", pos)), + _guard(f"end_{dim_name}_2", pos), ), - im.minimum(im.minimum(im.ref("end_J_0"), im.ref("end_J_1")), im.ref("end_J_2")), - ), + ) + for dim, dim_name in ((I, "I"), (J, "J")) }, ) assert expected == domain_utils.domain_intersection(domain0, domain1, domain2) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index cbfee0c920..93f0191e42 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1331,6 +1331,112 @@ def test_nested_concat_where_two_dimensions(): assert expected_domains == constant_fold_accessed_domains(actual_domains) +def test_concat_where_shift_in_never_selected_branch(): + # Regression test for #2205: the never-selected `K >= 10` branch has an empty domain `[10, 10)`; + # its shifted access `a[K+1]` must not over-extend `a` from `[0, 10)` to `[0, 11)`. + domain = im.domain(common.GridType.CARTESIAN, {KDim: (0, 10)}) + cond_lt = im.domain(common.GridType.CARTESIAN, {KDim: (itir.InfinityLiteral.NEGATIVE, 9)}) + cond_ge = im.domain(common.GridType.CARTESIAN, {KDim: (10, itir.InfinityLiteral.POSITIVE)}) + diff = im.lambda_("t")(im.minus(im.deref("t"), im.deref(im.shift(Koff, 1)("t")))) + + domain_lt = im.domain(common.GridType.CARTESIAN, {KDim: (0, 9)}) + domain_never = im.domain(common.GridType.CARTESIAN, {KDim: (10, 10)}) # empty + domain_eq = im.domain(common.GridType.CARTESIAN, {KDim: (9, 10)}) + + testee = im.concat_where( + cond_lt, + im.as_fieldop(diff)("a"), + im.concat_where(cond_ge, im.as_fieldop(diff)("a"), im.as_fieldop("deref")("a")), + ) + expected = im.concat_where( + cond_lt, + im.as_fieldop(diff, domain_lt)("a"), + im.concat_where( + cond_ge, + im.as_fieldop(diff, domain_never)("a"), + im.as_fieldop("deref", domain_eq)("a"), + ), + ) + expected_domains = {"a": domain} # `[0, 10)`, not the over-extended `[0, 11)` + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider={} + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_shift_in_never_selected_branch_shared(): + # Regression test for #2205, `let`-bound variant: the shifted value is shared via a `let` and + # referenced from the never-selected `K >= 10` branch (empty domain `[10, 10)`), which must not + # over-extend `a` from `[0, 10)` to `[0, 11)`. + domain = im.domain(common.GridType.CARTESIAN, {KDim: (0, 10)}) + cond_lt = im.domain(common.GridType.CARTESIAN, {KDim: (itir.InfinityLiteral.NEGATIVE, 9)}) + cond_ge = im.domain(common.GridType.CARTESIAN, {KDim: (10, itir.InfinityLiteral.POSITIVE)}) + diff = im.lambda_("t")(im.minus(im.deref("t"), im.deref(im.shift(Koff, 1)("t")))) + + domain_diff = im.domain(common.GridType.CARTESIAN, {KDim: (0, 9)}) + domain_eq = im.domain(common.GridType.CARTESIAN, {KDim: (9, 10)}) + + testee = im.let("diff", im.as_fieldop(diff)("a"))( + im.concat_where( + cond_lt, + "diff", + im.concat_where(cond_ge, "diff", im.as_fieldop("deref")("a")), + ) + ) + expected = im.let("diff", im.as_fieldop(diff, domain_diff)("a"))( + im.concat_where( + cond_lt, + "diff", + im.concat_where(cond_ge, "diff", im.as_fieldop("deref", domain_eq)("a")), + ) + ) + expected_domains = {"a": domain} # `[0, 10)`, not the over-extended `[0, 11)` + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider={} + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_unstructured_shift_in_never_selected_branch(unstructured_offset_provider): + # Regression test for #2205, unstructured variant: shifting the never-selected branch's empty + # `Edge: [1, 1)` range must not reduce over the empty connectivity slice (which would raise); + # it yields an empty `Vertex` range, so `a` is accessed on an empty domain. + domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) + cond = im.domain(common.GridType.UNSTRUCTURED, {Edge: (1, itir.InfinityLiteral.POSITIVE)}) + e2v = im.lambda_("it")(im.deref(im.shift("E2V", 0)("it"))) + + domain_never = im.domain(common.GridType.UNSTRUCTURED, {Edge: (1, 1)}) # empty + domain_false = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) + domain_a = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 0)}) # empty (never accessed) + + testee = im.concat_where(cond, im.as_fieldop(e2v)("a"), im.as_fieldop("deref")("b")) + expected = im.concat_where( + cond, im.as_fieldop(e2v, domain_never)("a"), im.as_fieldop("deref", domain_false)("b") + ) + expected_domains = { + "a": domain_a, + "b": domain_false, + } + + actual_call, actual_domains = infer_domain.infer_expr( + testee, + domain_utils.SymbolicDomain.from_expr(domain), + offset_provider=unstructured_offset_provider, + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + def test_broadcast(): testee = im.call("broadcast")("in_field", im.make_tuple(itir.AxisLiteral(value="IDim"))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) From 50a0c0eca5f2ffe6b86074e139ae6e187a48f106 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 24 Jun 2026 10:25:44 +0200 Subject: [PATCH 2/6] Fix non-sense --- .../next/iterator/ir_utils/domain_utils.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 4d1fd5476e..0b7fa12578 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -247,17 +247,17 @@ def _reduce_ranges( """ Fold the start and stop of a list of ranges with `start_reduce_op` / `stop_reduce_op`. - This function only computes the correct value if the ranges are either disjoint or empty. - Calculation is by means of the convex hull and some special handling for empty ranges. + This function only computes the correct value if the ranges are either overlapping / adjacent + or empty as calculation is by means of the convex hull (and some special handling for empty + ranges). """ # symbolic ranges, i.e., `empty()` is `None` must not be dropped; they are guarded below non_empty_ranges = [range_ for range_ in ranges if range_.empty() is not True] - result: SymbolicRange if len(non_empty_ranges) == 0: - result = ranges[0] # all empty -> empty + return ranges[0] elif len(non_empty_ranges) == 1: - result = non_empty_ranges[0] # nothing to combine + return non_empty_ranges[0] else: def guarded_bound(bound: itir.Expr, range_: SymbolicRange, neutral: itir.Expr) -> itir.Expr: @@ -266,19 +266,18 @@ def guarded_bound(bound: itir.Expr, range_: SymbolicRange, neutral: itir.Expr) - return im.if_(im.greater_equal(range_.start, range_.stop), neutral, bound) return bound - result = SymbolicRange( - functools.reduce( - start_reduce_op, - [guarded_bound(r.start, r, neutral_start_reduce_val) for r in non_empty_ranges], - ), - functools.reduce( - stop_reduce_op, - [guarded_bound(r.stop, r, neutral_stop_reduce_val) for r in non_empty_ranges], - ), + start = functools.reduce( + start_reduce_op, + [guarded_bound(r.start, r, neutral_start_reduce_val) for r in non_empty_ranges], + ) + stop = functools.reduce( + stop_reduce_op, + [guarded_bound(r.stop, r, neutral_stop_reduce_val) for r in non_empty_ranges], ) # constant fold expression to keep the tree small start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr + return SymbolicRange(start, stop) From ebfed07698470feebd7eab87edf59a936920e945 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 24 Jun 2026 11:00:25 +0200 Subject: [PATCH 3/6] Fix non-sense --- .../ir_utils_tests/test_domain_utils.py | 10 ++++----- .../transforms_tests/test_domain_inference.py | 21 ++++++++++--------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py index 1497bbe6d5..6f7e239ea9 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py @@ -44,8 +44,7 @@ def _make_domain(i: int): def _guard(bound: str, neutral: itir.InfinityLiteral): - # Symbolic emptiness guard `_reduce_ranges` wraps undecidable ranges in: an empty range - # contributes `neutral` instead of `bound` (e.g. "start_I_0"); start/stop are deduced from it. + # Guard range bound by neutral element of the respective bound _, dim, i = bound.split("_") return im.if_( im.greater_equal(im.ref(f"start_{dim}_{i}"), im.ref(f"end_{dim}_{i}")), @@ -136,8 +135,6 @@ def test_domain_union_all_empty(): def test_unstructured_translate_empty_range(): - # Translating an empty range must not reduce over the empty connectivity slice (which would - # raise); it yields an empty range in the codomain dimension. offset_provider = { "V2E": constructors.as_connectivity( domain={Vertex: (0, 4), V2EDim: 1}, @@ -145,11 +142,12 @@ def test_unstructured_translate_empty_range(): data=np.asarray([0, 1, 2, 3], dtype=fbuiltins.IndexType).reshape((4, 1)), ) } - shift_chain = [im.ensure_offset(o) for o in ("V2E", 0)] domain = domain_utils.SymbolicDomain.from_expr( im.domain(common.GridType.UNSTRUCTURED, {Vertex: (2, 2)}) # empty ) - translated = domain.translate(shift_chain, offset_provider) + translated = domain.translate( + [itir.OffsetLiteral(value="V2E"), itir.OffsetLiteral(value="0")], offset_provider + ) assert translated.empty() assert set(translated.ranges.keys()) == {Edge} diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 93f0191e42..317ce27774 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1333,7 +1333,9 @@ def test_nested_concat_where_two_dimensions(): def test_concat_where_shift_in_never_selected_branch(): # Regression test for #2205: the never-selected `K >= 10` branch has an empty domain `[10, 10)`; - # its shifted access `a[K+1]` must not over-extend `a` from `[0, 10)` to `[0, 11)`. + # its shifted access `a[K+1]` must not over-extend `a` from `[0, 10)` to `[0, 11)`, which can + # happen if the accessed domain is computed naively by deriving the domain union of `[10, 10)` + # `[11, 11)` as ``[min(10, 11), max(10, 11))``. domain = im.domain(common.GridType.CARTESIAN, {KDim: (0, 10)}) cond_lt = im.domain(common.GridType.CARTESIAN, {KDim: (itir.InfinityLiteral.NEGATIVE, 9)}) cond_ge = im.domain(common.GridType.CARTESIAN, {KDim: (10, itir.InfinityLiteral.POSITIVE)}) @@ -1407,23 +1409,22 @@ def test_concat_where_shift_in_never_selected_branch_shared(): def test_concat_where_unstructured_shift_in_never_selected_branch(unstructured_offset_provider): # Regression test for #2205, unstructured variant: shifting the never-selected branch's empty - # `Edge: [1, 1)` range must not reduce over the empty connectivity slice (which would raise); - # it yields an empty `Vertex` range, so `a` is accessed on an empty domain. + # `Edge: [1, 1)` range must not raise. domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) cond = im.domain(common.GridType.UNSTRUCTURED, {Edge: (1, itir.InfinityLiteral.POSITIVE)}) - e2v = im.lambda_("it")(im.deref(im.shift("E2V", 0)("it"))) + stencil_e2v = im.lambda_("it")(im.deref(im.shift("E2V", 0)("it"))) - domain_never = im.domain(common.GridType.UNSTRUCTURED, {Edge: (1, 1)}) # empty - domain_false = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) - domain_a = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 0)}) # empty (never accessed) + domain_empty = im.domain(common.GridType.UNSTRUCTURED, {Edge: (1, 1)}) + domain_a = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 0)}) + domain_b = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) - testee = im.concat_where(cond, im.as_fieldop(e2v)("a"), im.as_fieldop("deref")("b")) + testee = im.concat_where(cond, im.as_fieldop(stencil_e2v)("a"), im.as_fieldop("deref")("b")) expected = im.concat_where( - cond, im.as_fieldop(e2v, domain_never)("a"), im.as_fieldop("deref", domain_false)("b") + cond, im.as_fieldop(stencil_e2v, domain_empty)("a"), im.as_fieldop("deref", domain_b)("b") ) expected_domains = { "a": domain_a, - "b": domain_false, + "b": domain_b, } actual_call, actual_domains = infer_domain.infer_expr( From a7e73e5cedc42a785c4cb8f6b2c5779e945334ee Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 24 Jun 2026 11:01:49 +0200 Subject: [PATCH 4/6] Cleanup --- .../iterator_tests/ir_utils_tests/test_domain_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py index 6f7e239ea9..09bb057afe 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py @@ -165,16 +165,16 @@ def test_domain_intersection(): dim: domain_utils.SymbolicRange( im.maximum( im.maximum( - _guard(f"start_{dim_name}_0", neg), _guard(f"start_{dim_name}_1", neg) + _guard(f"start_{dim.value}_0", neg), _guard(f"start_{dim.value}_1", neg) ), - _guard(f"start_{dim_name}_2", neg), + _guard(f"start_{dim.value}_2", neg), ), im.minimum( - im.minimum(_guard(f"end_{dim_name}_0", pos), _guard(f"end_{dim_name}_1", pos)), - _guard(f"end_{dim_name}_2", pos), + im.minimum(_guard(f"end_{dim.value}_0", pos), _guard(f"end_{dim.value}_1", pos)), + _guard(f"end_{dim.value}_2", pos), ), ) - for dim, dim_name in ((I, "I"), (J, "J")) + for dim in (I, J) }, ) assert expected == domain_utils.domain_intersection(domain0, domain1, domain2) From 5dbe8946580e7bff47ed15e9df20506e9644bab9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 24 Jun 2026 13:29:01 +0200 Subject: [PATCH 5/6] Cleanup --- src/gt4py/next/iterator/ir.py | 5 ++++ .../next/iterator/ir_utils/domain_utils.py | 13 ++++---- src/gt4py/next/iterator/tracing.py | 6 ++++ .../next/iterator/transforms/global_tmps.py | 11 ++++++- .../next/iterator/type_system/inference.py | 5 +--- .../codegens/gtfn/codegen.py | 10 +++++++ .../codegens/gtfn/gtfn_ir.py | 17 +++++++++++ .../codegens/gtfn/itir_to_gtfn_ir.py | 8 +++++ .../program_processors/runners/roundtrip.py | 6 ++++ .../ir_utils_tests/test_domain_utils.py | 6 ++-- .../transforms_tests/test_global_tmps.py | 30 +++++++++++++++++++ 11 files changed, 104 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 5a3fc1713f..62c2ea8bad 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -73,6 +73,11 @@ class NoneLiteral(Expr): class InfinityLiteral(Expr): + """ + Infinity value appearing in concat_where domain selectors and domain bounds (to guard domain + operators on empty ranges). + """ + # TODO(tehrengruber): self referential `ClassVar` not supported in eve. if TYPE_CHECKING: POSITIVE: ClassVar[InfinityLiteral] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 0b7fa12578..a3c77689d3 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -254,10 +254,12 @@ def _reduce_ranges( # symbolic ranges, i.e., `empty()` is `None` must not be dropped; they are guarded below non_empty_ranges = [range_ for range_ in ranges if range_.empty() is not True] + start: itir.Expr + stop: itir.Expr if len(non_empty_ranges) == 0: - return ranges[0] + start, stop = ranges[0].start, ranges[0].stop elif len(non_empty_ranges) == 1: - return non_empty_ranges[0] + start, stop = non_empty_ranges[0].start, non_empty_ranges[0].stop else: def guarded_bound(bound: itir.Expr, range_: SymbolicRange, neutral: itir.Expr) -> itir.Expr: @@ -275,10 +277,9 @@ def guarded_bound(bound: itir.Expr, range_: SymbolicRange, neutral: itir.Expr) - [guarded_bound(r.stop, r, neutral_stop_reduce_val) for r in non_empty_ranges], ) - # constant fold expression to keep the tree small - start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr - - return SymbolicRange(start, stop) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr + return SymbolicRange(start, stop) _range_union = functools.partial( diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index d1cca102da..c1a3352c59 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -143,6 +143,12 @@ def make_node(o): return o if isinstance(o, common.Dimension): return AxisLiteral(value=o.value, kind=o.kind) + if isinstance(o, common.Infinity): + if o is common.Infinity.POSITIVE: + return itir.InfinityLiteral.POSITIVE + else: + assert o is common.Infinity.NEGATIVE + return itir.InfinityLiteral.NEGATIVE if isinstance(o, common.CartesianConnectivity): # TODO(havogt): `itir.CartesianOffset` cannot represent `offset != 0` (embedded honors # it, see `execute_shift`); decide whether to fold it into the shift value or forbid it. diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 49d2f3d5b4..85b9176e09 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -156,9 +156,18 @@ def _transform_by_pattern( # hide projector from extraction projector, expr = ir_utils_misc.extract_projector(stmt.expr) + def wrapped_predicate(expr: itir.Expr, num_occurences: int) -> bool: + assert isinstance(expr.type, ts.TypeSpec) or isinstance(expr, itir.Lambda) + if isinstance(expr.type, ts.TypeSpec) and not type_info.is_type_or_tuple_of_type( + expr.type, ts.FieldType + ): + return False + + return predicate(expr, num_occurences) + new_expr, extracted_fields, _ = cse.extract_subexpression( expr, - predicate=predicate, + predicate=wrapped_predicate, prefixed_uids=uids["__tmp_subexpr"], # TODO(tehrengruber): extracting the deepest expression first would allow us to fuse # the extracted expressions resulting in fewer kernel calls & better data-locality. diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 4ab6668247..3711d1926a 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -492,10 +492,7 @@ def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: return node.type def visit_InfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: - return ts.ScalarType(kind=ts.ScalarKind.INT32) - - def visit_NegInfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: - return ts.ScalarType(kind=ts.ScalarKind.INT32) + return ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) def visit_SymRef( self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 5f1cfcc516..059011d91d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -11,6 +11,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import common +from gt4py.next.iterator import builtins as itir_builtins from gt4py.next.otf import cpp_utils from gt4py.next.program_processors.codegens.gtfn import gtfn_im_ir, gtfn_ir, gtfn_ir_common @@ -92,6 +93,15 @@ def visit_SymRef(self, node: gtfn_ir_common.SymRef, **kwargs: Any) -> str: return node.id + def visit_InfinityLiteral(self, node: gtfn_ir.InfinityLiteral, **kwargs: Any) -> str: + # `±∞` only appears in (integer) domain bounds; use the limit of the index type. + cpptype = cpp_utils.pytype_to_cpptype(itir_builtins.INTEGER_INDEX_BUILTIN) + if node == gtfn_ir.InfinityLiteral.POSITIVE: + return f"std::numeric_limits<{cpptype}>::max()" + else: + assert node == gtfn_ir.InfinityLiteral.NEGATIVE + return f"std::numeric_limits<{cpptype}>::min()" + def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: # TODO(tehrengruber): isn't this wrong and int32 should be casted to an actual int32? match node.type: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index c0a5943e03..02ccd15dea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -8,6 +8,7 @@ from __future__ import annotations +import typing from typing import Callable, ClassVar, Optional, Union from gt4py.eve import Coerced, SymbolName, datamodels @@ -45,6 +46,22 @@ class Literal(Expr): type: str +class InfinityLiteral(Expr): + # TODO(tehrengruber): self referential `ClassVar` not supported in eve. + if typing.TYPE_CHECKING: + POSITIVE: ClassVar[InfinityLiteral] + NEGATIVE: ClassVar[InfinityLiteral] + + name: typing.Literal["POSITIVE", "NEGATIVE"] + + def __str__(self) -> str: + return f"{type(self).__name__}.{self.name}" + + +InfinityLiteral.NEGATIVE = InfinityLiteral(name="NEGATIVE") +InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE") + + class IntegralConstant(Expr): value: int # generalize to other types if needed diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index cc9dbd1d12..0b475e46ff 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -29,6 +29,7 @@ FunCall, FunctionDefinition, IfStmt, + InfinityLiteral, IntegralConstant, Lambda, Literal, @@ -383,6 +384,13 @@ def visit_Lambda( def visit_Literal(self, node: itir.Literal, **kwargs: Any) -> Literal: return Literal(value=node.value, type=node.type.kind.name.lower()) + def visit_InfinityLiteral(self, node: itir.InfinityLiteral, **kwargs: Any) -> InfinityLiteral: + return ( + InfinityLiteral.POSITIVE + if node == itir.InfinityLiteral.POSITIVE + else InfinityLiteral.NEGATIVE + ) + def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs: Any) -> OffsetLiteral: return OffsetLiteral(value=node.value) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 2d8f45f631..07fb038bcc 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -61,6 +61,12 @@ def visit_Literal(self, node: itir.Literal, **kwargs: Any) -> str: return f"np.{dtype}(np.nan)" return node.value + def visit_InfinityLiteral(self, node: itir.InfinityLiteral, **kwargs: Any) -> str: + if node == itir.InfinityLiteral.POSITIVE: + return "gtx.common.Infinity.POSITIVE" + assert node == itir.InfinityLiteral.NEGATIVE + return "gtx.common.Infinity.NEGATIVE" + NoneLiteral = as_fmt("None") OffsetLiteral = as_fmt("{value}") AxisLiteral = as_fmt("{value}") diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py index 09bb057afe..17bebccdf5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py @@ -146,7 +146,7 @@ def test_unstructured_translate_empty_range(): im.domain(common.GridType.UNSTRUCTURED, {Vertex: (2, 2)}) # empty ) translated = domain.translate( - [itir.OffsetLiteral(value="V2E"), itir.OffsetLiteral(value="0")], offset_provider + [itir.OffsetLiteral(value="V2E"), itir.OffsetLiteral(value=0)], offset_provider ) assert translated.empty() assert set(translated.ranges.keys()) == {Edge} @@ -170,7 +170,9 @@ def test_domain_intersection(): _guard(f"start_{dim.value}_2", neg), ), im.minimum( - im.minimum(_guard(f"end_{dim.value}_0", pos), _guard(f"end_{dim.value}_1", pos)), + im.minimum( + _guard(f"end_{dim.value}_0", pos), _guard(f"end_{dim.value}_1", pos) + ), _guard(f"end_{dim.value}_2", pos), ), ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index f00b536dd4..1e332d225a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -548,6 +548,36 @@ def test_domain_preservation(uids: utils.IDGeneratorPool): assert actual == expected +def test_no_extraction_from_domain_bound(uids: utils.IDGeneratorPool): + # Regression test: a scalar `if_` inside a domain bound (e.g. an emptiness guard produced by + # domain inference) must not be extracted into a temporary -- only field-typed expressions may + # become temporaries (see `_transform_by_pattern`). Previously the `if_` was extracted and + # crashed because, being a scalar, it has no `annex.domain`. + offset_provider = {} + # symbolic condition so the `if_` is not constant-folded away + guard = im.if_(im.greater_equal(im.ref("n", index_type), 0), 0, 10) + domain = im.domain("cartesian_domain", {IDim: (0, guard)}) + testee = program_factory( + params=[ + im.sym("inp", i_field_type), + im.sym("out", i_field_type), + im.sym("n", index_type), + ], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("inp"), + domain=domain, + ) + ], + ) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) + + actual = global_tmps.create_global_tmps(testee, offset_provider, uids=uids) + + assert testee == actual + + def test_non_scan_projector(uids: utils.IDGeneratorPool): domain = im.domain("cartesian_domain", {IDim: (0, 2)}) offset_provider = {} From 01e7cb809ce9db1ae9f52892fe5605556de92546 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Jun 2026 03:34:06 +0200 Subject: [PATCH 6/6] Fix symbolic case tree size explosion --- Fix_concat_where_start_stop_invariant.md | 86 +++++++++++ .../next/iterator/ir_utils/domain_utils.py | 144 +++++++++++++----- src/gt4py/next/iterator/ir_utils/ir_makers.py | 4 +- .../next/iterator/transforms/global_tmps.py | 2 + .../ir_utils_tests/test_domain_utils.py | 60 ++++++-- 5 files changed, 250 insertions(+), 46 deletions(-) create mode 100644 Fix_concat_where_start_stop_invariant.md diff --git a/Fix_concat_where_start_stop_invariant.md b/Fix_concat_where_start_stop_invariant.md new file mode 100644 index 0000000000..bf893b391b --- /dev/null +++ b/Fix_concat_where_start_stop_invariant.md @@ -0,0 +1,86 @@ +# `SymbolicRange` invariant: `start ≠ +inf`, `stop ≠ -inf` + +`SymbolicRange.__post_init__` (in +`src/gt4py/next/iterator/ir_utils/domain_utils.py`) asserts: + +```python +assert self.start is not itir.InfinityLiteral.POSITIVE +assert self.stop is not itir.InfinityLiteral.NEGATIVE +``` + +## What the invariant means + +Infinities are only allowed **outward**. Inward infinities are forbidden: + +| form | allowed? | meaning | +|---|---|---| +| `[a, b)`, `[-inf, b)`, `[a, +inf)`, `[-inf, +inf)` | ✅ | finite / left-/right-/both-unbounded | +| `[+inf, b)`, `[a, -inf)`, `[+inf, -inf)` | ❌ | degenerate "empty from infinity" | + +The union-neutral range `[+inf, -inf)` is exactly the forbidden inward form. +Empty domains are conventionally represented elsewhere with **finite** bounds +(`start >= stop`, e.g. `[10, 10)` / `[0, 0)`), never with infinities — that is +the assumption the check enforces. + +## Where the assumption is relied upon + +### Silent-wrong (no crash) + +- **`empty()`** (`domain_utils.py`) — now classifies infinities explicitly: an + "inward" infinity (`start is POSITIVE` or `stop is NEGATIVE`) → `True` + (degenerate empty), an "outward" infinity (`start is NEGATIVE` or + `stop is POSITIVE`) → `False` (always non-empty), otherwise the literal / + equality checks as before. So the neutral `[+inf, -inf)` *is* detected as + empty, and half-infinite ranges are no longer treated as + statically-undecidable (which previously left `let`/`if` guard cruft in + inferred domains). +- **`is_finite`** — reports `[+inf, -inf)` as "not finite", but it is really + *empty*, not *unbounded*; callers gating on finiteness would mis-handle it. + +### Assertion crashes (encode the invariant; in practice only hit on half-infinite condition-domains) + +- **`domain_complement`** (`domain_utils.py`) — + `assert (lb == NEGATIVE) != (ub == POSITIVE)`. For `[+inf, -inf)` both sides + are `False` → `False != False` fails. +- **concat_where `_range_complement`** (canonicalize domain argument) — + `assert not any(isinstance(b, itir.InfinityLiteral) for b in (start, stop))` + requires finite bounds outright. + +### Hard breakage if an inward-infinity domain reaches lowering / execution + +This is the important class: a real (materialized) domain must never carry +`[+inf, -inf)`. + +- **Size computed as `stop - start`**: + - dace `gtir_domain.py:143` — `max(0, stop - start)` + - dace `gtir_to_sdfg_scan.py:380` — `scan_domain.stop - scan_domain.start` + - `Domain.size` / `shape` (`common.py:491`) + + With `stop = -inf`, `start = +inf` this is `-oo - (+oo)` → fragile/garbage + sympy, and `origin = +inf` is nonsensical. +- **`assert ...is_finite(...)` on the materialized domain**: embedded + `nd_array_field.py:189,1187` / `embedded/common.py:56`, dace + `sdfg_callable.py:22` → assertion failure at runtime. +- **Codegen of bounds**: gtfn / dace / roundtrip now emit `InfinityLiteral` + (`std::numeric_limits<...>::min()/max()`, `±sympy.oo`, `common.Infinity`). + Fine for an *outward* infinity in a genuine concat_where domain, but an + inward-infinity empty domain (`origin = +inf`, `size = -inf - +inf`) generates + nonsensical C++ / SDFG. + +## Bottom line + +`ConstantFolding` itself is largely tolerant of `[+inf, -inf)`: the `plus`-only +`assert not both-infinity` never sees two infinities; `minimum` / `maximum` +fold it correctly; `greater_equal(+inf, -inf)` folds to `True`. + +The blockers to allowing the inward/neutral form are therefore: +1. `empty()` won't detect it (silent), +2. the two complement assertions (crash, but not on neutral ranges in practice), +3. and — critically — it must **never escape to lowering/execution**, where + `stop - start` and the `is_finite` asserts would break. + +Today this cannot happen because the all-empty reduction returns a *finite* +empty range. A neutral-seed reduction whose inputs are all empty would instead +yield `[+inf, -inf)` and could flow downstream, so that path would need to +re-normalize all-empty results back to a finite empty range (or guarantee they +are dropped before codegen) before the invariant could be relaxed. diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index a3c77689d3..c8d473c69d 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -34,16 +34,44 @@ class SymbolicRange: start: itir.Expr stop: itir.Expr + #: Groups of `let` bindings that are in scope for both `start` and `stop`. Each group is a + #: simultaneous `let` (its bindings are mutually independent); the groups are nested, with the + #: first tuple element the outermost `let`, so a later group's values may reference an earlier + #: group's symbols. They let chained reductions (see `_reduce_ranges`) reference previously + #: computed bounds by a unique symbol (`__sd_start_0`, `__sd_start_1`, ...) instead of + #: duplicating them, which would otherwise blow up the expression size exponentially. Sharing one + #: group between `start` and `stop` keeps each bound stored exactly once. Materialized by + #: `as_expr`; the unique names make the nesting unambiguous (no shadowing). + bindings: tuple[dict[str, itir.Expr], ...] = () + + # See: Fix_concat_where_start_stop_invariant.md + #def __post_init__(self) -> None: + # # TODO(havogt): added this defensive checks as code seems to make this reasonable assumption + # assert self.start is not itir.InfinityLiteral.POSITIVE + # assert self.stop is not itir.InfinityLiteral.NEGATIVE - def __post_init__(self) -> None: - # TODO(havogt): added this defensive checks as code seems to make this reasonable assumption - assert self.start is not itir.InfinityLiteral.POSITIVE - assert self.stop is not itir.InfinityLiteral.NEGATIVE + def __hash__(self) -> int: + # `bindings` holds (mutable, unhashable) dicts; hash their items instead so `SymbolicRange` + # stays hashable (it is hashed e.g. via `frozenset` in `SymbolicDomain.__hash__`). + return hash((self.start, self.stop, tuple(tuple(g.items()) for g in self.bindings))) def translate(self, distance: int) -> SymbolicRange: - return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance)) + # constant fold so that translated literal bounds stay literal (otherwise `empty()` would + # treat e.g. `0 + 1` as symbolic and `_reduce_ranges` would needlessly guard them) + return SymbolicRange( + ConstantFolding.apply(im.plus(self.start, distance)), # type: ignore[arg-type] # always an itir.Expr + ConstantFolding.apply(im.plus(self.stop, distance)), # type: ignore[arg-type] # always an itir.Expr + self.bindings, + ) def empty(self) -> bool | None: + # an "inward" infinity (`start == +inf` or `stop == -inf`) is the degenerate empty range + if self.start is itir.InfinityLiteral.POSITIVE or self.stop is itir.InfinityLiteral.NEGATIVE: + return True + # an "outward" infinity (`start == -inf` or `stop == +inf`) is always non-empty as the + # opposite bound is finite (or the opposite outward infinity) + if self.start is itir.InfinityLiteral.NEGATIVE or self.stop is itir.InfinityLiteral.POSITIVE: + return False if isinstance(self.start, itir.Literal) and isinstance(self.stop, itir.Literal): start, stop = int(self.start.value), int(self.stop.value) return start >= stop @@ -51,6 +79,14 @@ def empty(self) -> bool | None: return True return None + def as_expr(self) -> tuple[itir.Expr, itir.Expr]: + """Materialize `start` and `stop`, wrapping the shared `bindings` groups as nested `let`s.""" + start, stop = self.start, self.stop + # groups are outermost-first; wrap the innermost (last) group first + for group in reversed(self.bindings): + start, stop = im.let(*group.items())(start), im.let(*group.items())(stop) + return start, stop + _GRID_TYPE_MAPPING = { "unstructured_domain": common.GridType.UNSTRUCTURED, @@ -162,7 +198,7 @@ def from_expr(cls, node: itir.Node) -> SymbolicDomain: def as_expr(self) -> itir.FunCall: converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = { - key: (value.start, value.stop) for key, value in self.ranges.items() + key: value.as_expr() for key, value in self.ranges.items() } return im.domain(self.grid_type, converted_ranges) @@ -241,60 +277,96 @@ def _reduce_ranges( *ranges: SymbolicRange, start_reduce_op: Callable[[itir.Expr, itir.Expr], itir.Expr], stop_reduce_op: Callable[[itir.Expr, itir.Expr], itir.Expr], - neutral_start_reduce_val: itir.InfinityLiteral, - neutral_stop_reduce_val: itir.InfinityLiteral, + neutral_reduce_val: SymbolicRange, ) -> SymbolicRange: """ Fold the start and stop of a list of ranges with `start_reduce_op` / `stop_reduce_op`. + The reduction is seeded with `neutral_reduce_val`, the operation's identity range (the empty + range for union, the universe range for intersection); an empty input range therefore folds to + that same neutral and leaves the result unchanged. + This function only computes the correct value if the ranges are either overlapping / adjacent or empty as calculation is by means of the convex hull (and some special handling for empty ranges). """ # symbolic ranges, i.e., `empty()` is `None` must not be dropped; they are guarded below non_empty_ranges = [range_ for range_ in ranges if range_.empty() is not True] - - start: itir.Expr - stop: itir.Expr if len(non_empty_ranges) == 0: - start, stop = ranges[0].start, ranges[0].stop - elif len(non_empty_ranges) == 1: - start, stop = non_empty_ranges[0].start, non_empty_ranges[0].stop - else: - - def guarded_bound(bound: itir.Expr, range_: SymbolicRange, neutral: itir.Expr) -> itir.Expr: - # guard a symbolically-empty range so it contributes `neutral` instead of `bound` - if range_.empty() is None: - return im.if_(im.greater_equal(range_.start, range_.stop), neutral, bound) - return bound + return ranges[0] + if len(non_empty_ranges) == 1: + return non_empty_ranges[0] # the reduction of a single range is the range itself + + def guarded(start_expr: itir.Expr, stop_expr: itir.Expr, bound: itir.Expr, neutral: itir.Expr): + # an empty range contributes the reduction's `neutral` element instead of `bound` + return im.if_(im.greater_equal(start_expr, stop_expr), neutral, bound) + + def next_binding_index(groups: tuple[dict[str, itir.Expr], ...]) -> int: + # A fresh index above the highest existing one never collides with a symbol in scope. + indices = [ + int(name.removeprefix("__sd_start_")) + for group in groups + for name in group + if name.startswith("__sd_start_") + ] + return max(indices, default=-1) + 1 + + # Carry all inputs' binding groups as the outer (nested) `let`s; the bounds bound here go into + # one fresh innermost group, so `start` and `stop` share them (stored once) and the guards + # reference cheap symbols instead of duplicating the (chained, possibly large) sub-expressions -- + # keeping the result size linear. By contract we are the only allocator of `__sd_*` symbols, so + # equal names carry equal values and merging the (outermost-first aligned) groups is safe. + depth = max(len(range_.bindings) for range_ in non_empty_ranges) + outer_groups = tuple( + {name: value for range_ in non_empty_ranges if d < len(range_.bindings) + for name, value in range_.bindings[d].items()} + for d in range(depth) + ) - start = functools.reduce( - start_reduce_op, - [guarded_bound(r.start, r, neutral_start_reduce_val) for r in non_empty_ranges], - ) - stop = functools.reduce( - stop_reduce_op, - [guarded_bound(r.stop, r, neutral_stop_reduce_val) for r in non_empty_ranges], - ) + new_group: dict[str, itir.Expr] = {} + i = next_binding_index(outer_groups) + acc_start, acc_stop = neutral_reduce_val.start, neutral_reduce_val.stop + for range_ in non_empty_ranges: + if range_.empty() is None: + start_name, stop_name = f"__sd_start_{i}", f"__sd_stop_{i}" + i += 1 + # `range_.start`/`range_.stop` reference the range's own (outer) groups, which are in + # scope as this new group is the innermost one + new_group[start_name], new_group[stop_name] = range_.start, range_.stop + start_ref, stop_ref = im.ref(start_name), im.ref(stop_name) + r_start = guarded(start_ref, stop_ref, start_ref, neutral_reduce_val.start) + r_stop = guarded(start_ref, stop_ref, stop_ref, neutral_reduce_val.stop) + else: + r_start, r_stop = range_.start, range_.stop + acc_start = start_reduce_op(acc_start, r_start) + acc_stop = stop_reduce_op(acc_stop, r_stop) - # constant fold expression to keep the tree small - start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr - return SymbolicRange(start, stop) + groups = (*outer_groups, new_group) if new_group else outer_groups + # constant fold only the final result (binding values come from inputs that were already folded) + return SymbolicRange( + ConstantFolding.apply(acc_start), # type: ignore[arg-type] # always an itir.Expr + ConstantFolding.apply(acc_stop), # type: ignore[arg-type] # always an itir.Expr + groups, + ) _range_union = functools.partial( _reduce_ranges, start_reduce_op=im.minimum, - neutral_start_reduce_val=itir.InfinityLiteral.POSITIVE, stop_reduce_op=im.maximum, - neutral_stop_reduce_val=itir.InfinityLiteral.NEGATIVE, + # neutral element of union is the empty range `[+inf, -inf[` + neutral_reduce_val=SymbolicRange( + itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE + ), ) _range_intersection = functools.partial( _reduce_ranges, start_reduce_op=im.maximum, - neutral_start_reduce_val=itir.InfinityLiteral.NEGATIVE, stop_reduce_op=im.minimum, - neutral_stop_reduce_val=itir.InfinityLiteral.POSITIVE, + # neutral element of intersection is the universe range `]-inf, +inf[` + neutral_reduce_val=SymbolicRange( + itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE + ), ) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 561db992d8..c2a086dfae 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -291,7 +291,9 @@ def __init__(self, *args): "Invalid arguments: expected a variable name and an init form or a list thereof." ) - def __call__(self, form) -> itir.FunCall: + def __call__(self, form) -> itir.Expr: + if not self.vars: # no bindings: the `let` is a no-op + return ensure_expr(form) return call(lambda_(*self.vars)(form))(*self.init_forms) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 85b9176e09..2d0ad59fc7 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -157,6 +157,8 @@ def _transform_by_pattern( projector, expr = ir_utils_misc.extract_projector(stmt.expr) def wrapped_predicate(expr: itir.Expr, num_occurences: int) -> bool: + if not isinstance(expr, itir.Lambda): # TODO: e.g. test_tuple_different_domain + type_inference.reinfer(expr) assert isinstance(expr.type, ts.TypeSpec) or isinstance(expr, itir.Lambda) if isinstance(expr.type, ts.TypeSpec) and not type_info.is_type_or_tuple_of_type( expr.type, ts.FieldType diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py index 17bebccdf5..f51f30327b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_tests/test_domain_utils.py @@ -12,6 +12,7 @@ from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import domain_utils, ir_makers as im +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas from gt4py.next import common, constructors I = common.Dimension("I") @@ -53,13 +54,6 @@ def _guard(bound: str, neutral: itir.InfinityLiteral): ) -def test_symbolic_range(): - with pytest.raises(AssertionError): - domain_utils.SymbolicRange(itir.InfinityLiteral.POSITIVE, 0) - with pytest.raises(AssertionError): - domain_utils.SymbolicRange(0, itir.InfinityLiteral.NEGATIVE) - - def test_domain_op_preconditions(): domain_a = domain_utils.SymbolicDomain( grid_type=common.GridType.CARTESIAN, @@ -102,7 +96,10 @@ def test_domain_union(): for dim, dim_name in ((I, "I"), (J, "J")) }, ) - assert expected == domain_utils.domain_union(domain0, domain1, domain2) + # the result shares the guarded bounds via `let` bindings; inline them once to compare against + # the (binding-free) expected form + result = domain_utils.domain_union(domain0, domain1, domain2) + assert InlineLambdas.apply(result.as_expr()) == expected.as_expr() def test_domain_union_drops_empty_domains(): @@ -179,7 +176,52 @@ def test_domain_intersection(): for dim in (I, J) }, ) - assert expected == domain_utils.domain_intersection(domain0, domain1, domain2) + result = domain_utils.domain_intersection(domain0, domain1, domain2) + assert InlineLambdas.apply(result.as_expr()) == expected.as_expr() + + +def test_domain_union_then_intersection(): + # Cross-operation case: the (symbolic) union result is re-guarded when it flows into the + # intersection, so an empty union folds to the intersection's neutral (the universe) instead of + # surviving `max`/`min` and wrongly constraining the result. + union_result = domain_utils.domain_union(_make_domain(0), _make_domain(1)) + result = domain_utils.domain_intersection(union_result, _make_domain(2)) + + pos, neg = itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE + + def union_start(d: str): + return im.minimum(_guard(f"start_{d}_0", pos), _guard(f"start_{d}_1", pos)) + + def union_stop(d: str): + return im.maximum(_guard(f"end_{d}_0", neg), _guard(f"end_{d}_1", neg)) + + expected = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, + ranges={ + dim: domain_utils.SymbolicRange( + im.maximum( + # union result re-guarded with the intersection's start neutral + im.if_( + im.greater_equal(union_start(dim_name), union_stop(dim_name)), + neg, + union_start(dim_name), + ), + _guard(f"start_{dim_name}_2", neg), + ), + im.minimum( + # union result re-guarded with the intersection's stop neutral + im.if_( + im.greater_equal(union_start(dim_name), union_stop(dim_name)), + pos, + union_stop(dim_name), + ), + _guard(f"end_{dim_name}_2", pos), + ), + ) + for dim, dim_name in ((I, "I"), (J, "J")) + }, + ) + assert InlineLambdas.apply(result.as_expr()) == expected.as_expr() @pytest.mark.parametrize(