Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def __str__(self) -> str:
def __call__(self, val: int) -> NamedIndex:
return NamedIndex(self, val)

def __add__(self, offset: int) -> Connectivity:
return CartesianConnectivity(self, offset)
def __add__(self, offset: int | float) -> Connectivity:
return connectivity_for_cartesian_shift(self, offset)

def __sub__(self, offset: int) -> Connectivity:
def __sub__(self, offset: int | float) -> Connectivity:
return self + (-offset)

def __gt__(self, value: core_defs.IntegralScalar) -> Domain:
Expand Down Expand Up @@ -1336,7 +1336,13 @@ def order_dimensions(dims: Iterable[Dimension]) -> list[Dimension]:
"""Find the canonical ordering of the dimensions in `dims`."""
if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1:
raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.")
return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value))
return sorted(
dims,
key=lambda dim: (
_DIM_KIND_ORDER[dim.kind],
as_non_staggered(dim).value,
),
)


def check_dims(dims: Sequence[Dimension]) -> None:
Expand Down Expand Up @@ -1424,3 +1430,32 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call
#: Equivalent to the `_FillValue` attribute in the UGRID Conventions
#: (see: http://ugrid-conventions.github.io/ugrid-conventions/).
_DEFAULT_SKIP_VALUE: Final[int] = -1
_STAGGERED_PREFIX = "_Staggered"


def is_staggered(dim: Dimension) -> bool:
return dim.value.startswith(_STAGGERED_PREFIX)


def flip_staggered(dim: Dimension) -> Dimension:
if is_staggered(dim):
return Dimension(dim.value[len(_STAGGERED_PREFIX) :], dim.kind)
else:
return Dimension(f"{_STAGGERED_PREFIX}{dim.value}", dim.kind)


def as_non_staggered(dim: Dimension) -> Dimension:
Comment thread
tehrengruber marked this conversation as resolved.
if is_staggered(dim):
return flip_staggered(dim)
return dim


def connectivity_for_cartesian_shift(dim: Dimension, offset: int | float) -> CartesianConnectivity:
if isinstance(offset, float):
integral_offset, half = divmod(offset, 1)
assert half == 0.5
if not is_staggered(dim):
integral_offset += 1
return CartesianConnectivity(dim, int(integral_offset), codomain=flip_staggered(dim))
else:
return CartesianConnectivity(dim, offset, codomain=dim)
4 changes: 1 addition & 3 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,7 @@ def __getitem__(self, offset: int) -> common.Connectivity:
offset_definition = common.get_offset(current_offset_provider, self.value)

connectivity: common.Connectivity
if isinstance(offset_definition, common.Dimension):
connectivity = common.CartesianConnectivity(offset_definition, offset)
elif isinstance(offset_definition, common.Connectivity):
if isinstance(offset_definition, common.Connectivity):
assert common.is_neighbor_table(offset_definition)
named_index = common.NamedIndex(self.target[-1], offset)
connectivity = offset_definition[named_index]
Expand Down
25 changes: 17 additions & 8 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import gt4py.next.ffront.field_operator_ast as foast
from gt4py import eve
from gt4py.eve import NodeTranslator, NodeVisitor, traits
from gt4py.next import errors
from gt4py.next import common, errors
from gt4py.next.common import Dimension, DimensionKind, promote_dims
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -607,13 +607,6 @@ def _deduce_compare_type(
def _deduce_binop_type(
self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# e.g. `IDim+1`
if (
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and type_info.is_integral(right.type)
):
return ts.OffsetType(source=left.type.dim, target=(left.type.dim,))
if isinstance(left.type, ts.OffsetType):
raise errors.DSLError(
node.location, f"Type '{left.type}' can not be used in operator '{node.op}'."
Expand Down Expand Up @@ -679,6 +672,22 @@ def _deduce_binop_type(
f"must be one of {', '.join((str(op) for op in logical_ops))}.",
)
return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims))
elif (
node.op in (dialect_ast_enums.BinaryOperator.ADD, dialect_ast_enums.BinaryOperator.SUB)
and isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and type_info.is_arithmetic(right.type)
):
# e.g. `IDim+1` or `IDim+0.5`
if not isinstance(right, foast.Constant):
raise NotImplementedError(
"Cartesian offsets are only supported with literal rhs, e.g. `IDim + 1`, but not `IDim + expr`."
)
offset_index = right.value
if node.op == dialect_ast_enums.BinaryOperator.SUB:
offset_index *= -1
conn = common.connectivity_for_cartesian_shift(left.type.dim, offset_index)
return ts.OffsetType(source=conn.codomain, target=(conn.domain_dim,))
else:
raise errors.DSLError(node.location, err_msg)

Expand Down
16 changes: 10 additions & 6 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from gt4py import eve
from gt4py.eve.extended_typing import Never, cast
from gt4py.next import utils
from gt4py.next import common, utils
from gt4py.next.ffront import (
dialect_ast_enums,
experimental as experimental_builtins,
Expand Down Expand Up @@ -318,19 +318,23 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
current_expr = im.as_fieldop(
im.lambda_("__it")(im.deref(im.shift(shift_offset, new_index)("__it")))
)(current_expr)
# `field(Dim + idx)`
# `field(Dim + idx)` (and the staggered/relocation `field(Dim ± 0.5)`)
case foast.BinOp(
op=dialect_ast_enums.BinaryOperator.ADD | dialect_ast_enums.BinaryOperator.SUB,
left=foast.Name() as dim_name,
left=foast.Name(), # TODO(tehrengruber): use type instead
right=foast.Constant(value=offset_index),
):
if arg.op == dialect_ast_enums.BinaryOperator.SUB:
offset_index *= -1
assert isinstance(dim_name.type, ts.DimensionType)
dim = dim_name.type.dim
assert isinstance(arg.left.type, ts.DimensionType)
conn = common.connectivity_for_cartesian_shift(arg.left.type.dim, offset_index)
current_expr = im.as_fieldop(
im.lambda_("__it")(
im.deref(im.shift(im.cartesian_offset(dim), offset_index)("__it"))
im.deref(
im.shift(
im.cartesian_offset(conn.domain_dim, conn.codomain), conn.offset
)("__it")
)
)
)(current_expr)
# `field(Off)`
Expand Down
11 changes: 3 additions & 8 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,16 +587,11 @@ def execute_shift(
# the assertions above confirm pos is incomplete casting here to avoid duplicating work in a type guard
return cast(IncompletePosition, pos) | {tag: new_entry}

# a `CartesianConnectivity` tag is a self-describing cartesian shift (e.g. from a
# `CartesianOffset` IR node); named offsets are resolved through the offset provider
if isinstance(tag, common.CartesianConnectivity):
assert tag.domain_dim == tag.codomain # relocation (staggering) is not supported here
new_pos = copy.copy(pos)
key = tag.domain_dim.value
if common.is_int_index(value := new_pos[key]):
new_pos[key] = value + index + tag.offset
else:
raise AssertionError()
value = new_pos.pop(tag.domain_dim.value)
assert common.is_int_index(value)
new_pos[tag.codomain.value] = value + index + tag.offset
return new_pos
offset_implementation = common.get_offset(offset_provider, tag)
if common.is_neighbor_table(offset_implementation):
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __str__(self):
InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE")


# TODO(tehrengruber): allow int only and create OffsetRef for str instead
class OffsetLiteral(Expr):
value: Union[int, str]

Expand Down
48 changes: 34 additions & 14 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from gt4py.next import common
from gt4py.next.iterator import builtins, ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im, misc
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms.constant_folding import ConstantFolding

Expand Down Expand Up @@ -173,36 +173,54 @@ def translate(
#: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details.
symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None,
) -> SymbolicDomain:
offset_provider_type = common.offset_provider_to_type(offset_provider)

dims = list(self.ranges.keys())
new_ranges = {dim: self.ranges[dim] for dim in dims}
if len(shift) == 0:
return self
if len(shift) == 2:
off, val = shift

connectivity: common.Connectivity
if isinstance(off, itir.CartesianOffset):
domain = common.Dimension(value=off.domain.value, kind=off.domain.kind)
codomain = common.Dimension(value=off.codomain.value, kind=off.codomain.kind)
connectivity = common.CartesianConnectivity(domain, codomain=codomain)
elif isinstance(off, itir.OffsetLiteral):
assert isinstance(off.value, str)
# `get_offset` accepts both `OffsetProvider` and `OffsetProviderType`, but is typed
# for the former only (the two overlap, so an overload is not possible).
connectivity = common.get_offset(offset_provider, off.value) # type: ignore[arg-type]
else:
raise AssertionError()

if isinstance(connectivity, common.CartesianConnectivity):
if val is trace_shifts.Sentinel.VALUE:
raise NotImplementedError("Dynamic offsets not supported.")
assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
dom = misc.dim_from_axis_literal(off.domain)
cod = misc.dim_from_axis_literal(off.codomain)
assert dom == cod # relocation (staggering) is not supported here
new_ranges[dom] = SymbolicRange.translate(self.ranges[dom], val.value)
return SymbolicDomain(self.grid_type, new_ranges)
assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str)
connectivity_type = common.get_offset_type(offset_provider_type, off.value)

if isinstance(connectivity_type, common.NeighborConnectivityType):
assert len(connectivity.domain.dims) == 1

old_dim = connectivity.domain_dim
new_dim = connectivity.codomain

assert new_dim not in new_ranges or old_dim == new_dim

new_range = SymbolicRange.translate(
self.ranges[old_dim], connectivity.offset + val.value
)
new_ranges = dict(
(dim, range_) if dim != old_dim else (new_dim, new_range)
for dim, range_ in new_ranges.items()
)
elif common.is_neighbor_table(connectivity):
# unstructured shift
assert (
isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
) or val in [
trace_shifts.Sentinel.ALL_NEIGHBORS,
trace_shifts.Sentinel.VALUE,
]
old_dim = connectivity_type.source_dim
new_dim = connectivity_type.codomain
old_dim = connectivity.domain.dims[0]
new_dim = connectivity.codomain
assert new_dim not in new_ranges or old_dim == new_dim
if symbolic_domain_sizes is not None and new_dim.value in symbolic_domain_sizes:
new_range = SymbolicRange(
Expand All @@ -212,6 +230,7 @@ def translate(
else:
assert common.is_offset_provider(offset_provider)
assert not isinstance(val, itir.CartesianOffset) # offset value, never a node
assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str)
new_range = _unstructured_translate_range_statically(
new_ranges[old_dim], off.value, val, offset_provider, self.as_expr()
)
Expand All @@ -222,6 +241,7 @@ def translate(
)
else:
raise AssertionError()

return SymbolicDomain(self.grid_type, new_ranges)
elif len(shift) > 2:
return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate(
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/ir_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,16 @@ def dim_from_axis_literal(axis_literal: itir.AxisLiteral) -> common.Dimension:
return common.Dimension(value=axis_literal.value, kind=axis_literal.kind)


def connectivity_from_cartesian_offset(
cart_offset: itir.CartesianOffset,
) -> common.CartesianConnectivity:
return common.CartesianConnectivity(
domain_dim=dim_from_axis_literal(cart_offset.domain),
codomain=dim_from_axis_literal(cart_offset.codomain),
offset=0,
)


def _flatten_tuple_expr(expr: itir.Expr) -> tuple[itir.Expr]:
if cpm.is_call_to(expr, "make_tuple"):
return sum(
Expand Down
Loading
Loading