diff --git a/docs/development/ADRs/next/0024-Dtype-Generic-Operators.md b/docs/development/ADRs/next/0024-Dtype-Generic-Operators.md new file mode 100644 index 0000000000..33454e033e --- /dev/null +++ b/docs/development/ADRs/next/0024-Dtype-Generic-Operators.md @@ -0,0 +1,143 @@ +--- +tags: [] +--- + +# [Dtype-Generic Operators] + +- **Status**: valid +- **Authors**: Hannes Vogt (@havogt) +- **Created**: 2026-06-15 +- **Updated**: 2026-06-15 + +Field operators (and programs calling them) may be **generic in the field dtype**, +spelled with a native value-constrained `typing.TypeVar` so the same annotation is +meaningful to mypy and to the DSL frontend. Each concrete call is specialized +(monomorphized) at call time. + +```python +FloatT = typing.TypeVar("FloatT", gtx.float32, gtx.float64) + + +@gtx.field_operator +def diff( + a: gtx.Field[gtx.Dims[I, J], FloatT], b: gtx.Field[gtx.Dims[I, J], FloatT] +) -> gtx.Field[gtx.Dims[I, J], FloatT]: + return a - b +``` + +## Context + +`common.Field` is already a runtime-introspectable generic protocol, so +`Field[Dims[I, J], T]` with a value-constrained `TypeVar` is a valid, mypy-visible +annotation today; what was missing is the DSL side. The internal type system had +`DeferredType` ("some type, maybe constrained") but no notion of *identity* — it +could not express "the *same* unknown dtype in two parameters and the return type", +the essence of generics. The runtime monomorphization machinery already existed +(grown for scan operators): `CompiledProgramsPool` keys a per-call specialization +cache on the concrete argument types. + +Prior art (numpy.typing, jaxtyping, Numba, Taichi, Triton, DaCe) converges on the +two choices adopted here: a real generic annotation that static checkers can see, +and monomorphization at call time. + +## Decision + +### User-facing spelling + +A **value-constrained** type parameter inside the real generic `Field` class, +spelled with PEP 695 `def op[T: (float32, float64)](...)` (preferred at the 3.12+ +floor) or the equivalent module-level `TypeVar("T", float32, float64)` (accepted, +produces the same runtime objects). Value-constrained — not `bound=` — because each +use must resolve to exactly one listed type, which makes the dtype predicates +decidable and the variant set finite (eager precompilation possible). `bound=`-only +and unconstrained type variables are rejected with a clear message. + +### `ts.TypeVarType` + +A new `DataType` subclass carrying `name` and `constraints: tuple[ScalarType, ...]`. + +- Subclassing `DataType` lets it fit unchanged into `FieldType.dtype` (widened to + `ScalarType | ListType | TypeVarType`), `TupleType`/`NamedCollectionType` members, + and `foast.Symbol`. +- **Identity is the name**, scoped to one operator signature. Two *distinct* + same-named `TypeVar` objects in one signature are rejected at parse time (with PEP + 695 this is impossible by construction). As a frozen eve `DataModel` it gets + deterministic `eq`/`hash`/`content_hash` for cache keys. +- `ts.DeferredType` is **not** replaced. The two mechanisms coexist: `DeferredType` + means "not yet inferred" (and currently also encodes the scan operators' *dims* + genericity); `TypeVarType` means "universally quantified over the constraint set". + A single `type_info.is_generic` predicate recognizes both. + +### Decisions D1–D5 + +- **D1 — Decoration-time body checking with opaque `TypeVarType`.** The body is + type-checked once, at decoration time, with `T` treated as an opaque scalar. + Errors are reported in the user's vocabulary (`T`), not in instantiated terms. + (Rejected: skip-until-instantiation — breaks the decoration-time-errors UX; + finite monomorph-check — duplicates compile work and reports in the wrong + vocabulary. The finite check survives only as a test-suite cross-check.) +- **D2 — Value-constrained TypeVars only.** Finite variant set ⇒ decidable dtype + predicates and eager `.compile()` of all members. `bound=` is a future extension. +- **D3 — Strict no-promotion.** `promote(T, T) = T`; mixing `T` with a concrete + scalar/dtype (including literals: `a * 2.0`) is a decoration-time error naming the + type variable. `astype(x, T)` is the designated remediation (a fast-follow). We + pre-commit to strict-first rather than inheriting Numba-style silent promotion; + this is expected to be the main ergonomics complaint and is revisited via a named + "generic literals" follow-up, not by relaxing the default. +- **D4 — Monomorphize at FOAST level; never lower generic GTIR.** Specialization is + direct type substitution over the typed FOAST, with a full re-run of type + deduction as a soundness backstop under `__debug__`. (Rejected: lowering generic + FOAST and concretizing at GTIR level — `foast_to_gtir` bakes dtypes into literals + and casts; GTIR has no syntax for "dtype of param x".) +- **D5 — Binding is a first-class `type_info` utility.** + `bind_type_vars(params, args)` (structural match, consistency + exact-match + checks) and `substitute_type_vars(type_, binding)` (recursion over every TypeSpec). + `accepts_args` keeps its boolean interface; callers needing the binding use the + new API. + +### Monomorphization strategy + +- **Direct operator call with a backend:** `FieldOperator.__call__` → + `CompiledProgramsPool`. The pool already detects generic signatures + (`is_generic`), keys the cache on the full concrete substitution + (`arg_specialization_key`), and forwards concrete types as `CompileTimeArgs`. A + new `foast_specialize` toolchain step (after `func_to_foast`) computes the binding + and substitutes throughout the FOAST tree; everything downstream runs on a + concrete artifact. +- **Generic operator called from a concrete program:** the binding is fully static + at program decoration. The fieldop signature checks bind-and-substitute, and a + PAST monomorphization pass (run in `past_to_itir`) recomputes the binding from the + typed call-site args, name-mangles the callee per binding (e.g. + `diff__float32`), and swaps in a specialized callable via a new + `GTCallable.__gt_specialize__(binding)`. Two bindings of one operator naturally + become two GTIR `FunctionDefinition`s. +- **Embedded mode:** nearly free — the original Python definition runs on real + fields once decoration tolerates generic signatures. + +### Cache-key story + +The pool's `arg_specialization_key` hashes all argument types, so the full +substitution is in the key — distinct dtypes hit distinct variants. Value-constrained +TypeVars make eager precompilation of all variants possible via the existing +`.compile()` API. + +## Out of scope / deferred (with forward-compatibility notes) + +- **Generic scan operators** — rejected with a clear message (needs `init: T` + coercion semantics). Nothing in the utilities hardcodes `FieldOperatorType`. +- **`bound=` TypeVars** — infinite constraint sets; predicates by bound; no eager + precompile. +- **`astype(x, T)` / generic scalar constructors** — the D3 remediation; requires a + `ConstructorType` over `TypeVarType`. +- **Builtin coverage** — `where`, `broadcast`, reductions, `concat_where`, neighbor + fields are audited and widened incrementally; until then a generic argument to an + un-audited builtin is a clear decoration-time error (math builtins already work). +- **Dimension genericity** — a separate effort (the true fix for the scan + `DeferredType`/fabricated-`Dimension` hack). The binding utilities are kept + **dtype-scoped** here: `bind_type_vars`/`substitute_type_vars` map names to + `ScalarType` only, and same-name rejection is specified over dtype type variables + only. Widening the binding environment to dimensions and generalizing same-name + rejection across type-parameter kinds is explicitly deferred to that work. +- **PEP 696 dtype defaults** (unparameterized `Field` means `float64`) and + **mypy-plugin un-blurring** of `float32`/`float64` — later, coordinated with the + `Field` annotation cleanup (gt4py #1415/#1416). diff --git a/docs/development/ADRs/next/README.md b/docs/development/ADRs/next/README.md index 9fe0695f24..c6b1907a2c 100644 --- a/docs/development/ADRs/next/README.md +++ b/docs/development/ADRs/next/README.md @@ -26,6 +26,10 @@ Writing a new ADR is simple: - [0010 - Domain in Field View](0010-Domain_in_Field_View.md) - [0013 - Scalar vs 0d-Fields](0013-Scalar_vs_0d_Fields.md) +### Type System + +- [0024 - Dtype-Generic Operators](0024-Dtype-Generic-Operators.md) + ### Iterator IR #iterator - [0003 - Iterator View Tuple Support for Fields](0003-Iterator_View_Tuple_Support_for_Fields.md) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 1b090489ff..2b9693cc1f 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -831,7 +831,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> foast.Call: print(f"Warning: return type of '{func_name}' might be inconsistent (not implemented).") # deduce return type - return_type: Optional[ts.FieldType | ts.ScalarType] = None + return_type: Optional[ts.FieldType | ts.ScalarType | ts.TypeVarType] = None if ( func_name in fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES + fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 49d2f3d5b4..7aa8aa6cd5 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -181,12 +181,16 @@ def _transform_by_pattern( lambda x: next(uids["__tmp"]), result_collection_constructor=as_tuple, )(tmp_expr.type) + # the lowered IR is concrete, so `extract_dtype` never yields a `TypeVarType` here tmp_dtypes: ( ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...] - ) = type_info.tree_map_type( - type_info.extract_dtype, - result_collection_constructor=as_tuple, - )(tmp_expr.type) + ) = cast( + "ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...]", + type_info.tree_map_type( + type_info.extract_dtype, + result_collection_constructor=as_tuple, + )(tmp_expr.type), + ) tmp_domains: SymbolicDomain | tuple[SymbolicDomain | tuple, ...] = tmp_expr.annex.domain diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index f09ae16bd9..64b73ab9f2 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -396,7 +396,8 @@ def __call__( # expensive type deduction for all arguments and not include it in the key. if enable_jit: warnings.warn( - "Calling generic programs / direct calls to scan operators are not optimized. " + "Calling generic programs / operators (e.g. operators with generic dtype " + "or direct calls to scan operators) is not optimized. " "Consider calling a specialized version instead.", stacklevel=3, ) @@ -459,19 +460,11 @@ def _is_generic(self) -> bool: Is the operator or program generic in the sense that it can be called for different argument types. - Right now this is only the case for scan operators. + Right now this is the case for scan operators (genericity communicated via + `DeferredType` parameters created in `type_info.type_in_program_context`) and for + operators with a generic dtype (type variable). """ - # TODO(tehrengruber): This concept does not exist elsewhere and is not properly reflected - # in the type system. For now we just use `DeferredType` to communicate between - # here and `type_info.type_in_program_context`. - return any( - isinstance(t, ts.DeferredType) - for t in itertools.chain( - self.program_type.definition.pos_only_args, - self.program_type.definition.pos_or_kw_args.values(), - self.program_type.definition.kw_only_args.values(), - ) - ) + return type_info.is_generic(self.program_type.definition) @functools.cached_property def _args_canonicalizer(self) -> Callable[..., tuple[tuple, dict[str, Any]]]: diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 2e8983ccb8..ed60aec004 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -104,7 +104,9 @@ class MemletExpr: @property def gt_dtype(self) -> ts.ScalarType | ts.ListType: - return self.gt_field.dtype + dtype = self.gt_field.dtype + assert ti.is_concrete_dtype(dtype) + return dtype def __post_init__(self) -> None: if isinstance(self.gt_dtype, ts.ListType): diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 223eff2d79..aacc0ed7bc 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -896,6 +896,7 @@ def _add_storage( dc_dtype = gtx_dace_args.as_dace_type(gt_type.dtype) all_dims = gt_type.dims else: # for 'ts.ListType' use 'offset_type' as local dimension + assert isinstance(gt_type.dtype, ts.ListType) assert gt_type.dtype.offset_type is not None assert gt_type.dtype.offset_type.kind == gtx_common.DimensionKind.LOCAL assert isinstance(gt_type.dtype.element_type, ts.ScalarType) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py index ded61af77b..fde144dedc 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py @@ -143,6 +143,7 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: if isinstance(output_type.dtype, ts.ScalarType): all_dims = gtx_common.order_dimensions(output_type.dims) else: + assert isinstance(output_type.dtype, ts.ListType) assert output_type.dtype.offset_type all_dims = gtx_common.order_dimensions([*output_type.dims, output_type.dtype.offset_type]) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index fddcb080b8..0169e8b149 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -622,7 +622,9 @@ def _handle_dataflow_result_of_nested_sdfg( None, dace.Memlet.from_array(outer_dataname, outer_desc), ) - output_expr = gtir_dataflow.ValueExpr(outer_node, inner_data.gt_type.dtype) + output_dtype = inner_data.gt_type.dtype + assert ti.is_concrete_dtype(output_dtype) + output_expr = gtir_dataflow.ValueExpr(outer_node, output_dtype) return gtir_dataflow.DataflowOutputEdge(outer_ctx.state, output_expr) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py index 83cd7660d8..30b9e4c7a3 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py @@ -21,7 +21,7 @@ from gt4py.next.iterator import builtins as gtir_builtins from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args from gt4py.next.program_processors.runners.dace.lowering import gtir_dataflow, gtir_domain -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -86,9 +86,9 @@ def get_local_view( (dim, dace.symbolic.SymExpr(0) if self.origin is None else self.origin[i]) for i, dim in enumerate(self.gt_type.dims) ] - return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, field_origin, it_indices - ) + dtype = self.gt_type.dtype + assert ti.is_concrete_dtype(dtype) + return gtir_dataflow.IteratorExpr(self.dc_node, dtype, field_origin, it_indices) raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 25dce1c3f5..2456489de4 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -54,6 +54,84 @@ def is_concrete(symbol_type: ts.TypeSpec) -> TypeGuard[ts.TypeSpec]: return False +def _function_type_arg_groups( + function_type: ts.FunctionType, +) -> tuple[Sequence[ts.TypeSpec], ...]: + """The positional, keyword and return sub-type groups of a function type, in canonical order.""" + return ( + function_type.pos_only_args, + tuple(function_type.pos_or_kw_args.values()), + tuple(function_type.kw_only_args.values()), + (function_type.returns,), + ) + + +def _function_type_children(function_type: ts.FunctionType) -> tuple[ts.TypeSpec, ...]: + """Return the argument and return sub-types of a function type, in canonical order.""" + return tuple(child for group in _function_type_arg_groups(function_type) for child in group) + + +def _map_function_type( + function_type: ts.FunctionType, transform: Callable[[ts.TypeSpec], ts.TypeSpec] +) -> ts.FunctionType: + """Apply ``transform`` to each argument and return sub-type of a function type.""" + return ts.FunctionType( + pos_only_args=[transform(a) for a in function_type.pos_only_args], + pos_or_kw_args={name: transform(a) for name, a in function_type.pos_or_kw_args.items()}, + kw_only_args={name: transform(a) for name, a in function_type.kw_only_args.items()}, + returns=transform(function_type.returns), + ) + + +def _type_params(symbol_type: ts.TypeSpec) -> tuple[ts.TypeSpec, ...]: + """Return the immediate type-parameter sub-types of ``symbol_type``. + + These are its dtype, element type, tuple elements, or function argument / return types. + """ + match symbol_type: + case ts.FieldType(dtype=dtype): + return (dtype,) + case ts.ListType(element_type=element_type): + return (element_type,) + case ts.TupleType(types=types) | ts.NamedCollectionType(types=types): + return tuple(types) + case ts.FunctionType(): + return _function_type_children(symbol_type) + # callable type wrappers (e.g. the field operator types in `ffront`) carry their + # signature in a `definition` attribute + if isinstance(definition := getattr(symbol_type, "definition", None), ts.TypeSpec): + return (definition,) + return () + + +def is_generic(symbol_type: ts.TypeSpec) -> bool: + """ + Figure out if a type contains parts that are only known when concrete arguments are given. + + Recurses into composite types, reporting ``True`` if any nested part is a `DeferredType` or + `TypeVarType`. Unlike :func:`is_concrete` (a shallow top-level check), this is deep, so a + tuple with a nested `DeferredType` is both concrete and generic. + + Note: this returns ``True`` for a bare ``astype`` constructor type, whose ``definition`` + carries a ``DeferredType`` by design; callers that only care about *data* arguments must + filter for ``ts.DataType`` themselves. + + Examples: + >>> is_generic(ts.DeferredType(constraint=None)) + True + + >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) + >>> is_generic(bool_type) + False + + >>> is_generic(ts.TupleType(types=[bool_type, ts.DeferredType(constraint=None)])) + True + """ + if isinstance(symbol_type, (ts.DeferredType, ts.TypeVarType)): + return True + return any(is_generic(p) for p in _type_params(symbol_type)) + + def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: """ Determine which class should be used to create a compatible concrete type. @@ -181,9 +259,9 @@ def tree_map_type( ) -def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: +def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType | ts.TypeVarType: """ - Extract the data type from ``symbol_type`` if it is either `FieldType` or `ScalarType`. + Extract the data type from ``symbol_type`` if it is `FieldType`, `ScalarType` or `TypeVarType`. Raise an error if no dtype can be found or the result would be ambiguous. @@ -201,9 +279,16 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: return dtype case ts.ScalarType() as dtype: return dtype + case ts.TypeVarType() as dtype: + return dtype raise ValueError(f"Can not unambiguosly extract data type from '{symbol_type}'.") +def is_concrete_dtype(dtype: ts.TypeSpec) -> xtyping.TypeIs[ts.ScalarType | ts.ListType]: + """Whether ``dtype`` is a concrete field dtype, i.e. not a (generic) `TypeVarType`.""" + return isinstance(dtype, (ts.ScalarType, ts.ListType)) + + def _scalar_kinds(scalar_types: tuple[type, ...]) -> frozenset[ts.ScalarKind]: # Derived from the canonical scalar-type tuples in `gt4py._core.definitions` so the two # stay in sync; the `int`/`float` builtins collapse onto their fixed-width kind. @@ -212,13 +297,21 @@ def _scalar_kinds(scalar_types: tuple[type, ...]) -> frozenset[ts.ScalarKind]: _FLOATING_POINT_KINDS: Final[frozenset[ts.ScalarKind]] = _scalar_kinds(core_defs.FLOAT_TYPES) _INTEGRAL_KINDS: Final[frozenset[ts.ScalarKind]] = _scalar_kinds(core_defs.INTEGRAL_TYPES) +_ARITHMETIC_KINDS: Final[frozenset[ts.ScalarKind]] = _FLOATING_POINT_KINDS | _INTEGRAL_KINDS def _is_field_or_scalar_of_kind(symbol_type: ts.TypeSpec, kinds: Collection[ts.ScalarKind]) -> bool: - """Check if ``symbol_type`` is a scalar or a field whose dtype kind is in ``kinds``.""" + """Check if ``symbol_type`` is a scalar or a field whose dtype kind is in ``kinds``. + + A type variable has the property iff all of its constraints have it. + """ + if isinstance(symbol_type, ts.TypeVarType): + return all(_is_field_or_scalar_of_kind(c, kinds) for c in symbol_type.constraints) if not isinstance(symbol_type, (ts.ScalarType, ts.FieldType)): return False dtype = extract_dtype(symbol_type) + if isinstance(dtype, ts.TypeVarType): + return all(_is_field_or_scalar_of_kind(c, kinds) for c in dtype.constraints) return isinstance(dtype, ts.ScalarType) and dtype.kind in kinds @@ -289,7 +382,7 @@ def is_arithmetic_scalar(symbol_type: ts.TypeSpec) -> bool: ... ) False """ - if not isinstance(symbol_type, ts.ScalarType): + if not isinstance(symbol_type, (ts.ScalarType, ts.TypeVarType)): return False return is_arithmetic(symbol_type) @@ -323,7 +416,7 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: >>> is_arithmetic(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) True """ - return is_floating_point(symbol_type) or is_integral(symbol_type) + return _is_field_or_scalar_of_kind(symbol_type, _ARITHMETIC_KINDS) def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.number]: @@ -398,7 +491,7 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: >>> extract_dims(ts.FieldType(dims=[I, J], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64))) [Dimension(value='I', kind=), Dimension(value='J', kind=)] """ - if isinstance(symbol_type, ts.ScalarType): + if isinstance(symbol_type, (ts.ScalarType, ts.TypeVarType)): return [] if isinstance(symbol_type, ts.FieldType): return symbol_type.dims @@ -488,17 +581,12 @@ def is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec) -> bool: for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True): is_compatible &= is_compatible_type(el_type_a, el_type_b) elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType): - for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): - is_compatible &= is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip( - type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True - ): - is_compatible &= is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip( - type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True + # zip per group (not flattened) so a positional/keyword arity mismatch is still caught + for group_a, group_b in zip( + _function_type_arg_groups(type_a), _function_type_arg_groups(type_b), strict=True ): - is_compatible &= is_compatible_type(arg_a, arg_b) - is_compatible &= is_compatible_type(type_a.returns, type_b.returns) + for arg_a, arg_b in zip(group_a, group_b, strict=True): + is_compatible &= is_compatible_type(arg_a, arg_b) else: is_compatible &= is_concretizable(type_a, type_b) @@ -558,9 +646,122 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: return False +def _bind_var(var: ts.TypeVarType, dtype: ts.TypeSpec) -> dict[str, ts.ScalarType]: + if not isinstance(dtype, ts.ScalarType): + # not a concrete scalar to bind to -- e.g. a `TypeVarType` (operator-from-operator + # call), a `DeferredType` (scan), or a `ListType` (local field). Leave it unbound; + # the caller is responsible for checking that no type variable remained unbound. + return {} + if dtype not in var.constraints: + raise ValueError(f"'{dtype}' does not satisfy the constraints of type variable '{var}'.") + return {var.name: dtype} + + +def _merge_bindings(parts: Iterable[dict[str, ts.ScalarType]]) -> dict[str, ts.ScalarType]: + binding: dict[str, ts.ScalarType] = {} + for part in parts: + for name, dtype in part.items(): + if (previous := binding.get(name)) is not None and previous != dtype: + raise ValueError( + f"Type variable '{name}' is bound inconsistently:" + f" '{previous}' and '{dtype}' (all arguments using '{name}'" + " must have the same dtype)." + ) + binding[name] = dtype + return binding + + +def _bind(param: ts.TypeSpec, arg: ts.TypeSpec) -> dict[str, ts.ScalarType]: + match param: + case ts.TypeVarType() as var: + return _bind_var(var, arg) + case ts.FieldType(dtype=ts.TypeVarType() as var): + # scalar arguments are promoted to zero-dimensional fields + return _bind_var(var, arg.dtype if isinstance(arg, ts.FieldType) else arg) + case ts.ListType(element_type=element_type) if isinstance(arg, ts.ListType): + return _bind(element_type, arg.element_type) + case ts.TupleType() | ts.NamedCollectionType() if isinstance( + arg, (ts.TupleType, ts.NamedCollectionType) + ): + # tolerant by design: a structural mismatch (e.g. tuple vs scalar) binds nothing + # here and is reported by the regular signature checks instead. + return _merge_bindings(_bind(p, a) for p, a in zip(param.types, arg.types)) + return {} + + +def bind_type_vars( + params: Sequence[ts.TypeSpec], args: Sequence[ts.TypeSpec] +) -> dict[str, ts.ScalarType]: + """ + Compute a binding of all type variables in ``params`` by structurally matching ``args``. + + Concrete (non-generic) parts of the parameters are ignored; a type variable position binds + only if the corresponding argument provides a concrete scalar dtype. The caller is + responsible for checking that no type variable remained unbound. + + Raises: + ValueError: If a type variable would be bound inconsistently or to a dtype that is + not one of its constraints. + + Examples: + >>> var = ts.TypeVarType(name="T", constraints=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),)) + >>> I = common.Dimension(value="I") + >>> binding = bind_type_vars( + ... [ts.FieldType(dims=[I], dtype=var)], + ... [ts.FieldType(dims=[I], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))], + ... ) + >>> print(binding["T"]) + float64 + """ + return _merge_bindings(_bind(param, arg) for param, arg in zip(params, args)) + + +def substitute_type_vars( + type_: ts.TypeSpec, binding: xtyping.Mapping[str, ts.ScalarType] +) -> ts.TypeSpec: + """ + Replace all type variables in ``type_`` that are bound in ``binding``. + + Unbound type variables and all other generic parts (e.g. `DeferredType`) are kept as-is. + + Examples: + >>> var = ts.TypeVarType(name="T", constraints=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),)) + >>> I = common.Dimension(value="I") + >>> print( + ... substitute_type_vars( + ... ts.FieldType(dims=[I], dtype=var), + ... {"T": ts.ScalarType(kind=ts.ScalarKind.FLOAT64)}, + ... ) + ... ) + Field[[I], float64] + """ + if not binding: + return type_ + + def substitute_leaf(leaf: ts.TypeSpec) -> ts.TypeSpec: + # `tree_map_type` has already mapped the tuple structure; what is left is to substitute + # inside the primitive constituents, i.e. in their dtype / element type / signature. + match leaf: + case ts.TypeVarType(): + return binding.get(leaf.name, leaf) + case ts.FieldType(dims=dims, dtype=dtype): + new_dtype = substitute_type_vars(dtype, binding) + assert isinstance(new_dtype, (ts.ScalarType, ts.ListType, ts.TypeVarType)) + return ts.FieldType(dims=dims, dtype=new_dtype) + case ts.ListType(element_type=element_type, offset_type=offset_type): + new_element_type = substitute_type_vars(element_type, binding) + assert isinstance(new_element_type, ts.DataType) + return ts.ListType(element_type=new_element_type, offset_type=offset_type) + case ts.FunctionType(): + return _map_function_type(leaf, lambda a: substitute_type_vars(a, binding)) + return leaf + + return tree_map_type(substitute_leaf)(type_) + + def promote( - *types: ts.FieldType | ts.ScalarType, always_field: bool = False -) -> ts.FieldType | ts.ScalarType: + *types: ts.FieldType | ts.ScalarType | ts.TypeVarType, always_field: bool = False +) -> ts.FieldType | ts.ScalarType | ts.TypeVarType: """ Promote a set of field or scalar types to a common type. @@ -582,17 +783,29 @@ def promote( >>> promoted.dims == [I, J, K] and promoted.dtype == dtype True """ - if not always_field and all(isinstance(type_, ts.ScalarType) for type_ in types): + if not always_field and all( + isinstance(type_, (ts.ScalarType, ts.TypeVarType)) for type_ in types + ): if not all(type_ == types[0] for type_ in types): + if any(isinstance(type_, ts.TypeVarType) for type_ in types): + distinct_types = "', '".join(str(t) for t in dict.fromkeys(types)) + raise ValueError( + f"Could not promote '{distinct_types}': a generic dtype (type variable)" + " can only be combined with values of the same type variable," + " not with other dtypes." + ) raise ValueError("Could not promote scalars of different dtype (not implemented).") - if not all(type_.shape is None for type_ in types): # type: ignore[union-attr] + if not all(type_.shape is None for type_ in types if isinstance(type_, ts.ScalarType)): raise NotImplementedError("Shape promotion not implemented.") return types[0] - elif all(isinstance(type_, (ts.ScalarType, ts.FieldType)) for type_ in types): + elif all(isinstance(type_, (ts.ScalarType, ts.FieldType, ts.TypeVarType)) for type_ in types): dims = common.promote_dims(*(extract_dims(type_) for type_ in types)) extracted_dtypes = [extract_dtype(type_) for type_ in types] - assert all(isinstance(dtype, ts.ScalarType) for dtype in extracted_dtypes) - dtype = cast(ts.ScalarType, promote(*extracted_dtypes)) # type: ignore[arg-type] # checked is `ScalarType` + assert all(isinstance(dtype, (ts.ScalarType, ts.TypeVarType)) for dtype in extracted_dtypes) + dtype = cast( # type variables promote like scalars (only with themselves) + ts.ScalarType | ts.TypeVarType, + promote(*extracted_dtypes), # type: ignore[arg-type] # checked above + ) return ts.FieldType(dims=dims, dtype=dtype) raise TypeError("Expected a 'FieldType' or 'ScalarType'.") diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 59ac40f0f3..8450bbb7d6 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -105,6 +105,38 @@ def __str__(self) -> str: return f"{kind_str}{self.shape}" +def _canonicalize_constraints(constraints: Sequence[ScalarType]) -> tuple[ScalarType, ...]: + # A value-constrained type variable resolves to exactly one of its constraints, so their + # order carries no meaning; canonicalize it to make `TypeVarType` identity order-insensitive. + return tuple(sorted(constraints, key=lambda c: c.kind)) + + +class TypeVarType(DataType): + """ + A scalar type variable, universally quantified over its constraints. + + Represents the type of a value-constrained Python ``typing.TypeVar`` (e.g. + ``TypeVar("T", float32, float64)``) used in the signature of a generic operator. + Two occurrences with the same ``name`` within one signature denote the same type. + """ + + name: str + constraints: tuple[ScalarType, ...] = eve_datamodels.field(converter=_canonicalize_constraints) + + def __str__(self) -> str: + return f"{self.name}: ({' | '.join(map(str, self.constraints))})" + + @eve_datamodels.validator("constraints") + def _constraints_validator( + self, attribute: eve_datamodels.Attribute, constraints: tuple[ScalarType, ...] + ) -> None: + if not constraints: + raise ValueError( + f"Type variable '{self.name}' must be value-constrained, i.e. have at" + " least one constraint." + ) + + class ListType(DataType): """Represents a neighbor list in the ITIR representation. @@ -119,7 +151,7 @@ class ListType(DataType): class FieldType(DataType, CallableType): dims: list[common.Dimension] - dtype: ScalarType | ListType + dtype: ScalarType | ListType | TypeVarType def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_info.py b/tests/next_tests/unit_tests/type_system_tests/test_type_info.py index 35c3d2eba1..33b8da2348 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_info.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_info.py @@ -373,6 +373,206 @@ def test_type_info_basic(symbol_type, expected): assert getattr(type_info, key)(symbol_type) == expected[key] +def is_generic_cases() -> list[tuple[ts.TypeSpec, bool]]: + deferred_type = ts.DeferredType(constraint=None) + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + concrete_field_type = ts.FieldType(dims=[TDim], dtype=float_type) + + def function_type(params: list[ts.TypeSpec]) -> ts.FunctionType: + return ts.FunctionType( + pos_only_args=[], + pos_or_kw_args={f"arg{i}": param for i, param in enumerate(params)}, + kw_only_args={}, + returns=ts.VoidType(), + ) + + return [ + (deferred_type, True), + (float_type, False), + (concrete_field_type, False), + (ts.TupleType(types=[float_type, concrete_field_type]), False), + # `DeferredType` nested inside a composite type, e.g. the program context signature + # of a scan operator with tuple arguments + (ts.TupleType(types=[float_type, deferred_type]), True), + (function_type([concrete_field_type]), False), + (function_type([deferred_type]), True), + (function_type([ts.TupleType(types=[deferred_type])]), True), + ( + ts_ffront.ProgramType(definition=function_type([deferred_type])), + True, + ), + ( + ts_ffront.FieldOperatorType(definition=function_type([concrete_field_type])), + False, + ), + ] + + +@pytest.mark.parametrize("symbol_type,expected", is_generic_cases()) +def test_is_generic(symbol_type: ts.TypeSpec, expected: bool): + assert type_info.is_generic(symbol_type) == expected + + +float32_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) +float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +int32_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) +float_var = ts.TypeVarType(name="T", constraints=(float32_type, float64_type)) +mixed_var = ts.TypeVarType(name="U", constraints=(float64_type, int32_type)) + + +class TestTypeVarType: + def test_validation(self): + with pytest.raises(ValueError, match="value-constrained"): + ts.TypeVarType(name="T", constraints=()) + + def test_identity_and_hashing(self): + from gt4py.eve import utils as eve_utils + + same_var = ts.TypeVarType(name="T", constraints=(float32_type, float64_type)) + assert float_var == same_var + assert hash(float_var) == hash(same_var) + assert eve_utils.content_hash(float_var) == eve_utils.content_hash(same_var) + assert float_var != ts.TypeVarType(name="S", constraints=(float32_type, float64_type)) + # constraint order is canonicalized, so it is not part of the identity + assert float_var == ts.TypeVarType(name="T", constraints=(float64_type, float32_type)) + + def test_is_generic(self): + assert type_info.is_generic(float_var) + assert type_info.is_generic(ts.FieldType(dims=[TDim], dtype=float_var)) + assert type_info.is_generic( + ts.TupleType(types=[float64_type, ts.FieldType(dims=[TDim], dtype=float_var)]) + ) + + @pytest.mark.parametrize( + "predicate,var,expected", + [ + (type_info.is_floating_point, float_var, True), + (type_info.is_floating_point, mixed_var, False), + (type_info.is_integral, float_var, False), + (type_info.is_integral, ts.TypeVarType(name="I", constraints=(int32_type,)), True), + (type_info.is_arithmetic, float_var, True), + (type_info.is_arithmetic, mixed_var, True), + ( + type_info.is_arithmetic, + ts.TypeVarType(name="B", constraints=(bool_type, float64_type)), + False, + ), + (type_info.is_logical, ts.TypeVarType(name="B", constraints=(bool_type,)), True), + (type_info.is_logical, float_var, False), + (type_info.is_arithmetic_scalar, float_var, True), + ], + ) + def test_predicates_evaluate_over_constraints(self, predicate, var, expected): + assert predicate(var) == expected + if predicate is not type_info.is_arithmetic_scalar: # rejects fields by design + assert predicate(ts.FieldType(dims=[TDim], dtype=var)) == expected + + def test_promote_same_var(self): + assert type_info.promote(float_var, float_var) == float_var + promoted = type_info.promote( + ts.FieldType(dims=[TDim], dtype=float_var), ts.FieldType(dims=[TDim], dtype=float_var) + ) + assert promoted == ts.FieldType(dims=[TDim], dtype=float_var) + + def test_promote_var_with_scalar_arg(self): + promoted = type_info.promote(ts.FieldType(dims=[TDim], dtype=float_var), float_var) + assert promoted == ts.FieldType(dims=[TDim], dtype=float_var) + + @pytest.mark.parametrize( + "types", + [ + (float_var, float64_type), + (float_var, mixed_var), + (ts.FieldType(dims=[TDim], dtype=float_var), float64_type), + ( + ts.FieldType(dims=[TDim], dtype=float_var), + ts.FieldType(dims=[TDim], dtype=float64_type), + ), + ], + ) + def test_promote_mixing_error(self, types): + with pytest.raises(ValueError, match="type variable"): + type_info.promote(*types) + + +class TestBindTypeVars: + def test_bind_from_field(self): + binding = type_info.bind_type_vars( + [ts.FieldType(dims=[TDim], dtype=float_var)], + [ts.FieldType(dims=[TDim], dtype=float32_type)], + ) + assert binding == {"T": float32_type} + + def test_bind_from_scalar_and_nested(self): + binding = type_info.bind_type_vars( + [ts.TupleType(types=[float_var, ts.FieldType(dims=[TDim], dtype=float_var)])], + [ts.TupleType(types=[float64_type, ts.FieldType(dims=[TDim], dtype=float64_type)])], + ) + assert binding == {"T": float64_type} + + def test_concrete_params_dont_bind(self): + assert ( + type_info.bind_type_vars( + [ts.FieldType(dims=[TDim], dtype=float64_type)], + [ts.FieldType(dims=[TDim], dtype=float32_type)], + ) + == {} + ) + + def test_inconsistent_binding(self): + with pytest.raises(ValueError, match="bound inconsistently"): + type_info.bind_type_vars( + [ + ts.FieldType(dims=[TDim], dtype=float_var), + ts.FieldType(dims=[TDim], dtype=float_var), + ], + [ + ts.FieldType(dims=[TDim], dtype=float32_type), + ts.FieldType(dims=[TDim], dtype=float64_type), + ], + ) + + def test_constraint_violation(self): + with pytest.raises(ValueError, match="constraints"): + type_info.bind_type_vars( + [ts.FieldType(dims=[TDim], dtype=float_var)], + [ts.FieldType(dims=[TDim], dtype=int32_type)], + ) + + +class TestSubstituteTypeVars: + def test_substitute(self): + generic = ts.TupleType( + types=[float_var, ts.FieldType(dims=[TDim], dtype=float_var), int32_type] + ) + substituted = type_info.substitute_type_vars(generic, {"T": float32_type}) + assert substituted == ts.TupleType( + types=[float32_type, ts.FieldType(dims=[TDim], dtype=float32_type), int32_type] + ) + assert not type_info.is_generic(substituted) + + def test_unbound_vars_are_kept(self): + generic = ts.FieldType(dims=[TDim], dtype=float_var) + assert type_info.substitute_type_vars(generic, {"S": float32_type}) == generic + + def test_concrete_is_returned_unchanged(self): + concrete = ts.FieldType(dims=[TDim], dtype=float64_type) + assert type_info.substitute_type_vars(concrete, {"T": float32_type}) == concrete + + def test_substitute_function_type(self): + func_type = ts.FunctionType( + pos_only_args=[ts.FieldType(dims=[TDim], dtype=float_var)], + pos_or_kw_args={"a": float_var}, + kw_only_args={}, + returns=ts.FieldType(dims=[TDim], dtype=float_var), + ) + substituted = type_info.substitute_type_vars(func_type, {"T": float64_type}) + assert substituted.pos_only_args[0] == ts.FieldType(dims=[TDim], dtype=float64_type) + assert substituted.pos_or_kw_args["a"] == float64_type + assert substituted.returns == ts.FieldType(dims=[TDim], dtype=float64_type) + + @pytest.mark.parametrize("func_type,args,kwargs,expected,return_type", callable_type_info_cases()) def test_accept_args( func_type: ts.TypeSpec,