diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index d71f0e7f4f..84a753eb03 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -86,10 +86,10 @@ def __str__(self) -> str: def __call__(self, val: int) -> NamedIndex: return NamedIndex(self, val) - def __add__(self, offset: int) -> Connectivity: - return CartesianConnectivity(self, offset) + def __add__(self, offset: int | float) -> Connectivity: + return connectivity_for_cartesian_shift(self, offset) - def __sub__(self, offset: int) -> Connectivity: + def __sub__(self, offset: int | float) -> Connectivity: return self + (-offset) def __gt__(self, value: core_defs.IntegralScalar) -> Domain: @@ -1336,7 +1336,13 @@ def order_dimensions(dims: Iterable[Dimension]) -> list[Dimension]: """Find the canonical ordering of the dimensions in `dims`.""" if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1: raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") - return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value)) + return sorted( + dims, + key=lambda dim: ( + _DIM_KIND_ORDER[dim.kind], + as_non_staggered(dim).value, + ), + ) def check_dims(dims: Sequence[Dimension]) -> None: @@ -1424,3 +1430,32 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call #: Equivalent to the `_FillValue` attribute in the UGRID Conventions #: (see: http://ugrid-conventions.github.io/ugrid-conventions/). _DEFAULT_SKIP_VALUE: Final[int] = -1 +_STAGGERED_PREFIX = "_Staggered" + + +def is_staggered(dim: Dimension) -> bool: + return dim.value.startswith(_STAGGERED_PREFIX) + + +def flip_staggered(dim: Dimension) -> Dimension: + if is_staggered(dim): + return Dimension(dim.value[len(_STAGGERED_PREFIX) :], dim.kind) + else: + return Dimension(f"{_STAGGERED_PREFIX}{dim.value}", dim.kind) + + +def as_non_staggered(dim: Dimension) -> Dimension: + if is_staggered(dim): + return flip_staggered(dim) + return dim + + +def connectivity_for_cartesian_shift(dim: Dimension, offset: int | float) -> CartesianConnectivity: + if isinstance(offset, float): + integral_offset, half = divmod(offset, 1) + assert half == 0.5 + if not is_staggered(dim): + integral_offset += 1 + return CartesianConnectivity(dim, int(integral_offset), codomain=flip_staggered(dim)) + else: + return CartesianConnectivity(dim, offset, codomain=dim) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 8531ccacf9..c6df764eff 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -478,9 +478,7 @@ def __getitem__(self, offset: int) -> common.Connectivity: offset_definition = common.get_offset(current_offset_provider, self.value) connectivity: common.Connectivity - if isinstance(offset_definition, common.Dimension): - connectivity = common.CartesianConnectivity(offset_definition, offset) - elif isinstance(offset_definition, common.Connectivity): + if isinstance(offset_definition, common.Connectivity): assert common.is_neighbor_table(offset_definition) named_index = common.NamedIndex(self.target[-1], offset) connectivity = offset_definition[named_index] diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 1b090489ff..a2f59eaed1 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -12,7 +12,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py import eve from gt4py.eve import NodeTranslator, NodeVisitor, traits -from gt4py.next import errors +from gt4py.next import common, errors from gt4py.next.common import Dimension, DimensionKind, promote_dims from gt4py.next.ffront import ( dialect_ast_enums, @@ -607,13 +607,6 @@ def _deduce_compare_type( def _deduce_binop_type( self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: - # e.g. `IDim+1` - if ( - isinstance(left.type, ts.DimensionType) - and isinstance(right.type, ts.ScalarType) - and type_info.is_integral(right.type) - ): - return ts.OffsetType(source=left.type.dim, target=(left.type.dim,)) if isinstance(left.type, ts.OffsetType): raise errors.DSLError( node.location, f"Type '{left.type}' can not be used in operator '{node.op}'." @@ -679,6 +672,22 @@ def _deduce_binop_type( f"must be one of {', '.join((str(op) for op in logical_ops))}.", ) return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) + elif ( + node.op in (dialect_ast_enums.BinaryOperator.ADD, dialect_ast_enums.BinaryOperator.SUB) + and isinstance(left.type, ts.DimensionType) + and isinstance(right.type, ts.ScalarType) + and type_info.is_arithmetic(right.type) + ): + # e.g. `IDim+1` or `IDim+0.5` + if not isinstance(right, foast.Constant): + raise NotImplementedError( + "Cartesian offsets are only supported with literal rhs, e.g. `IDim + 1`, but not `IDim + expr`." + ) + offset_index = right.value + if node.op == dialect_ast_enums.BinaryOperator.SUB: + offset_index *= -1 + conn = common.connectivity_for_cartesian_shift(left.type.dim, offset_index) + return ts.OffsetType(source=conn.codomain, target=(conn.domain_dim,)) else: raise errors.DSLError(node.location, err_msg) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 86dd7fb2db..8cd82b3b41 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -13,7 +13,7 @@ from gt4py import eve from gt4py.eve.extended_typing import Never, cast -from gt4py.next import utils +from gt4py.next import common, utils from gt4py.next.ffront import ( dialect_ast_enums, experimental as experimental_builtins, @@ -318,19 +318,23 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: current_expr = im.as_fieldop( im.lambda_("__it")(im.deref(im.shift(shift_offset, new_index)("__it"))) )(current_expr) - # `field(Dim + idx)` + # `field(Dim + idx)` (and the staggered/relocation `field(Dim ± 0.5)`) case foast.BinOp( op=dialect_ast_enums.BinaryOperator.ADD | dialect_ast_enums.BinaryOperator.SUB, - left=foast.Name() as dim_name, + left=foast.Name(), # TODO(tehrengruber): use type instead right=foast.Constant(value=offset_index), ): if arg.op == dialect_ast_enums.BinaryOperator.SUB: offset_index *= -1 - assert isinstance(dim_name.type, ts.DimensionType) - dim = dim_name.type.dim + assert isinstance(arg.left.type, ts.DimensionType) + conn = common.connectivity_for_cartesian_shift(arg.left.type.dim, offset_index) current_expr = im.as_fieldop( im.lambda_("__it")( - im.deref(im.shift(im.cartesian_offset(dim), offset_index)("__it")) + im.deref( + im.shift( + im.cartesian_offset(conn.domain_dim, conn.codomain), conn.offset + )("__it") + ) ) )(current_expr) # `field(Off)` diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index e01ecc181e..76d25c1770 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -587,16 +587,11 @@ def execute_shift( # the assertions above confirm pos is incomplete casting here to avoid duplicating work in a type guard return cast(IncompletePosition, pos) | {tag: new_entry} - # a `CartesianConnectivity` tag is a self-describing cartesian shift (e.g. from a - # `CartesianOffset` IR node); named offsets are resolved through the offset provider if isinstance(tag, common.CartesianConnectivity): - assert tag.domain_dim == tag.codomain # relocation (staggering) is not supported here new_pos = copy.copy(pos) - key = tag.domain_dim.value - if common.is_int_index(value := new_pos[key]): - new_pos[key] = value + index + tag.offset - else: - raise AssertionError() + value = new_pos.pop(tag.domain_dim.value) + assert common.is_int_index(value) + new_pos[tag.codomain.value] = value + index + tag.offset return new_pos offset_implementation = common.get_offset(offset_provider, tag) if common.is_neighbor_table(offset_implementation): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 6321a2487c..3f8f40ccec 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -98,6 +98,7 @@ def __str__(self): InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE") +# TODO(tehrengruber): allow int only and create OffsetRef for str instead class OffsetLiteral(Expr): value: Union[int, str] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 49b98e0877..28bfb56a9c 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -17,7 +17,7 @@ from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im, misc +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -173,27 +173,45 @@ def translate( #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None, ) -> SymbolicDomain: - offset_provider_type = common.offset_provider_to_type(offset_provider) - dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} if len(shift) == 0: return self if len(shift) == 2: off, val = shift + + connectivity: common.Connectivity if isinstance(off, itir.CartesianOffset): + domain = common.Dimension(value=off.domain.value, kind=off.domain.kind) + codomain = common.Dimension(value=off.codomain.value, kind=off.codomain.kind) + connectivity = common.CartesianConnectivity(domain, codomain=codomain) + elif isinstance(off, itir.OffsetLiteral): + assert isinstance(off.value, str) + # `get_offset` accepts both `OffsetProvider` and `OffsetProviderType`, but is typed + # for the former only (the two overlap, so an overload is not possible). + connectivity = common.get_offset(offset_provider, off.value) # type: ignore[arg-type] + else: + raise AssertionError() + + if isinstance(connectivity, common.CartesianConnectivity): if val is trace_shifts.Sentinel.VALUE: raise NotImplementedError("Dynamic offsets not supported.") assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) - dom = misc.dim_from_axis_literal(off.domain) - cod = misc.dim_from_axis_literal(off.codomain) - assert dom == cod # relocation (staggering) is not supported here - new_ranges[dom] = SymbolicRange.translate(self.ranges[dom], val.value) - return SymbolicDomain(self.grid_type, new_ranges) - assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str) - connectivity_type = common.get_offset_type(offset_provider_type, off.value) - - if isinstance(connectivity_type, common.NeighborConnectivityType): + assert len(connectivity.domain.dims) == 1 + + old_dim = connectivity.domain_dim + new_dim = connectivity.codomain + + assert new_dim not in new_ranges or old_dim == new_dim + + new_range = SymbolicRange.translate( + self.ranges[old_dim], connectivity.offset + val.value + ) + new_ranges = dict( + (dim, range_) if dim != old_dim else (new_dim, new_range) + for dim, range_ in new_ranges.items() + ) + elif common.is_neighbor_table(connectivity): # unstructured shift assert ( isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) @@ -201,8 +219,8 @@ def translate( trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE, ] - old_dim = connectivity_type.source_dim - new_dim = connectivity_type.codomain + old_dim = connectivity.domain.dims[0] + new_dim = connectivity.codomain assert new_dim not in new_ranges or old_dim == new_dim if symbolic_domain_sizes is not None and new_dim.value in symbolic_domain_sizes: new_range = SymbolicRange( @@ -212,6 +230,7 @@ def translate( else: assert common.is_offset_provider(offset_provider) assert not isinstance(val, itir.CartesianOffset) # offset value, never a node + assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str) new_range = _unstructured_translate_range_statically( new_ranges[old_dim], off.value, val, offset_provider, self.as_expr() ) @@ -222,6 +241,7 @@ def translate( ) else: raise AssertionError() + return SymbolicDomain(self.grid_type, new_ranges) elif len(shift) > 2: return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate( diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index d9a6e85f57..87af68bec6 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -235,6 +235,16 @@ def dim_from_axis_literal(axis_literal: itir.AxisLiteral) -> common.Dimension: return common.Dimension(value=axis_literal.value, kind=axis_literal.kind) +def connectivity_from_cartesian_offset( + cart_offset: itir.CartesianOffset, +) -> common.CartesianConnectivity: + return common.CartesianConnectivity( + domain_dim=dim_from_axis_literal(cart_offset.domain), + codomain=dim_from_axis_literal(cart_offset.codomain), + offset=0, + ) + + def _flatten_tuple_expr(expr: itir.Expr) -> tuple[itir.Expr]: if cpm.is_call_to(expr, "make_tuple"): return sum( diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index e204377bc0..6587eb437e 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -16,8 +16,8 @@ from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union from gt4py.next import common, utils -from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, ir as itir +from gt4py.next.iterator.ir_utils import misc as ir_misc from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -482,14 +482,25 @@ def _resolve_dimensions( >>> Edge = common.Dimension(value="Edge") >>> Vertex = common.Dimension(value="Vertex") + >>> Cell = common.Dimension(value="Cell") >>> K = common.Dimension(value="K", kind=common.DimensionKind.VERTICAL) >>> V2E = common.Dimension(value="V2E") + >>> C2V = common.Dimension(value="C2V") >>> input_dims = [Edge, K] >>> shift_tuple = ( + ... itir.OffsetLiteral(value="C2V"), + ... itir.OffsetLiteral(value=0), ... itir.OffsetLiteral(value="V2E"), ... itir.OffsetLiteral(value=0), ... ) >>> offset_provider_type = { + ... "C2V": common.NeighborConnectivityType( + ... domain=(Cell, C2V), + ... codomain=Vertex, + ... skip_value=None, + ... dtype=None, + ... max_neighbors=3, + ... ), ... "V2E": common.NeighborConnectivityType( ... domain=(Vertex, V2E), ... codomain=Edge, @@ -497,26 +508,54 @@ def _resolve_dimensions( ... dtype=None, ... max_neighbors=4, ... ), - ... "KOff": K, ... } >>> _resolve_dimensions(input_dims, shift_tuple, offset_provider_type) - [Dimension(value='Vertex', kind=), Dimension(value='K', kind=)] + [Dimension(value='Cell', kind=), Dimension(value='K', kind=)] + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> IDim = common.Dimension(value="IDim") + >>> IHalfDim = common.flip_staggered(IDim) + >>> JDim = common.Dimension(value="JDim") + >>> JHalfDim = common.flip_staggered(JDim) + >>> input_dims = [IDim, JDim] + >>> shift_tuple = ( + ... itir.CartesianOffset( + ... domain=im.axis_literal(IDim), codomain=im.axis_literal(IHalfDim) + ... ), + ... itir.OffsetLiteral(value=0), + ... itir.CartesianOffset(domain=im.axis_literal(JDim), codomain=im.axis_literal(IDim)), + ... itir.OffsetLiteral(value=0), + ... itir.CartesianOffset( + ... domain=im.axis_literal(IHalfDim), codomain=im.axis_literal(JDim) + ... ), + ... itir.OffsetLiteral(value=0), + ... ) + >>> _resolve_dimensions(input_dims, shift_tuple, offset_provider_type) + [Dimension(value='JDim', kind=), Dimension(value='IDim', kind=)] + """ resolved_dims = [] for input_dim in input_dims: - for off_literal in reversed( - shift_tuple[::2] - ): # Only OffsetLiterals are processed, located at even indices in shift_tuple. Shifts are applied in reverse order: the last shift in the tuple is applied first. + resolved_dim = input_dim + for off_literal in reversed(shift_tuple[::2]): + # Only OffsetLiterals/CartesianOffsets are processed, located at even indices in + # shift_tuple. Shifts are applied in reverse order: the last shift in the tuple is + # applied first. if isinstance(off_literal, itir.CartesianOffset): - if off_literal.domain != off_literal.codomain: - raise NotImplementedError("Relocation (staggering) is not supported.") - continue # translation does not change the dimension - assert isinstance(off_literal.value, str) - offset_type = common.get_offset_type(offset_provider_type, off_literal.value) - if isinstance(offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType)): - if input_dim == offset_type.codomain: # Check if input fits to offset - input_dim = offset_type.domain[0] # Update input_dim for next iteration - resolved_dims.append(input_dim) + if resolved_dim == ir_misc.dim_from_axis_literal(off_literal.codomain): + resolved_dim = ir_misc.dim_from_axis_literal(off_literal.domain) + else: + assert isinstance(off_literal, itir.OffsetLiteral) and isinstance( + off_literal.value, str + ) + offset_type = common.get_offset_type(offset_provider_type, off_literal.value) + if isinstance(offset_type, common.NeighborConnectivityType): + if resolved_dim == offset_type.codomain: # Check if input fits to offset + resolved_dim = offset_type.domain[0] # Update input_dim for next iteration + else: + raise NotImplementedError( + f"'{offset_type}' is not a supported connectivity type." + ) + resolved_dims.append(resolved_dim) return resolved_dims @@ -660,24 +699,24 @@ def apply_shift( new_position_dims = [*it.position_dims] assert len(offset_literals) % 2 == 0 for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): + source_dim: common.Dimension + target_dim: common.Dimension if isinstance(offset_axis, it_ts.CartesianOffsetType): - if offset_axis.domain != offset_axis.codomain: - raise NotImplementedError("Relocation (staggering) is not supported.") - continue # translation leaves position dims unchanged - assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( - offset_axis.value, str - ) - type_ = common.get_offset_type(offset_provider_type, offset_axis.value) - if isinstance(type_, common.NeighborConnectivityType): - found = False - for i, dim in enumerate(new_position_dims): - if dim.value == type_.source_dim.value: - assert not found - new_position_dims[i] = type_.codomain - found = True - assert found + source_dim, target_dim = offset_axis.domain, offset_axis.codomain else: - raise NotImplementedError(f"{type_} is not a supported Connectivity type.") + assert isinstance(offset_axis, it_ts.OffsetLiteralType) + assert isinstance(offset_axis.value, str) + type_ = common.get_offset_type(offset_provider_type, offset_axis.value) + assert isinstance(type_, common.NeighborConnectivityType) + source_dim, target_dim = type_.domain[0], type_.codomain + + found = False + for i, dim in enumerate(new_position_dims): + if dim == source_dim: + assert not found + new_position_dims[i] = target_dim + found = True + assert found else: # during re-inference we don't have an offset provider type new_position_dims = "unknown" diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 416ac6a7c0..47c4ddb1a5 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -206,7 +206,7 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | Tuple: source_buffer=name, dimensions=[ DimensionSpec( - name=dim.value, + name=common.as_non_staggered(dim).value, static_stride=1 if ( config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index cc9dbd1d12..4c1ff0f092 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -91,10 +91,18 @@ def _name_from_named_range(named_range_call: itir.FunCall) -> str: return named_range_call.args[0].value +class FlipStaggeredDims(eve.NodeTranslator): + def visit_AxisLiteral(self, node: itir.AxisLiteral) -> itir.AxisLiteral: + dim = ir_utils_misc.dim_from_axis_literal(node) + if common.is_staggered(dim): + return im.axis_literal(common.as_non_staggered(dim)) + return node + + def _collect_dimensions_from_domain( body: Iterable[itir.Stmt], ) -> dict[str, TagDefinition]: - domains = _get_domains(body) + domains = FlipStaggeredDims().visit(_get_domains(body)) offset_definitions = {} for domain in domains: if domain.fun == itir.SymRef(id="cartesian_domain"): @@ -151,37 +159,20 @@ def _collect_offset_definitions( grid_type: common.GridType, offset_provider_type: common.OffsetProviderType, ) -> dict[str, TagDefinition]: - used_offset_tags: set[str] = ( - node.walk_values() - .if_isinstance(itir.OffsetLiteral) - .filter(lambda offset_literal: isinstance(offset_literal.value, str)) - .getattr("value") - ).to_set() - # implicit offsets don't occur in the `offset_provider_type`, get them from the used offset tags - offset_provider_type = { - offset_name: common.get_offset_type(offset_provider_type, offset_name) - for offset_name in used_offset_tags - } | {**offset_provider_type} offset_definitions = {} + offset_provider_type = {**offset_provider_type} - # cartesian shifts (`field(Dim + n)`) are encoded as `CartesianOffset` nodes and don't - # occur in the `offset_provider_type`; define a tag for each of their dimensions cartesian_offsets: set[itir.CartesianOffset] = ( node.walk_values().if_isinstance(itir.CartesianOffset) ).to_set() for cart_offset in cartesian_offsets: - for axis in (cart_offset.domain, cart_offset.codomain): - if grid_type == common.GridType.CARTESIAN: - offset_definitions[axis.value] = TagDefinition(name=Sym(id=axis.value)) - else: - assert grid_type == common.GridType.UNSTRUCTURED - if axis.kind != common.DimensionKind.VERTICAL: - raise ValueError( - "Mapping an offset to a horizontal dimension in unstructured is not allowed." - ) - offset_definitions[axis.value] = TagDefinition( - name=Sym(id=axis.value), alias=_vertical_dimension - ) + dims = [ + common.Dimension(value=v.value, kind=v.kind) + for v in (cart_offset.domain, cart_offset.codomain) + ] + for dim in dims: + dim = common.as_non_staggered(dim) + offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value)) for offset_name, connectivity_type in offset_provider_type.items(): if isinstance(connectivity_type, common.NeighborConnectivityType): @@ -387,12 +378,12 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs: Any) -> Offset return OffsetLiteral(value=node.value) def visit_CartesianOffset(self, node: itir.CartesianOffset, **kwargs: Any) -> Literal: - # render as the (shared) dimension tag - assert node.domain == node.codomain, "relocation (staggering) is not supported" return self.visit(node.codomain, **kwargs) def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs: Any) -> Literal: - return Literal(value=node.value, type="axis_literal") + assert isinstance(node.type, ts.DimensionType) + dim = common.as_non_staggered(node.type.dim) + return Literal(value=dim.value, type="axis_literal") def _make_domain(self, node: itir.FunCall) -> tuple[TaggedValues, TaggedValues]: tags = [] diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 2e8983ccb8..b6b270b2b6 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -1497,12 +1497,12 @@ def _visit_shift_multidim( return offset_provider_arg, offset_value_arg, it def _make_cartesian_shift( - self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr + self, it: IteratorExpr, conn: gtx_common.CartesianConnectivity, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert any(dim == offset_dim for dim, _ in it.field_domain) + old_dim, new_dim = conn.domain_dim, conn.codomain new_index: SymbolExpr | ValueExpr - index_expr = it.indices[offset_dim] + index_expr = it.indices[old_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr( @@ -1565,9 +1565,10 @@ def _make_cartesian_shift( ) # a new iterator with a shifted index along one dimension - shifted_indices = { - dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items() - } + shifted_indices = dict( + (new_dim, new_index) if dim == old_dim else (dim, index) + for dim, index in it.indices.items() + ) return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _make_dynamic_neighbor_offset( @@ -1673,31 +1674,27 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: ) if isinstance(offset_provider_arg, gtir.CartesianOffset): - # cartesian shift; the dimension (incl. kind) is encoded in the node - assert offset_provider_arg.domain == offset_provider_arg.codomain, ( - "relocation (staggering) is not supported" + conn = itir_misc.connectivity_from_cartesian_offset(offset_provider_arg) + return self._make_cartesian_shift(it, conn, offset_expr) + else: + assert isinstance(offset_provider_arg, gtir.OffsetLiteral) + assert isinstance(offset_provider_arg.value, str) + offset_provider_type = self.subgraph_builder.get_offset_provider_type( + offset_provider_arg.value + ) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + # a named offset → unstructured shift; the offset value may be a static + # `OffsetLiteral` or a dynamic offset (handled by `_make_unstructured_shift`). + # initially, the storage for the connectivity tables is created as transient; + # when the tables are used, the storage is changed to non-transient, + # so the corresponding arrays are supposed to be allocated by the SDFG caller + offset_table = gtx_dace_args.connectivity_identifier(offset_provider_arg.value) + self.sdfg.arrays[offset_table].transient = False + offset_table_node = self.state.add_access(offset_table) + + return self._make_unstructured_shift( + it, offset_provider_type, offset_table_node, offset_expr ) - offset_dim = itir_misc.dim_from_axis_literal(offset_provider_arg.codomain) - return self._make_cartesian_shift(it, offset_dim, offset_expr) - - # first argument of the shift node is the offset provider - assert isinstance(offset_provider_arg, gtir.OffsetLiteral) - offset = offset_provider_arg.value - assert isinstance(offset, str) - offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset) - - # reaching here means a named offset → unstructured shift (cartesian shifts took the - # CartesianOffset branch above) - # initially, the storage for the connectivity tables is created as transient; - # when the tables are used, the storage is changed to non-transient, - # so the corresponding arrays are supposed to be allocated by the SDFG caller - offset_table = gtx_dace_args.connectivity_identifier(offset) - self.sdfg.arrays[offset_table].transient = False - offset_table_node = self.state.add_access(offset_table) - - return self._make_unstructured_shift( - it, offset_provider_type, offset_table_node, offset_expr - ) def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: """ diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 678c81cfcf..746fc2f9ad 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -15,7 +15,7 @@ import gt4py._core.definitions as core_defs import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, common, config, field_utils +from gt4py.next import backend, common, config from gt4py.next.embedded import nd_array_field from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow @@ -90,15 +90,10 @@ def extract_connectivity_args( # the keys' order is taken into account. Any modification to the hashing # of offset providers may break this assumption here. args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ - (ndarray, zero_origin) + (conn.ndarray, zero_origin) for conn in offset_provider.values() - if (ndarray := getattr(conn, "ndarray", None)) is not None + if common.is_neighbor_table(conn) ] - assert all( - common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) - for conn in offset_provider.values() - if hasattr(conn, "ndarray") - ) return args diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d2efee1fc5..08fb817856 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -58,7 +58,9 @@ E2VDim, Edge, IDim, + IHalfDim, JDim, + JHalfDim, KDim, KHalfDim, V2EDim, @@ -74,6 +76,7 @@ # mypy does not accept [IDim, ...] as a type IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] +IHalfField: TypeAlias = gtx.Field[[IHalfDim], np.int32] # type: ignore [valid-type] JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] @@ -501,6 +504,7 @@ def verify_with_default_data( case: Case, fieldop: decorator.FieldOperator, ref: Callable, + offset_provider: Optional[OffsetProvider] = None, comparison: Callable[[Any, Any], bool] = tree_mapped_np_allclose, ) -> None: """ @@ -515,6 +519,8 @@ def verify_with_default_data( fieldview_prog: The field operator or program to be verified. ref: A callable which will be called with all the input arguments of the fieldview code, after applying ``.ndarray`` on the fields. + offset_provider: An override for the test case's offset_provider. + Use with care! comparison: A comparison function, which will be called as ``comparison(ref, )`` and should return a boolean. """ @@ -528,7 +534,7 @@ def verify_with_default_data( *inps, **kwfields, ref=ref(*ref_args), - offset_provider=case.offset_provider, + offset_provider=offset_provider, comparison=comparison, ) @@ -731,7 +737,9 @@ def from_cartesian_grid_descriptor( IDim: grid_descriptor.sizes[0], JDim: grid_descriptor.sizes[1], KDim: grid_descriptor.sizes[2], - KHalfDim: grid_descriptor.sizes[3], + IHalfDim: grid_descriptor.sizes[0] - 1, + JHalfDim: grid_descriptor.sizes[1] - 1, + KHalfDim: grid_descriptor.sizes[2] - 1, }, grid_type=common.GridType.CARTESIAN, allocator=allocator, diff --git a/tests/next_tests/integration_tests/cases_utils.py b/tests/next_tests/integration_tests/cases_utils.py index 7afb134ef2..59f5cb43c5 100644 --- a/tests/next_tests/integration_tests/cases_utils.py +++ b/tests/next_tests/integration_tests/cases_utils.py @@ -146,9 +146,11 @@ def debug_itir(tree): DType = TypeVar("DType") IDim = gtx.Dimension("IDim") +IHalfDim = common.flip_staggered(IDim) JDim = gtx.Dimension("JDim") +JHalfDim = common.flip_staggered(JDim) KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) -KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) +KHalfDim = common.flip_staggered(KDim) Vertex = gtx.Dimension("Vertex") Edge = gtx.Dimension("Edge") @@ -182,11 +184,11 @@ def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_cartesian_grid( - sizes: int | tuple[int, int, int, int] = (5, 7, 9, 11), + sizes: int | tuple[int, int, int, int] = (5, 7, 9), ) -> CartesianGridDescriptor: if isinstance(sizes, int): - sizes = (sizes,) * 4 - assert len(sizes) == 4, "sizes must be a tuple of four integers" + sizes = (sizes,) * 3 + assert len(sizes) == 3, "sizes must be a tuple of three integers" offset_provider = {} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py new file mode 100644 index 0000000000..1f7398ad57 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py @@ -0,0 +1,147 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import functools +import math +from functools import reduce +from typing import TypeAlias + +import numpy as np +import pytest + +import gt4py.next as gtx +from gt4py.next import ( + astype, + broadcast, + common, + errors, + float32, + float64, + int32, + int64, + minimum, + neighbor_sum, + utils as gt_utils, +) +from gt4py.next.ffront.experimental import as_offset + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import ( + C2E, + E2V, + V2E, + E2VDim, + Edge, + IDim, + IHalfDim, + JDim, + KDim, + V2EDim, + Vertex, + cartesian_case, + unstructured_case, + unstructured_case_3d, +) +from next_tests.integration_tests.cases_utils import ( + exec_alloc_descriptor, + mesh_descriptor, +) + + +@pytest.mark.uses_cartesian_shift +def test_copy_half_field(cartesian_case): + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + field_tuple = (a, a) + field_0 = field_tuple[0] + field_1 = field_tuple[1] + return field_0 + + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_shift_plus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IField: + return a(IDim + 1) # always pass an I-index to an IField + + size = cartesian_case.default_sizes[IDim] + a = cases.allocate(cartesian_case, testee, "a", domain={IDim: (1, size + 1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN, domain={IDim: (0, size)})() + + cases.verify(cartesian_case, testee, a, out=out, ref=a[:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_plus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IHalfField: + return a(IHalfDim + 0.5) # always pass an I-index to an IField + + size = cartesian_case.default_sizes[IDim] + a = cases.allocate(cartesian_case, testee, "a", sizes={IDim: size})() + out = cases.allocate(cartesian_case, testee, cases.RETURN, sizes={IHalfDim: size})() + + cases.verify(cartesian_case, testee, a, out=out, ref=a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_back(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + return a(IDim + 0.5)(IHalfDim - 0.5) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_plus1(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + return a(IHalfDim + 1) # always pass an IHalf-index to an IHalfField + + size = cartesian_case.default_sizes[IDim] + a = cases.allocate(cartesian_case, testee, "a", domain={IHalfDim: (1, size + 1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN, domain={IHalfDim: (0, size)})() + + cases.verify(cartesian_case, testee, a, out=out[:], ref=a[:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_minus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IHalfField: + return a(IHalfDim - 0.5) # always pass an I-index to an IField + + size = cartesian_case.default_sizes[IDim] + a = cases.allocate(cartesian_case, testee, "a", domain={IDim: (-1, size - 1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN, domain={IHalfDim: (0, size)})() + + cases.verify(cartesian_case, testee, a, out=out, ref=a[:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_half2center(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IField: + return 2 * a(IDim + 0.5) # always pass an IHalf-index to an IHalfField + + size = cartesian_case.default_sizes[IDim] + a = cases.allocate(cartesian_case, testee, "a", domain={IHalfDim: (1, size + 1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN, sizes={IDim: size})() + + cases.verify(cartesian_case, testee, a, out=out, ref=2 * a[:], offset_provider={}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index d4b27e78ee..0ddb42a521 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -17,6 +17,7 @@ IDim, JDim, KDim, + KHalfDim, C2E, E2V, V2E, @@ -37,7 +38,7 @@ mesh_descriptor, ) -KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) +pytestmark = pytest.mark.uses_cartesian_shift @gtx.field_operator