Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/gt4py/next/otf/binding/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def format_source(source_code_spec: code_specs.SourceCodeSpec, source: str) -> s
class Parameter:
name: str
type_: ts.TypeSpec
# If set, the parameter is accepted by the generated wrapper (to keep the call signature
# stable) but forwarded to the callee without constructing a buffer SID, i.e. the program
# does not actually access it.
pass_through: bool = False


@dataclasses.dataclass(frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/otf/binding/nanobind.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def create_bindings(
expr=FunctionCall(
target=program_source.entry_point,
args=[
make_argument(param.name, param.type_)
param.name if param.pass_through else make_argument(param.name, param.type_)
for param in program_source.entry_point.parameters
],
)
Expand Down
15 changes: 13 additions & 2 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gt4py.next import common
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import misc as ir_utils_misc
from gt4py.next.iterator.transforms import pass_manager
from gt4py.next.otf import code_specs, definitions, stages, workflow
from gt4py.next.otf.binding import cpp_interface, interface
Expand Down Expand Up @@ -100,11 +101,17 @@ def _process_regular_arguments(
return parameters, arg_exprs

def _process_connectivity_args(
self, offset_provider_type: common.OffsetProviderType
self, offset_provider_type: common.OffsetProviderType, grid_type: common.GridType
) -> tuple[list[interface.Parameter], list[str]]:
parameters: list[interface.Parameter] = []
arg_exprs: list[str] = []

# A cartesian program never uses connectivities, but the offset provider may still contain
# them (e.g. when shared across multiple programs). They are declared as parameters (to keep
# the runtime call signature stable) but passed through the bindings without building a SID,
# since the generated program neither defines their dimension tags nor accesses them.
is_cartesian = grid_type != common.GridType.UNSTRUCTURED

for name, connectivity_type in offset_provider_type.items():
if isinstance(connectivity_type, common.NeighborConnectivityType):
if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]:
Expand All @@ -120,9 +127,13 @@ def _process_connectivity_args(
dims=list(connectivity_type.domain),
dtype=type_translation.from_dtype(connectivity_type.dtype),
),
pass_through=is_cartesian,
)
)

if is_cartesian:
continue

# connectivity argument expression
nbtbl = (
f"gridtools::fn::sid_neighbor_table::as_neighbor_table<"
Expand Down Expand Up @@ -213,7 +224,7 @@ def __call__(
# handle connectivity parameters and arguments (i.e. what the user provided in the offset
# provider)
connectivity_parameters, connectivity_args_expr = self._process_connectivity_args(
inp.args.offset_provider_type
inp.args.offset_provider_type, ir_utils_misc.grid_type_from_program(program)
)

# combine into a format that is aligned with what the backend expects
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ def _collect_offset_definitions(
elif isinstance(
connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType
):
assert grid_type == common.GridType.UNSTRUCTURED
if grid_type != common.GridType.UNSTRUCTURED:
# A cartesian program never uses connectivities (using one would make it
# unstructured), but the offset provider may still contain them, e.g. when shared
# across multiple programs. Such connectivities are ignored.
continue
offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name))
if offset_name != connectivity_type.neighbor_dim.value:
offset_definitions[connectivity_type.neighbor_dim.value] = TagDefinition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ def testee(a: cases.IJKField) -> cases.IJKField:
cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a)


def test_unused_connectivity_in_offset_provider(cartesian_case):
# A program may be called with an offset provider that contains connectivities it does not
# use (e.g. a single offset provider shared across many programs). The connectivity must be
# accepted even though the program has a cartesian grid type. See GH #863.
@gtx.field_operator
def testee(a: cases.IFloatField) -> cases.IFloatField:
return a

a = cases.allocate(cartesian_case, testee, "a")()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()
unused_conn = gtx.as_connectivity(
[Vertex, V2EDim],
codomain=Edge,
data=np.array([[0, 1], [1, 2], [2, 0]], dtype=np.int64),
allocator=cartesian_case.allocator,
)

cases.verify(
cartesian_case,
testee,
a,
out=out,
offset_provider={"V2E": unused_conn},
ref=a.asnumpy(),
)


@pytest.mark.uses_tuple_returns
def test_multicopy(cartesian_case):
@gtx.field_operator
Expand Down