From c03dea29ecfd31aa21355a884eec531761481b76 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 8 Jun 2026 10:06:31 +0200 Subject: [PATCH] fix[next]: accept unused connectivities in offset provider for gtfn A field operator called with an offset provider that contains connectivities it does not use is lowered to a cartesian grid type (using a connectivity would force unstructured). The gtfn backend then asserted `grid_type == UNSTRUCTURED` while collecting offset definitions and crashed. In cartesian mode, ignore such connectivities: they are still declared as wrapper parameters so the runtime call signature stays stable, but are passed through the bindings without building a SID, since the generated program neither defines their dimension tags nor accesses them. Fixes #863. --- src/gt4py/next/otf/binding/interface.py | 4 +++ src/gt4py/next/otf/binding/nanobind.py | 2 +- .../codegens/gtfn/gtfn_module.py | 15 +++++++++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 6 ++++- .../ffront_tests/test_execution.py | 27 +++++++++++++++++++ 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/otf/binding/interface.py b/src/gt4py/next/otf/binding/interface.py index 96cab18c8a..7304c0445b 100644 --- a/src/gt4py/next/otf/binding/interface.py +++ b/src/gt4py/next/otf/binding/interface.py @@ -28,6 +28,10 @@ def format_source(source_code_spec: code_specs.SourceCodeSpec, source: str) -> s class Parameter: name: str type_: ts.TypeSpec + # If set, the parameter is accepted by the generated wrapper (to keep the call signature + # stable) but forwarded to the callee without constructing a buffer SID, i.e. the program + # does not actually access it. + pass_through: bool = False @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 416ac6a7c0..66255a06d8 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -274,7 +274,7 @@ def create_bindings( expr=FunctionCall( target=program_source.entry_point, args=[ - make_argument(param.name, param.type_) + param.name if param.pass_through else make_argument(param.name, param.type_) for param in program_source.entry_point.parameters ], ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 2135af7fbb..243108eeb6 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -20,6 +20,7 @@ from gt4py.next import common from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import misc as ir_utils_misc from gt4py.next.iterator.transforms import pass_manager from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.binding import cpp_interface, interface @@ -100,11 +101,17 @@ def _process_regular_arguments( return parameters, arg_exprs def _process_connectivity_args( - self, offset_provider_type: common.OffsetProviderType + self, offset_provider_type: common.OffsetProviderType, grid_type: common.GridType ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] + # A cartesian program never uses connectivities, but the offset provider may still contain + # them (e.g. when shared across multiple programs). They are declared as parameters (to keep + # the runtime call signature stable) but passed through the bindings without building a SID, + # since the generated program neither defines their dimension tags nor accesses them. + is_cartesian = grid_type != common.GridType.UNSTRUCTURED + for name, connectivity_type in offset_provider_type.items(): if isinstance(connectivity_type, common.NeighborConnectivityType): if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]: @@ -120,9 +127,13 @@ def _process_connectivity_args( dims=list(connectivity_type.domain), dtype=type_translation.from_dtype(connectivity_type.dtype), ), + pass_through=is_cartesian, ) ) + if is_cartesian: + continue + # connectivity argument expression nbtbl = ( f"gridtools::fn::sid_neighbor_table::as_neighbor_table<" @@ -213,7 +224,7 @@ def __call__( # handle connectivity parameters and arguments (i.e. what the user provided in the offset # provider) connectivity_parameters, connectivity_args_expr = self._process_connectivity_args( - inp.args.offset_provider_type + inp.args.offset_provider_type, ir_utils_misc.grid_type_from_program(program) ) # combine into a format that is aligned with what the backend expects diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ecd8ed88ed..bb0447dc71 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -169,7 +169,11 @@ def _collect_offset_definitions( elif isinstance( connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType ): - assert grid_type == common.GridType.UNSTRUCTURED + if grid_type != common.GridType.UNSTRUCTURED: + # A cartesian program never uses connectivities (using one would make it + # unstructured), but the offset provider may still contain them, e.g. when shared + # across multiple programs. Such connectivities are ignored. + continue offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) if offset_name != connectivity_type.neighbor_dim.value: offset_definitions[connectivity_type.neighbor_dim.value] = TagDefinition( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f410dadc95..c610749fb8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -64,6 +64,33 @@ def testee(a: cases.IJKField) -> cases.IJKField: cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) +def test_unused_connectivity_in_offset_provider(cartesian_case): + # A program may be called with an offset provider that contains connectivities it does not + # use (e.g. a single offset provider shared across many programs). The connectivity must be + # accepted even though the program has a cartesian grid type. See GH #863. + @gtx.field_operator + def testee(a: cases.IFloatField) -> cases.IFloatField: + return a + + a = cases.allocate(cartesian_case, testee, "a")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + unused_conn = gtx.as_connectivity( + [Vertex, V2EDim], + codomain=Edge, + data=np.array([[0, 1], [1, 2], [2, 0]], dtype=np.int64), + allocator=cartesian_case.allocator, + ) + + cases.verify( + cartesian_case, + testee, + a, + out=out, + offset_provider={"V2E": unused_conn}, + ref=a.asnumpy(), + ) + + @pytest.mark.uses_tuple_returns def test_multicopy(cartesian_case): @gtx.field_operator