Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions Fix_concat_where_start_stop_invariant.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# `SymbolicRange` invariant: `start ≠ +inf`, `stop ≠ -inf`

`SymbolicRange.__post_init__` (in
`src/gt4py/next/iterator/ir_utils/domain_utils.py`) asserts:

```python
assert self.start is not itir.InfinityLiteral.POSITIVE
assert self.stop is not itir.InfinityLiteral.NEGATIVE
```

## What the invariant means

Infinities are only allowed **outward**. Inward infinities are forbidden:

| form | allowed? | meaning |
|---|---|---|
| `[a, b)`, `[-inf, b)`, `[a, +inf)`, `[-inf, +inf)` | ✅ | finite / left-/right-/both-unbounded |
| `[+inf, b)`, `[a, -inf)`, `[+inf, -inf)` | ❌ | degenerate "empty from infinity" |

The union-neutral range `[+inf, -inf)` is exactly the forbidden inward form.
Empty domains are conventionally represented elsewhere with **finite** bounds
(`start >= stop`, e.g. `[10, 10)` / `[0, 0)`), never with infinities — that is
the assumption the check enforces.

## Where the assumption is relied upon

### Silent-wrong (no crash)

- **`empty()`** (`domain_utils.py`) — now classifies infinities explicitly: an
"inward" infinity (`start is POSITIVE` or `stop is NEGATIVE`) → `True`
(degenerate empty), an "outward" infinity (`start is NEGATIVE` or
`stop is POSITIVE`) → `False` (always non-empty), otherwise the literal /
equality checks as before. So the neutral `[+inf, -inf)` *is* detected as
empty, and half-infinite ranges are no longer treated as
statically-undecidable (which previously left `let`/`if` guard cruft in
inferred domains).
- **`is_finite`** — reports `[+inf, -inf)` as "not finite", but it is really
*empty*, not *unbounded*; callers gating on finiteness would mis-handle it.

### Assertion crashes (encode the invariant; in practice only hit on half-infinite condition-domains)

- **`domain_complement`** (`domain_utils.py`) —
`assert (lb == NEGATIVE) != (ub == POSITIVE)`. For `[+inf, -inf)` both sides
are `False` → `False != False` fails.
- **concat_where `_range_complement`** (canonicalize domain argument) —
`assert not any(isinstance(b, itir.InfinityLiteral) for b in (start, stop))`
requires finite bounds outright.

### Hard breakage if an inward-infinity domain reaches lowering / execution

This is the important class: a real (materialized) domain must never carry
`[+inf, -inf)`.

- **Size computed as `stop - start`**:
- dace `gtir_domain.py:143` — `max(0, stop - start)`
- dace `gtir_to_sdfg_scan.py:380` — `scan_domain.stop - scan_domain.start`
- `Domain.size` / `shape` (`common.py:491`)

With `stop = -inf`, `start = +inf` this is `-oo - (+oo)` → fragile/garbage
sympy, and `origin = +inf` is nonsensical.
- **`assert ...is_finite(...)` on the materialized domain**: embedded
`nd_array_field.py:189,1187` / `embedded/common.py:56`, dace
`sdfg_callable.py:22` → assertion failure at runtime.
- **Codegen of bounds**: gtfn / dace / roundtrip now emit `InfinityLiteral`
(`std::numeric_limits<...>::min()/max()`, `±sympy.oo`, `common.Infinity`).
Fine for an *outward* infinity in a genuine concat_where domain, but an
inward-infinity empty domain (`origin = +inf`, `size = -inf - +inf`) generates
nonsensical C++ / SDFG.

## Bottom line

`ConstantFolding` itself is largely tolerant of `[+inf, -inf)`: the `plus`-only
`assert not both-infinity` never sees two infinities; `minimum` / `maximum`
fold it correctly; `greater_equal(+inf, -inf)` folds to `True`.

The blockers to allowing the inward/neutral form are therefore:
1. `empty()` won't detect it (silent),
2. the two complement assertions (crash, but not on neutral ranges in practice),
3. and — critically — it must **never escape to lowering/execution**, where
`stop - start` and the `is_finite` asserts would break.

Today this cannot happen because the all-empty reduction returns a *finite*
empty range. A neutral-seed reduction whose inputs are all empty would instead
yield `[+inf, -inf)` and could flow downstream, so that path would need to
re-normalize all-empty results back to a finite empty range (or guarantee they
are dropped before codegen) before the invariant could be relaxed.
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ class NoneLiteral(Expr):


class InfinityLiteral(Expr):
"""
Infinity value appearing in concat_where domain selectors and domain bounds (to guard domain
operators on empty ranges).
"""

# TODO(tehrengruber): self referential `ClassVar` not supported in eve.
if TYPE_CHECKING:
POSITIVE: ClassVar[InfinityLiteral]
Expand Down
148 changes: 130 additions & 18 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,59 @@
class SymbolicRange:
start: itir.Expr
stop: itir.Expr
#: Groups of `let` bindings that are in scope for both `start` and `stop`. Each group is a
#: simultaneous `let` (its bindings are mutually independent); the groups are nested, with the
#: first tuple element the outermost `let`, so a later group's values may reference an earlier
#: group's symbols. They let chained reductions (see `_reduce_ranges`) reference previously
#: computed bounds by a unique symbol (`__sd_start_0`, `__sd_start_1`, ...) instead of
#: duplicating them, which would otherwise blow up the expression size exponentially. Sharing one
#: group between `start` and `stop` keeps each bound stored exactly once. Materialized by
#: `as_expr`; the unique names make the nesting unambiguous (no shadowing).
bindings: tuple[dict[str, itir.Expr], ...] = ()

# See: Fix_concat_where_start_stop_invariant.md
#def __post_init__(self) -> None:
# # TODO(havogt): added this defensive checks as code seems to make this reasonable assumption
# assert self.start is not itir.InfinityLiteral.POSITIVE
# assert self.stop is not itir.InfinityLiteral.NEGATIVE

def __post_init__(self) -> None:
# TODO(havogt): added this defensive checks as code seems to make this reasonable assumption
assert self.start is not itir.InfinityLiteral.POSITIVE
assert self.stop is not itir.InfinityLiteral.NEGATIVE
def __hash__(self) -> int:
# `bindings` holds (mutable, unhashable) dicts; hash their items instead so `SymbolicRange`
# stays hashable (it is hashed e.g. via `frozenset` in `SymbolicDomain.__hash__`).
return hash((self.start, self.stop, tuple(tuple(g.items()) for g in self.bindings)))

def translate(self, distance: int) -> SymbolicRange:
return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance))
# constant fold so that translated literal bounds stay literal (otherwise `empty()` would
# treat e.g. `0 + 1` as symbolic and `_reduce_ranges` would needlessly guard them)
return SymbolicRange(
ConstantFolding.apply(im.plus(self.start, distance)), # type: ignore[arg-type] # always an itir.Expr
ConstantFolding.apply(im.plus(self.stop, distance)), # type: ignore[arg-type] # always an itir.Expr
self.bindings,
)

def empty(self) -> bool | None:
# an "inward" infinity (`start == +inf` or `stop == -inf`) is the degenerate empty range
if self.start is itir.InfinityLiteral.POSITIVE or self.stop is itir.InfinityLiteral.NEGATIVE:
return True
# an "outward" infinity (`start == -inf` or `stop == +inf`) is always non-empty as the
# opposite bound is finite (or the opposite outward infinity)
if self.start is itir.InfinityLiteral.NEGATIVE or self.stop is itir.InfinityLiteral.POSITIVE:
return False
if isinstance(self.start, itir.Literal) and isinstance(self.stop, itir.Literal):
start, stop = int(self.start.value), int(self.stop.value)
return start >= stop
elif self.start == self.stop:
return True
return None

def as_expr(self) -> tuple[itir.Expr, itir.Expr]:
"""Materialize `start` and `stop`, wrapping the shared `bindings` groups as nested `let`s."""
start, stop = self.start, self.stop
# groups are outermost-first; wrap the innermost (last) group first
for group in reversed(self.bindings):
start, stop = im.let(*group.items())(start), im.let(*group.items())(stop)
return start, stop


_GRID_TYPE_MAPPING = {
"unstructured_domain": common.GridType.UNSTRUCTURED,
Expand Down Expand Up @@ -81,6 +117,12 @@ def _unstructured_translate_range_statically(
assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal)
start, stop = int(start_expr.value), int(stop_expr.value)

if range_.empty():
return SymbolicRange(
im.literal(str("0"), builtins.INTEGER_INDEX_BUILTIN),
im.literal(str("0"), builtins.INTEGER_INDEX_BUILTIN),
)

nb_index: slice | int
if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]:
nb_index = slice(None)
Expand Down Expand Up @@ -156,7 +198,7 @@ def from_expr(cls, node: itir.Node) -> SymbolicDomain:

def as_expr(self) -> itir.FunCall:
converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = {
key: (value.start, value.stop) for key, value in self.ranges.items()
key: value.as_expr() for key, value in self.ranges.items()
}
return im.domain(self.grid_type, converted_ranges)

Expand Down Expand Up @@ -235,26 +277,96 @@ def _reduce_ranges(
*ranges: SymbolicRange,
start_reduce_op: Callable[[itir.Expr, itir.Expr], itir.Expr],
stop_reduce_op: Callable[[itir.Expr, itir.Expr], itir.Expr],
neutral_reduce_val: SymbolicRange,
) -> SymbolicRange:
"""Uses start_op and stop_op to fold the start and stop of a list of ranges."""
start = functools.reduce(
lambda current_expr, el_expr: start_reduce_op(current_expr, el_expr),
[range_.start for range_ in ranges],
"""
Fold the start and stop of a list of ranges with `start_reduce_op` / `stop_reduce_op`.

The reduction is seeded with `neutral_reduce_val`, the operation's identity range (the empty
range for union, the universe range for intersection); an empty input range therefore folds to
that same neutral and leaves the result unchanged.

This function only computes the correct value if the ranges are either overlapping / adjacent
or empty as calculation is by means of the convex hull (and some special handling for empty
ranges).
"""
# symbolic ranges, i.e., `empty()` is `None` must not be dropped; they are guarded below
non_empty_ranges = [range_ for range_ in ranges if range_.empty() is not True]
if len(non_empty_ranges) == 0:
return ranges[0]
if len(non_empty_ranges) == 1:
return non_empty_ranges[0] # the reduction of a single range is the range itself

def guarded(start_expr: itir.Expr, stop_expr: itir.Expr, bound: itir.Expr, neutral: itir.Expr):
# an empty range contributes the reduction's `neutral` element instead of `bound`
return im.if_(im.greater_equal(start_expr, stop_expr), neutral, bound)

def next_binding_index(groups: tuple[dict[str, itir.Expr], ...]) -> int:
# A fresh index above the highest existing one never collides with a symbol in scope.
indices = [
int(name.removeprefix("__sd_start_"))
for group in groups
for name in group
if name.startswith("__sd_start_")
]
return max(indices, default=-1) + 1

# Carry all inputs' binding groups as the outer (nested) `let`s; the bounds bound here go into
# one fresh innermost group, so `start` and `stop` share them (stored once) and the guards
# reference cheap symbols instead of duplicating the (chained, possibly large) sub-expressions --
# keeping the result size linear. By contract we are the only allocator of `__sd_*` symbols, so
# equal names carry equal values and merging the (outermost-first aligned) groups is safe.
depth = max(len(range_.bindings) for range_ in non_empty_ranges)
outer_groups = tuple(
{name: value for range_ in non_empty_ranges if d < len(range_.bindings)
for name, value in range_.bindings[d].items()}
for d in range(depth)
)
stop = functools.reduce(
lambda current_expr, el_expr: stop_reduce_op(current_expr, el_expr),
[range_.stop for range_ in ranges],

new_group: dict[str, itir.Expr] = {}
i = next_binding_index(outer_groups)
acc_start, acc_stop = neutral_reduce_val.start, neutral_reduce_val.stop
for range_ in non_empty_ranges:
if range_.empty() is None:
start_name, stop_name = f"__sd_start_{i}", f"__sd_stop_{i}"
i += 1
# `range_.start`/`range_.stop` reference the range's own (outer) groups, which are in
# scope as this new group is the innermost one
new_group[start_name], new_group[stop_name] = range_.start, range_.stop
start_ref, stop_ref = im.ref(start_name), im.ref(stop_name)
r_start = guarded(start_ref, stop_ref, start_ref, neutral_reduce_val.start)
r_stop = guarded(start_ref, stop_ref, stop_ref, neutral_reduce_val.stop)
else:
r_start, r_stop = range_.start, range_.stop
acc_start = start_reduce_op(acc_start, r_start)
acc_stop = stop_reduce_op(acc_stop, r_stop)

groups = (*outer_groups, new_group) if new_group else outer_groups
# constant fold only the final result (binding values come from inputs that were already folded)
return SymbolicRange(
ConstantFolding.apply(acc_start), # type: ignore[arg-type] # always an itir.Expr
ConstantFolding.apply(acc_stop), # type: ignore[arg-type] # always an itir.Expr
groups,
)
# constant fold expression to keep the tree small
start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr
return SymbolicRange(start, stop)


_range_union = functools.partial(
_reduce_ranges, start_reduce_op=im.minimum, stop_reduce_op=im.maximum
_reduce_ranges,
start_reduce_op=im.minimum,
stop_reduce_op=im.maximum,
# neutral element of union is the empty range `[+inf, -inf[`
neutral_reduce_val=SymbolicRange(
itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE
),
)
_range_intersection = functools.partial(
_reduce_ranges, start_reduce_op=im.maximum, stop_reduce_op=im.minimum
_reduce_ranges,
start_reduce_op=im.maximum,
stop_reduce_op=im.minimum,
# neutral element of intersection is the universe range `]-inf, +inf[`
neutral_reduce_val=SymbolicRange(
itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE
),
)


Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def __init__(self, *args):
"Invalid arguments: expected a variable name and an init form or a list thereof."
)

def __call__(self, form) -> itir.FunCall:
def __call__(self, form) -> itir.Expr:
if not self.vars: # no bindings: the `let` is a no-op
return ensure_expr(form)
return call(lambda_(*self.vars)(form))(*self.init_forms)


Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def make_node(o):
return o
if isinstance(o, common.Dimension):
return AxisLiteral(value=o.value, kind=o.kind)
if isinstance(o, common.Infinity):
if o is common.Infinity.POSITIVE:
return itir.InfinityLiteral.POSITIVE
else:
assert o is common.Infinity.NEGATIVE
return itir.InfinityLiteral.NEGATIVE
if isinstance(o, common.CartesianConnectivity):
# TODO(havogt): `itir.CartesianOffset` cannot represent `offset != 0` (embedded honors
# it, see `execute_shift`); decide whether to fold it into the shift value or forbid it.
Expand Down
13 changes: 12 additions & 1 deletion src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,20 @@ def _transform_by_pattern(
# hide projector from extraction
projector, expr = ir_utils_misc.extract_projector(stmt.expr)

def wrapped_predicate(expr: itir.Expr, num_occurences: int) -> bool:
if not isinstance(expr, itir.Lambda): # TODO: e.g. test_tuple_different_domain
type_inference.reinfer(expr)
assert isinstance(expr.type, ts.TypeSpec) or isinstance(expr, itir.Lambda)
if isinstance(expr.type, ts.TypeSpec) and not type_info.is_type_or_tuple_of_type(
expr.type, ts.FieldType
):
return False

return predicate(expr, num_occurences)

new_expr, extracted_fields, _ = cse.extract_subexpression(
expr,
predicate=predicate,
predicate=wrapped_predicate,
prefixed_uids=uids["__tmp_subexpr"],
# TODO(tehrengruber): extracting the deepest expression first would allow us to fuse
# the extracted expressions resulting in fewer kernel calls & better data-locality.
Expand Down
5 changes: 1 addition & 4 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,7 @@ def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType:
return node.type

def visit_InfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType:
return ts.ScalarType(kind=ts.ScalarKind.INT32)

def visit_NegInfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType:
return ts.ScalarType(kind=ts.ScalarKind.INT32)
return ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()))

def visit_SymRef(
self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec]
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/program_processors/codegens/gtfn/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gt4py.eve import codegen
from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako
from gt4py.next import common
from gt4py.next.iterator import builtins as itir_builtins
from gt4py.next.otf import cpp_utils
from gt4py.next.program_processors.codegens.gtfn import gtfn_im_ir, gtfn_ir, gtfn_ir_common

Expand Down Expand Up @@ -92,6 +93,15 @@ def visit_SymRef(self, node: gtfn_ir_common.SymRef, **kwargs: Any) -> str:

return node.id

def visit_InfinityLiteral(self, node: gtfn_ir.InfinityLiteral, **kwargs: Any) -> str:
# `±∞` only appears in (integer) domain bounds; use the limit of the index type.
cpptype = cpp_utils.pytype_to_cpptype(itir_builtins.INTEGER_INDEX_BUILTIN)
if node == gtfn_ir.InfinityLiteral.POSITIVE:
return f"std::numeric_limits<{cpptype}>::max()"
else:
assert node == gtfn_ir.InfinityLiteral.NEGATIVE
return f"std::numeric_limits<{cpptype}>::min()"

def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str:
# TODO(tehrengruber): isn't this wrong and int32 should be casted to an actual int32?
match node.type:
Expand Down
Loading
Loading