Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ def as_field(
origin: Mapping[common.Dimension, int] | None = None,
) -> nd_array_field.NdArrayField:
"""Create a `Field` from an array-like object. See :func:`as_field` for details."""
if (
not isinstance(domain, str)
and isinstance(domain, Sequence)
and any(isinstance(d, str) for d in domain)
):
raise TypeError(
f"Invalid domain {domain!r}: dimensions must be 'Dimension' objects, not strings."
)
if isinstance(domain, Sequence) and all(
isinstance(dim, common.Dimension) for dim in domain
):
Expand Down
8 changes: 6 additions & 2 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def from_array(
if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type

assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES)
if not issubclass(array.dtype.type, core_defs.SCALAR_TYPES):
raise ValueError(
f"Cannot construct 'Field' from array with unsupported dtype '{array.dtype}'."
)

assert all(isinstance(d, common.Dimension) for d in domain.dims), domain
assert len(domain) == array.ndim
Expand Down Expand Up @@ -498,7 +501,8 @@ def from_array( # type: ignore[override]
if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type

assert issubclass(array.dtype.type, core_defs.INTEGRAL_TYPES)
if not issubclass(array.dtype.type, core_defs.INTEGRAL_TYPES):
raise ValueError(f"Neighbor tables must have an integral dtype, got '{array.dtype}'.")

assert all(isinstance(d, common.Dimension) for d in domain.dims), domain
assert len(domain) == array.ndim
Expand Down
21 changes: 16 additions & 5 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,22 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) ->
# without checking if the types are consistent. However, these errors are caught in linting if enabled.
container_extracted_out = arguments.extract(out)
assert xtyping.is_maybe_nested_in_tuple_of(container_extracted_out, common.MutableField) # type: ignore[type-abstract] # MutableField is abstract/generic
out_domain = (
utils.tree_map(common.domain)(domain)
if domain is not None
else _get_out_domain(container_extracted_out)
)
try:
out_domain = (
utils.tree_map(common.domain)(domain)
if domain is not None
else _get_out_domain(container_extracted_out)
)
except ValueError as err:
raise errors.DSLTypeError(
None,
f"Invalid 'domain' argument: {err}",
hints=(
"Pass a mapping from dimensions to ranges, e.g. "
"'domain={IDim: (0, 10)}' (or a tuple thereof matching a tuple "
"'out' argument).",
),
) from err

new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain)

Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MissingParameterAnnotationError,
UndefinedSymbolError,
UnsupportedPythonFeatureError,
did_you_mean,
)


Expand All @@ -32,4 +33,5 @@
"MissingParameterAnnotationError",
"UndefinedSymbolError",
"UnsupportedPythonFeatureError",
"did_you_mean",
]
13 changes: 9 additions & 4 deletions src/gt4py/next/errors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def message(self) -> str:
return self.args[0]


def _did_you_mean(name: str, candidates: Iterable[str]) -> list[str]:
def did_you_mean(name: str, candidates: Iterable[str]) -> list[str]:
"""Produce a 'Did you mean ...?' hint if `name` closely matches any candidate."""
# Never suggest the name the user already wrote (it can appear among the
# candidates as a same-named symbol from another SSA generation).
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(
location,
f"Undeclared symbol '{name}'.",
label="not defined at this point",
hints=_did_you_mean(name, candidates),
hints=did_you_mean(name, candidates),
)
self.sym_name = name

Expand All @@ -183,8 +183,8 @@ def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg:


class DSLTypeError(DSLError):
def __init__(self, location: Optional[SourceLocation], message: str) -> None:
super().__init__(location, message)
def __init__(self, location: Optional[SourceLocation], message: str, **kwargs: Any) -> None:
super().__init__(location, message, **kwargs)


class MissingParameterAnnotationError(DSLTypeError):
Expand All @@ -203,6 +203,11 @@ def __init__(self, location: Optional[SourceLocation], param_name: str, type_: A
super().__init__(
location, f"Parameter '{param_name}' has invalid type annotation '{type_}'."
)
self.hints = [
"Annotate parameters with a GT4Py type: a field (e.g. "
"'gtx.Field[[IDim], gtx.float64]'), a scalar (e.g. 'gtx.float64') or a "
"tuple of these."
]
self.param_name = param_name
self.annotated_type = type_

Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/ffront/ast_passes/simple_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def generic_visit(self, node: ast.AST) -> Iterator[ast.AST]: # type: ignore[ove
"""Override generic visit to deal with generators."""
for field, old_value in ast.iter_fields(node):
if isinstance(old_value, list):
new_values = [i for j in old_value for i in self.visit(j)]
# fields may contain non-AST values, e.g. the names of a `global`
# statement; those must be kept as-is, not visited.
new_values = [
i
for j in old_value
for i in (self.visit(j) if isinstance(j, ast.AST) else (j,))
]
old_value[:] = new_values
elif isinstance(old_value, ast.AST):
new_node, *_ = list(self.visit(old_value)) or (None,)
Expand Down
97 changes: 92 additions & 5 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import types
import typing
import warnings
from collections.abc import Callable
from collections.abc import Callable, Mapping
from typing import Any, Generic, Optional, Sequence, TypeAlias

from gt4py import eve
Expand Down Expand Up @@ -59,6 +59,34 @@
)


def _validate_offset_provider(offset_provider: Any) -> None:
if not isinstance(offset_provider, Mapping):
raise errors.DSLTypeError(
None,
"'offset_provider' must be a mapping from offset names to dimensions or "
f"connectivities, got '{type(offset_provider).__name__}'.",
)


def _type_of_argument(value: Any, description: str, function_name: str) -> ts.TypeSpec:
"""Translate a call argument to its GT4Py type, reporting failures as :class:`errors.DSLError`."""
try:
return type_translation.from_value(value)
except Exception as err:
hints: tuple[str, ...] = ()
if hasattr(value, "__array__") or hasattr(value, "__cuda_array_interface__"):
hints = (
"Wrap raw arrays in a GT4Py field before passing them to a program or "
"operator, e.g. 'gtx.as_field([IDim, JDim], array)'.",
)
raise errors.DSLTypeError(
None,
f"In call to '{function_name}': {description} has a type not supported by "
f"GT4Py: '{type(value).__name__}'.",
hints=hints,
) from err


@hook_machinery.context_hook
def program_call_context(
program: Program,
Expand Down Expand Up @@ -98,6 +126,10 @@ class _CompilableGTEntryPointMixin(Generic[ffront_stages.DSLDefinitionT]):
def __gt_type__(self) -> ts.CallableType: ...

def with_backend(self, backend: next_backend.Backend | None) -> Self:
if backend is not None and not isinstance(backend, next_backend.Backend):
raise TypeError(
f"Expected a 'gt4py.next' backend or 'None', got '{type(backend).__name__}'."
)
return dataclasses.replace(self, backend=backend)

def with_compilation_options(
Expand Down Expand Up @@ -373,6 +405,7 @@ def __call__(
) -> None:
if offset_provider is None:
offset_provider = {}
_validate_offset_provider(offset_provider)
enable_jit = self.compilation_options.enable_jit if enable_jit is None else enable_jit

with program_call_context(
Expand All @@ -386,8 +419,14 @@ def __call__(
# TODO: remove or make dependency on self.past_stage optional
past_process_args._validate_args(
self.past_stage.past_node,
arg_types=[type_translation.from_value(arg) for arg in args],
kwarg_types={k: type_translation.from_value(v) for k, v in kwargs.items()},
arg_types=[
_type_of_argument(arg, f"argument {i + 1}", self.__name__)
for i, arg in enumerate(args)
],
kwarg_types={
k: _type_of_argument(v, f"keyword argument '{k}'", self.__name__)
for k, v in kwargs.items()
},
)

if self.backend is not None:
Expand Down Expand Up @@ -445,8 +484,13 @@ def __call__(
)
)

arg_types = [type_translation.from_value(arg) for arg in args]
kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()}
arg_types = [
_type_of_argument(arg, f"argument {i + 1}", self.__name__) for i, arg in enumerate(args)
]
kwarg_types = {
k: _type_of_argument(v, f"keyword argument '{k}'", self.__name__)
for k, v in kwargs.items()
}

try:
# This error is also catched using `accepts_args`, but we do it manually here to give
Expand Down Expand Up @@ -643,7 +687,50 @@ def __gt_gtir__(self) -> itir.FunctionDefinition:
def __gt_closure_vars__(self) -> dict[str, Any]:
return self.foast_stage.closure_vars

def _validate_call_args(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
"""
Validate the arguments of a direct (outside-a-program) call.

Mirrors the checks programs run on their arguments; without it, invalid
arguments surface as crashes from deep inside the embedded execution
or the compiled backend.
"""
if "out" not in kwargs:
raise errors.MissingArgumentError(None, "out", True)
if "offset_provider" in kwargs:
_validate_offset_provider(kwargs["offset_provider"])
operator_type = self.__gt_type__()
name = self.__name__
arg_types = [
_type_of_argument(arg, f"argument {i + 1}", name) for i, arg in enumerate(args)
]
kwarg_types = {
k: _type_of_argument(v, f"keyword argument '{k}'", name)
for k, v in kwargs.items()
if k not in ("out", "offset_provider", "domain")
}
try:
type_info.accepts_args(
operator_type, with_args=arg_types, with_kwargs=kwarg_types, raise_exception=True
)
except ValueError as err:
raise errors.DSLError(
None, f"Invalid argument types in call to '{name}'.\n{err}"
) from err
out_type = _type_of_argument(kwargs["out"], "keyword argument 'out'", name)
expected_out_type = type_info.return_type(
operator_type, with_args=arg_types, with_kwargs=kwarg_types
)
if expected_out_type != out_type:
raise errors.DSLTypeError(
None,
f"In call to '{name}': expected keyword argument 'out' to be of type "
f"'{expected_out_type}', got '{out_type}'.",
)

def __call__(self, *args: Any, enable_jit: bool | None = None, **kwargs: Any) -> Any:
if __debug__ and not next_embedded.context.within_valid_context():
self._validate_call_args(args, kwargs)
if not next_embedded.context.within_valid_context() and self.backend is not None:
# non embedded execution
offset_provider = {**kwargs.pop("offset_provider", {})}
Expand Down
47 changes: 45 additions & 2 deletions src/gt4py/next/ffront/dialect_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

import ast
import textwrap
import typing
from dataclasses import dataclass
from typing import Callable, ClassVar, Collection

from gt4py.eve.concepts import SourceLocation
from gt4py.eve.extended_typing import Any, Generic, TypeVar
from gt4py.next import errors
from gt4py.next.ffront import source_utils
from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations
from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings
from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function
Expand Down Expand Up @@ -67,6 +67,20 @@
ast.ClassDef: ("class definition", ()),
ast.JoinedStr: ("f-string", ("Strings cannot be computed inside GT4Py functions.",)),
ast.Match: ("'match' statement", ("Use 'if'/'elif' chains or 'where' instead.",)),
ast.Global: (
"'global' statement",
(
"Variables from the surrounding scope are read-only inside GT4Py functions; "
"pass values as parameters and return results instead.",
),
),
ast.Nonlocal: (
"'nonlocal' statement",
(
"Variables from the surrounding scope are read-only inside GT4Py functions; "
"pass values as parameters and return results instead.",
),
),
}


Expand Down Expand Up @@ -136,7 +150,7 @@ def apply(
def apply_to_function(cls, function: Callable) -> DialectRootT:
src = SourceDefinition.from_function(function)
closure_vars = get_closure_vars_from_function(function)
annotations = typing.get_type_hints(function)
annotations = source_utils.get_type_hints_from_function(function, src)
return cls.apply(src, closure_vars, annotations)

@classmethod
Expand All @@ -159,6 +173,35 @@ def generic_visit(self, node: ast.AST) -> None:
hints=hints,
)

def _validate_signature(self, node: ast.arguments) -> None:
"""Reject parameter-list features that have no DSL semantics with a clear diagnostic."""
if node.posonlyargs:
raise errors.UnsupportedPythonFeatureError(
self.get_location(node.posonlyargs[0]),
"positional-only parameters",
hints=("Remove the '/' marker from the parameter list.",),
)
if node.kwonlyargs:
raise errors.UnsupportedPythonFeatureError(
self.get_location(node.kwonlyargs[0]),
"keyword-only parameters",
hints=("Remove the '*' marker from the parameter list.",),
)
if node.vararg is not None:
raise errors.UnsupportedPythonFeatureError(
self.get_location(node.vararg), "'*args' parameters"
)
if node.kwarg is not None:
raise errors.UnsupportedPythonFeatureError(
self.get_location(node.kwarg), "'**kwargs' parameters"
)
if defaults := [d for d in [*node.defaults, *node.kw_defaults] if d is not None]:
raise errors.UnsupportedPythonFeatureError(
self.get_location(defaults[0]),
"default values for parameters",
hints=("Pass the value explicitly at every call site instead.",),
)

def _check_not_a_reserved_name(self, name: str, location: SourceLocation) -> None:
if name in self.reserved_names:
raise errors.DSLError(
Expand Down
Loading