Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions docs/development/ADRs/next/0024-Dtype-Generic-Operators.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
---
Comment thread
havogt marked this conversation as resolved.
tags: []
---

# [Dtype-Generic Operators]

- **Status**: valid
- **Authors**: Hannes Vogt (@havogt)
- **Created**: 2026-06-15
- **Updated**: 2026-06-15

Field operators (and programs calling them) may be **generic in the field dtype**,
spelled with a native value-constrained `typing.TypeVar` so the same annotation is
meaningful to mypy and to the DSL frontend. Each concrete call is specialized
(monomorphized) at call time.

```python
FloatT = typing.TypeVar("FloatT", gtx.float32, gtx.float64)


@gtx.field_operator
def diff(
a: gtx.Field[gtx.Dims[I, J], FloatT], b: gtx.Field[gtx.Dims[I, J], FloatT]
) -> gtx.Field[gtx.Dims[I, J], FloatT]:
return a - b
```

## Context

`common.Field` is already a runtime-introspectable generic protocol, so
`Field[Dims[I, J], T]` with a value-constrained `TypeVar` is a valid, mypy-visible
annotation today; what was missing is the DSL side. The internal type system had
`DeferredType` ("some type, maybe constrained") but no notion of *identity* — it
could not express "the *same* unknown dtype in two parameters and the return type",
the essence of generics. The runtime monomorphization machinery already existed
(grown for scan operators): `CompiledProgramsPool` keys a per-call specialization
cache on the concrete argument types.

Prior art (numpy.typing, jaxtyping, Numba, Taichi, Triton, DaCe) converges on the
two choices adopted here: a real generic annotation that static checkers can see,
and monomorphization at call time.

## Decision

### User-facing spelling

A **value-constrained** type parameter inside the real generic `Field` class,
spelled with PEP 695 `def op[T: (float32, float64)](...)` (preferred at the 3.12+
floor) or the equivalent module-level `TypeVar("T", float32, float64)` (accepted,
produces the same runtime objects). Value-constrained — not `bound=` — because each
use must resolve to exactly one listed type, which makes the dtype predicates
decidable and the variant set finite (eager precompilation possible). `bound=`-only
and unconstrained type variables are rejected with a clear message.

### `ts.TypeVarType`

A new `DataType` subclass carrying `name` and `constraints: tuple[ScalarType, ...]`.

- Subclassing `DataType` lets it fit unchanged into `FieldType.dtype` (widened to
`ScalarType | ListType | TypeVarType`), `TupleType`/`NamedCollectionType` members,
and `foast.Symbol`.
- **Identity is the name**, scoped to one operator signature. Two *distinct*
same-named `TypeVar` objects in one signature are rejected at parse time (with PEP
695 this is impossible by construction). As a frozen eve `DataModel` it gets
deterministic `eq`/`hash`/`content_hash` for cache keys.
- `ts.DeferredType` is **not** replaced. The two mechanisms coexist: `DeferredType`
means "not yet inferred" (and currently also encodes the scan operators' *dims*
genericity); `TypeVarType` means "universally quantified over the constraint set".
A single `type_info.is_generic` predicate recognizes both.

### Decisions D1–D5

- **D1 — Decoration-time body checking with opaque `TypeVarType`.** The body is
type-checked once, at decoration time, with `T` treated as an opaque scalar.
Errors are reported in the user's vocabulary (`T`), not in instantiated terms.
(Rejected: skip-until-instantiation — breaks the decoration-time-errors UX;
finite monomorph-check — duplicates compile work and reports in the wrong
vocabulary. The finite check survives only as a test-suite cross-check.)
- **D2 — Value-constrained TypeVars only.** Finite variant set ⇒ decidable dtype
predicates and eager `.compile()` of all members. `bound=` is a future extension.
- **D3 — Strict no-promotion.** `promote(T, T) = T`; mixing `T` with a concrete
scalar/dtype (including literals: `a * 2.0`) is a decoration-time error naming the
type variable. `astype(x, T)` is the designated remediation (a fast-follow). We
pre-commit to strict-first rather than inheriting Numba-style silent promotion;
this is expected to be the main ergonomics complaint and is revisited via a named
"generic literals" follow-up, not by relaxing the default.
- **D4 — Monomorphize at FOAST level; never lower generic GTIR.** Specialization is
direct type substitution over the typed FOAST, with a full re-run of type
deduction as a soundness backstop under `__debug__`. (Rejected: lowering generic
FOAST and concretizing at GTIR level — `foast_to_gtir` bakes dtypes into literals
and casts; GTIR has no syntax for "dtype of param x".)
- **D5 — Binding is a first-class `type_info` utility.**
`bind_type_vars(params, args)` (structural match, consistency + exact-match
checks) and `substitute_type_vars(type_, binding)` (recursion over every TypeSpec).
`accepts_args` keeps its boolean interface; callers needing the binding use the
new API.

### Monomorphization strategy

- **Direct operator call with a backend:** `FieldOperator.__call__` →
`CompiledProgramsPool`. The pool already detects generic signatures
(`is_generic`), keys the cache on the full concrete substitution
(`arg_specialization_key`), and forwards concrete types as `CompileTimeArgs`. A
new `foast_specialize` toolchain step (after `func_to_foast`) computes the binding
and substitutes throughout the FOAST tree; everything downstream runs on a
concrete artifact.
- **Generic operator called from a concrete program:** the binding is fully static
at program decoration. The fieldop signature checks bind-and-substitute, and a
PAST monomorphization pass (run in `past_to_itir`) recomputes the binding from the
typed call-site args, name-mangles the callee per binding (e.g.
`diff__float32`), and swaps in a specialized callable via a new
`GTCallable.__gt_specialize__(binding)`. Two bindings of one operator naturally
become two GTIR `FunctionDefinition`s.
- **Embedded mode:** nearly free — the original Python definition runs on real
fields once decoration tolerates generic signatures.

### Cache-key story

The pool's `arg_specialization_key` hashes all argument types, so the full
substitution is in the key — distinct dtypes hit distinct variants. Value-constrained
TypeVars make eager precompilation of all variants possible via the existing
`.compile()` API.

## Out of scope / deferred (with forward-compatibility notes)

- **Generic scan operators** — rejected with a clear message (needs `init: T`
coercion semantics). Nothing in the utilities hardcodes `FieldOperatorType`.
- **`bound=` TypeVars** — infinite constraint sets; predicates by bound; no eager
precompile.
- **`astype(x, T)` / generic scalar constructors** — the D3 remediation; requires a
`ConstructorType` over `TypeVarType`.
- **Builtin coverage** — `where`, `broadcast`, reductions, `concat_where`, neighbor
fields are audited and widened incrementally; until then a generic argument to an
un-audited builtin is a clear decoration-time error (math builtins already work).
- **Dimension genericity** — a separate effort (the true fix for the scan
`DeferredType`/fabricated-`Dimension` hack). The binding utilities are kept
**dtype-scoped** here: `bind_type_vars`/`substitute_type_vars` map names to
`ScalarType` only, and same-name rejection is specified over dtype type variables
only. Widening the binding environment to dimensions and generalizing same-name
rejection across type-parameter kinds is explicitly deferred to that work.
- **PEP 696 dtype defaults** (unparameterized `Field` means `float64`) and
**mypy-plugin un-blurring** of `float32`/`float64` — later, coordinated with the
`Field` annotation cleanup (gt4py #1415/#1416).
4 changes: 4 additions & 0 deletions docs/development/ADRs/next/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ Writing a new ADR is simple:
- [0010 - Domain in Field View](0010-Domain_in_Field_View.md)
- [0013 - Scalar vs 0d-Fields](0013-Scalar_vs_0d_Fields.md)

### Type System

- [0024 - Dtype-Generic Operators](0024-Dtype-Generic-Operators.md)

### Iterator IR #iterator

- [0003 - Iterator View Tuple Support for Fields](0003-Iterator_View_Tuple_Support_for_Fields.md)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> foast.Call:
print(f"Warning: return type of '{func_name}' might be inconsistent (not implemented).")

# deduce return type
return_type: Optional[ts.FieldType | ts.ScalarType] = None
return_type: Optional[ts.FieldType | ts.ScalarType | ts.TypeVarType] = None
if (
func_name
in fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES + fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES
Expand Down
12 changes: 8 additions & 4 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,16 @@ def _transform_by_pattern(
lambda x: next(uids["__tmp"]),
result_collection_constructor=as_tuple,
)(tmp_expr.type)
# the lowered IR is concrete, so `extract_dtype` never yields a `TypeVarType` here
tmp_dtypes: (
ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...]
) = type_info.tree_map_type(
type_info.extract_dtype,
result_collection_constructor=as_tuple,
)(tmp_expr.type)
) = cast(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if we can express this without a cast

"ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...]",
type_info.tree_map_type(
type_info.extract_dtype,
result_collection_constructor=as_tuple,
)(tmp_expr.type),
)

tmp_domains: SymbolicDomain | tuple[SymbolicDomain | tuple, ...] = tmp_expr.annex.domain

Expand Down
19 changes: 6 additions & 13 deletions src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ def __call__(
# expensive type deduction for all arguments and not include it in the key.
if enable_jit:
warnings.warn(
"Calling generic programs / direct calls to scan operators are not optimized. "
"Calling generic programs / operators (e.g. operators with generic dtype "
"or direct calls to scan operators) is not optimized. "
"Consider calling a specialized version instead.",
stacklevel=3,
)
Expand Down Expand Up @@ -459,19 +460,11 @@ def _is_generic(self) -> bool:
Is the operator or program generic in the sense that it can be called for different
argument types.

Right now this is only the case for scan operators.
Right now this is the case for scan operators (genericity communicated via
`DeferredType` parameters created in `type_info.type_in_program_context`) and for
operators with a generic dtype (type variable).
"""
# TODO(tehrengruber): This concept does not exist elsewhere and is not properly reflected
# in the type system. For now we just use `DeferredType` to communicate between
# here and `type_info.type_in_program_context`.
return any(
isinstance(t, ts.DeferredType)
for t in itertools.chain(
self.program_type.definition.pos_only_args,
self.program_type.definition.pos_or_kw_args.values(),
self.program_type.definition.kw_only_args.values(),
)
)
return type_info.is_generic(self.program_type.definition)

@functools.cached_property
def _args_canonicalizer(self) -> Callable[..., tuple[tuple, dict[str, Any]]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ class MemletExpr:

@property
def gt_dtype(self) -> ts.ScalarType | ts.ListType:
return self.gt_field.dtype
dtype = self.gt_field.dtype
assert ti.is_concrete_dtype(dtype)
return dtype

def __post_init__(self) -> None:
if isinstance(self.gt_dtype, ts.ListType):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ def _add_storage(
dc_dtype = gtx_dace_args.as_dace_type(gt_type.dtype)
all_dims = gt_type.dims
else: # for 'ts.ListType' use 'offset_type' as local dimension
assert isinstance(gt_type.dtype, ts.ListType)
assert gt_type.dtype.offset_type is not None
assert gt_type.dtype.offset_type.kind == gtx_common.DimensionKind.LOCAL
assert isinstance(gt_type.dtype.element_type, ts.ScalarType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
if isinstance(output_type.dtype, ts.ScalarType):
all_dims = gtx_common.order_dimensions(output_type.dims)
else:
assert isinstance(output_type.dtype, ts.ListType)
assert output_type.dtype.offset_type
all_dims = gtx_common.order_dimensions([*output_type.dims, output_type.dtype.offset_type])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,9 @@ def _handle_dataflow_result_of_nested_sdfg(
None,
dace.Memlet.from_array(outer_dataname, outer_desc),
)
output_expr = gtir_dataflow.ValueExpr(outer_node, inner_data.gt_type.dtype)
output_dtype = inner_data.gt_type.dtype
assert ti.is_concrete_dtype(output_dtype)
output_expr = gtir_dataflow.ValueExpr(outer_node, output_dtype)
return gtir_dataflow.DataflowOutputEdge(outer_ctx.state, output_expr)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gt4py.next.iterator import builtins as gtir_builtins
from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args
from gt4py.next.program_processors.runners.dace.lowering import gtir_dataflow, gtir_domain
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.type_system import type_info as ti, type_specifications as ts


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -86,9 +86,9 @@ def get_local_view(
(dim, dace.symbolic.SymExpr(0) if self.origin is None else self.origin[i])
for i, dim in enumerate(self.gt_type.dims)
]
return gtir_dataflow.IteratorExpr(
self.dc_node, self.gt_type.dtype, field_origin, it_indices
)
dtype = self.gt_type.dtype
assert ti.is_concrete_dtype(dtype)
return gtir_dataflow.IteratorExpr(self.dc_node, dtype, field_origin, it_indices)

raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.")

Expand Down
Loading