Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions src/gt4py/cartesian/gtc/dace/oir_to_treeir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,18 @@
"""Default dace residency types per device type."""


def _resolve_default_map_schedule(
device_type: dtypes.DeviceType,
def _resolve_loop_schedule(
device_type: dtypes.DeviceType, loop_order: common.LoopOrder
) -> dtypes.ScheduleType:
"""Default kernel target per device type."""
"""Optimal kernel schedule type based on syntax and target device.

Current strategy:
- respect OIR syntax: sequential on all non-parallel keyword
- maximize local parallel usage per target
"""
if loop_order != common.LoopOrder.PARALLEL:
return dtypes.ScheduleType.Sequential

if device_type == dtypes.DeviceType.GPU:
return dtypes.ScheduleType.GPU_Device

Expand All @@ -44,7 +52,7 @@ def _resolve_default_map_schedule(
if not gt_config.build_settings["openmp"]["use_openmp"]:
return dtypes.ScheduleType.Sequential

return dtypes.ScheduleType.Default
return dtypes.ScheduleType.CPU_Multicore


class OIRToTreeIR(eve.NodeVisitor):
Expand Down Expand Up @@ -148,7 +156,9 @@ def visit_HorizontalExecution(self, node: oir.HorizontalExecution, ctx: tir.Cont
loop = tir.HorizontalLoop(
bounds_i=tir.Bounds(start=axis_start_i, end=axis_end_i),
bounds_j=tir.Bounds(start=axis_start_j, end=axis_end_j),
schedule=_resolve_default_map_schedule(self._device_type),
schedule=_resolve_loop_schedule(
self._device_type, common.LoopOrder.PARALLEL
), # Horizontal is always parallel
children=[],
parent=ctx.current_scope,
)
Expand Down Expand Up @@ -258,19 +268,6 @@ def visit_Interval(

return tir.Bounds(start=start, end=end)

def _vertical_loop_schedule(self) -> dtypes.ScheduleType:
"""
Defines the vertical loop schedule.

Current strategy is to
- keep the vertical loop on the host for both, CPU and GPU targets
- and run it in parallel on CPU and sequential on GPU.
"""
if self._device_type == dtypes.DeviceType.GPU:
return dtypes.ScheduleType.Sequential

return _resolve_default_map_schedule(self._device_type)

def visit_VerticalLoopSection(
self, node: oir.VerticalLoopSection, ctx: tir.Context, loop_order: common.LoopOrder
) -> None:
Expand All @@ -285,7 +282,7 @@ def visit_VerticalLoopSection(
iteration_variable=eve.SymbolRef(f"{tir.Axis.K.iteration_symbol()}_{id(node)}"),
loop_order=loop_order,
bounds_k=bounds,
schedule=self._vertical_loop_schedule(),
schedule=_resolve_loop_schedule(self._device_type, loop_order),
children=[],
parent=ctx.current_scope,
)
Expand Down