diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 838f2c5204..0ec13a76ef 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -177,6 +177,14 @@ def as_field( origin: Mapping[common.Dimension, int] | None = None, ) -> nd_array_field.NdArrayField: """Create a `Field` from an array-like object. See :func:`as_field` for details.""" + if ( + not isinstance(domain, str) + and isinstance(domain, Sequence) + and any(isinstance(d, str) for d in domain) + ): + raise TypeError( + f"Invalid domain {domain!r}: dimensions must be 'Dimension' objects, not strings." + ) if isinstance(domain, Sequence) and all( isinstance(dim, common.Dimension) for dim in domain ): diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 69bb89da3a..2c80f4893c 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -134,7 +134,10 @@ def from_array( if dtype is not None: assert array.dtype.type == core_defs.dtype(dtype).scalar_type - assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) + if not issubclass(array.dtype.type, core_defs.SCALAR_TYPES): + raise ValueError( + f"Cannot construct 'Field' from array with unsupported dtype '{array.dtype}'." + ) assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim @@ -498,7 +501,8 @@ def from_array( # type: ignore[override] if dtype is not None: assert array.dtype.type == core_defs.dtype(dtype).scalar_type - assert issubclass(array.dtype.type, core_defs.INTEGRAL_TYPES) + if not issubclass(array.dtype.type, core_defs.INTEGRAL_TYPES): + raise ValueError(f"Neighbor tables must have an integral dtype, got '{array.dtype}'.") assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 072acb7f8d..cace6771c4 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -118,11 +118,22 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> # without checking if the types are consistent. However, these errors are caught in linting if enabled. container_extracted_out = arguments.extract(out) assert xtyping.is_maybe_nested_in_tuple_of(container_extracted_out, common.MutableField) # type: ignore[type-abstract] # MutableField is abstract/generic - out_domain = ( - utils.tree_map(common.domain)(domain) - if domain is not None - else _get_out_domain(container_extracted_out) - ) + try: + out_domain = ( + utils.tree_map(common.domain)(domain) + if domain is not None + else _get_out_domain(container_extracted_out) + ) + except ValueError as err: + raise errors.DSLTypeError( + None, + f"Invalid 'domain' argument: {err}", + hints=( + "Pass a mapping from dimensions to ranges, e.g. " + "'domain={IDim: (0, 10)}' (or a tuple thereof matching a tuple " + "'out' argument).", + ), + ) from err new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 87237a0a75..21b351ad48 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -19,6 +19,7 @@ MissingParameterAnnotationError, UndefinedSymbolError, UnsupportedPythonFeatureError, + did_you_mean, ) @@ -32,4 +33,5 @@ "MissingParameterAnnotationError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", + "did_you_mean", ] diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index fb7e7eaa7d..ab2e83784f 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -36,7 +36,7 @@ def message(self) -> str: return self.args[0] -def _did_you_mean(name: str, candidates: Iterable[str]) -> list[str]: +def did_you_mean(name: str, candidates: Iterable[str]) -> list[str]: """Produce a 'Did you mean ...?' hint if `name` closely matches any candidate.""" # Never suggest the name the user already wrote (it can appear among the # candidates as a same-named symbol from another SSA generation). @@ -157,7 +157,7 @@ def __init__( location, f"Undeclared symbol '{name}'.", label="not defined at this point", - hints=_did_you_mean(name, candidates), + hints=did_you_mean(name, candidates), ) self.sym_name = name @@ -183,8 +183,8 @@ def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg: class DSLTypeError(DSLError): - def __init__(self, location: Optional[SourceLocation], message: str) -> None: - super().__init__(location, message) + def __init__(self, location: Optional[SourceLocation], message: str, **kwargs: Any) -> None: + super().__init__(location, message, **kwargs) class MissingParameterAnnotationError(DSLTypeError): @@ -203,6 +203,11 @@ def __init__(self, location: Optional[SourceLocation], param_name: str, type_: A super().__init__( location, f"Parameter '{param_name}' has invalid type annotation '{type_}'." ) + self.hints = [ + "Annotate parameters with a GT4Py type: a field (e.g. " + "'gtx.Field[[IDim], gtx.float64]'), a scalar (e.g. 'gtx.float64') or a " + "tuple of these." + ] self.param_name = param_name self.annotated_type = type_ diff --git a/src/gt4py/next/ffront/ast_passes/simple_assign.py b/src/gt4py/next/ffront/ast_passes/simple_assign.py index a251490846..3813d78516 100644 --- a/src/gt4py/next/ffront/ast_passes/simple_assign.py +++ b/src/gt4py/next/ffront/ast_passes/simple_assign.py @@ -30,7 +30,13 @@ def generic_visit(self, node: ast.AST) -> Iterator[ast.AST]: # type: ignore[ove """Override generic visit to deal with generators.""" for field, old_value in ast.iter_fields(node): if isinstance(old_value, list): - new_values = [i for j in old_value for i in self.visit(j)] + # fields may contain non-AST values, e.g. the names of a `global` + # statement; those must be kept as-is, not visited. + new_values = [ + i + for j in old_value + for i in (self.visit(j) if isinstance(j, ast.AST) else (j,)) + ] old_value[:] = new_values elif isinstance(old_value, ast.AST): new_node, *_ = list(self.visit(old_value)) or (None,) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c749fcec01..0736cc60f9 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -19,7 +19,7 @@ import types import typing import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import Any, Generic, Optional, Sequence, TypeAlias from gt4py import eve @@ -59,6 +59,34 @@ ) +def _validate_offset_provider(offset_provider: Any) -> None: + if not isinstance(offset_provider, Mapping): + raise errors.DSLTypeError( + None, + "'offset_provider' must be a mapping from offset names to dimensions or " + f"connectivities, got '{type(offset_provider).__name__}'.", + ) + + +def _type_of_argument(value: Any, description: str, function_name: str) -> ts.TypeSpec: + """Translate a call argument to its GT4Py type, reporting failures as :class:`errors.DSLError`.""" + try: + return type_translation.from_value(value) + except Exception as err: + hints: tuple[str, ...] = () + if hasattr(value, "__array__") or hasattr(value, "__cuda_array_interface__"): + hints = ( + "Wrap raw arrays in a GT4Py field before passing them to a program or " + "operator, e.g. 'gtx.as_field([IDim, JDim], array)'.", + ) + raise errors.DSLTypeError( + None, + f"In call to '{function_name}': {description} has a type not supported by " + f"GT4Py: '{type(value).__name__}'.", + hints=hints, + ) from err + + @hook_machinery.context_hook def program_call_context( program: Program, @@ -98,6 +126,10 @@ class _CompilableGTEntryPointMixin(Generic[ffront_stages.DSLDefinitionT]): def __gt_type__(self) -> ts.CallableType: ... def with_backend(self, backend: next_backend.Backend | None) -> Self: + if backend is not None and not isinstance(backend, next_backend.Backend): + raise TypeError( + f"Expected a 'gt4py.next' backend or 'None', got '{type(backend).__name__}'." + ) return dataclasses.replace(self, backend=backend) def with_compilation_options( @@ -373,6 +405,7 @@ def __call__( ) -> None: if offset_provider is None: offset_provider = {} + _validate_offset_provider(offset_provider) enable_jit = self.compilation_options.enable_jit if enable_jit is None else enable_jit with program_call_context( @@ -386,8 +419,14 @@ def __call__( # TODO: remove or make dependency on self.past_stage optional past_process_args._validate_args( self.past_stage.past_node, - arg_types=[type_translation.from_value(arg) for arg in args], - kwarg_types={k: type_translation.from_value(v) for k, v in kwargs.items()}, + arg_types=[ + _type_of_argument(arg, f"argument {i + 1}", self.__name__) + for i, arg in enumerate(args) + ], + kwarg_types={ + k: _type_of_argument(v, f"keyword argument '{k}'", self.__name__) + for k, v in kwargs.items() + }, ) if self.backend is not None: @@ -445,8 +484,13 @@ def __call__( ) ) - arg_types = [type_translation.from_value(arg) for arg in args] - kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} + arg_types = [ + _type_of_argument(arg, f"argument {i + 1}", self.__name__) for i, arg in enumerate(args) + ] + kwarg_types = { + k: _type_of_argument(v, f"keyword argument '{k}'", self.__name__) + for k, v in kwargs.items() + } try: # This error is also catched using `accepts_args`, but we do it manually here to give @@ -643,7 +687,50 @@ def __gt_gtir__(self) -> itir.FunctionDefinition: def __gt_closure_vars__(self) -> dict[str, Any]: return self.foast_stage.closure_vars + def _validate_call_args(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: + """ + Validate the arguments of a direct (outside-a-program) call. + + Mirrors the checks programs run on their arguments; without it, invalid + arguments surface as crashes from deep inside the embedded execution + or the compiled backend. + """ + if "out" not in kwargs: + raise errors.MissingArgumentError(None, "out", True) + if "offset_provider" in kwargs: + _validate_offset_provider(kwargs["offset_provider"]) + operator_type = self.__gt_type__() + name = self.__name__ + arg_types = [ + _type_of_argument(arg, f"argument {i + 1}", name) for i, arg in enumerate(args) + ] + kwarg_types = { + k: _type_of_argument(v, f"keyword argument '{k}'", name) + for k, v in kwargs.items() + if k not in ("out", "offset_provider", "domain") + } + try: + type_info.accepts_args( + operator_type, with_args=arg_types, with_kwargs=kwarg_types, raise_exception=True + ) + except ValueError as err: + raise errors.DSLError( + None, f"Invalid argument types in call to '{name}'.\n{err}" + ) from err + out_type = _type_of_argument(kwargs["out"], "keyword argument 'out'", name) + expected_out_type = type_info.return_type( + operator_type, with_args=arg_types, with_kwargs=kwarg_types + ) + if expected_out_type != out_type: + raise errors.DSLTypeError( + None, + f"In call to '{name}': expected keyword argument 'out' to be of type " + f"'{expected_out_type}', got '{out_type}'.", + ) + def __call__(self, *args: Any, enable_jit: bool | None = None, **kwargs: Any) -> Any: + if __debug__ and not next_embedded.context.within_valid_context(): + self._validate_call_args(args, kwargs) if not next_embedded.context.within_valid_context() and self.backend is not None: # non embedded execution offset_provider = {**kwargs.pop("offset_provider", {})} diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index aa4aae7b49..41c91e4ab6 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -8,13 +8,13 @@ import ast import textwrap -import typing from dataclasses import dataclass from typing import Callable, ClassVar, Collection from gt4py.eve.concepts import SourceLocation from gt4py.eve.extended_typing import Any, Generic, TypeVar from gt4py.next import errors +from gt4py.next.ffront import source_utils from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function @@ -67,6 +67,20 @@ ast.ClassDef: ("class definition", ()), ast.JoinedStr: ("f-string", ("Strings cannot be computed inside GT4Py functions.",)), ast.Match: ("'match' statement", ("Use 'if'/'elif' chains or 'where' instead.",)), + ast.Global: ( + "'global' statement", + ( + "Variables from the surrounding scope are read-only inside GT4Py functions; " + "pass values as parameters and return results instead.", + ), + ), + ast.Nonlocal: ( + "'nonlocal' statement", + ( + "Variables from the surrounding scope are read-only inside GT4Py functions; " + "pass values as parameters and return results instead.", + ), + ), } @@ -136,7 +150,7 @@ def apply( def apply_to_function(cls, function: Callable) -> DialectRootT: src = SourceDefinition.from_function(function) closure_vars = get_closure_vars_from_function(function) - annotations = typing.get_type_hints(function) + annotations = source_utils.get_type_hints_from_function(function, src) return cls.apply(src, closure_vars, annotations) @classmethod @@ -159,6 +173,35 @@ def generic_visit(self, node: ast.AST) -> None: hints=hints, ) + def _validate_signature(self, node: ast.arguments) -> None: + """Reject parameter-list features that have no DSL semantics with a clear diagnostic.""" + if node.posonlyargs: + raise errors.UnsupportedPythonFeatureError( + self.get_location(node.posonlyargs[0]), + "positional-only parameters", + hints=("Remove the '/' marker from the parameter list.",), + ) + if node.kwonlyargs: + raise errors.UnsupportedPythonFeatureError( + self.get_location(node.kwonlyargs[0]), + "keyword-only parameters", + hints=("Remove the '*' marker from the parameter list.",), + ) + if node.vararg is not None: + raise errors.UnsupportedPythonFeatureError( + self.get_location(node.vararg), "'*args' parameters" + ) + if node.kwarg is not None: + raise errors.UnsupportedPythonFeatureError( + self.get_location(node.kwarg), "'**kwargs' parameters" + ) + if defaults := [d for d in [*node.defaults, *node.kw_defaults] if d is not None]: + raise errors.UnsupportedPythonFeatureError( + self.get_location(defaults[0]), + "default values for parameters", + hints=("Pass the value explicitly at every call site instead.",), + ) + def _check_not_a_reserved_name(self, name: str, location: SourceLocation) -> None: if name in self.reserved_names: raise errors.DSLError( diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 8531ccacf9..ee2154278c 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -470,12 +470,24 @@ def __gt_type__(self) -> ts.OffsetType: def __getitem__(self, offset: int) -> common.Connectivity: """Serve as a connectivity factory.""" - from gt4py.next import embedded # avoid circular import + from gt4py.next import embedded, errors # avoid circular import assert isinstance(self.value, str) current_offset_provider = embedded.context.get_offset_provider(None) assert current_offset_provider is not None - offset_definition = common.get_offset(current_offset_provider, self.value) + try: + offset_definition = common.get_offset(current_offset_provider, self.value) + except KeyError as err: + raise errors.DSLError( + None, + f"Offset '{self.value}' not found in the offset provider.", + hints=( + f"Add an entry for '{self.value}' to the 'offset_provider' argument " + f"of the call: the dimension itself for a Cartesian offset (e.g. " + f"'offset_provider={{\"{self.value}\": {self.source.value}}}'), or a " + "connectivity for an unstructured offset.", + ), + ) from err connectivity: common.Connectivity if isinstance(offset_definition, common.Dimension): @@ -485,25 +497,47 @@ def __getitem__(self, offset: int) -> common.Connectivity: named_index = common.NamedIndex(self.target[-1], offset) connectivity = offset_definition[named_index] else: - raise NotImplementedError() + raise errors.DSLTypeError( + None, + f"Invalid offset provider entry for '{self.value}': expected a " + f"'Dimension' or a connectivity, got '{type(offset_definition).__name__}'.", + ) return connectivity def as_connectivity_field(self) -> common.Connectivity: """Convert to connectivity field using the offset providers in current embedded execution context.""" - from gt4py.next import embedded # avoid circular import + from gt4py.next import embedded, errors # avoid circular import assert isinstance(self.value, str) current_offset_provider = embedded.context.get_offset_provider(None) assert current_offset_provider is not None - offset_definition = common.get_offset(current_offset_provider, self.value) + try: + offset_definition = common.get_offset(current_offset_provider, self.value) + except KeyError as err: + raise errors.DSLError( + None, + f"Offset '{self.value}' not found in the offset provider.", + hints=( + f"Add an entry for '{self.value}' to the 'offset_provider' argument " + "of the call.", + ), + ) from err cache_key = id(offset_definition) if (connectivity := self._cache.get(cache_key, None)) is None: if isinstance(offset_definition, common.Connectivity): connectivity = offset_definition else: - raise NotImplementedError() + raise errors.DSLTypeError( + None, + f"Invalid offset provider entry for '{self.value}': expected a " + f"connectivity, got '{type(offset_definition).__name__}'.", + hints=( + "Construct connectivities from neighbor tables with " + "'gtx.as_connectivity(...)'.", + ), + ) self._cache[cache_key] = connectivity diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 1b090489ff..377b08d1dc 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -429,13 +429,64 @@ def visit_Symbol( return new_node return node + def _invalid_attribute_error( + self, node: foast.Attribute, value_type: ts.TypeSpec + ) -> errors.DSLError: + notes: list[str] = [] + hints: list[str] = [] + if isinstance(value_type, ts.NamedCollectionType): + hints = errors.did_you_mean(node.attr, value_type.keys) + elif isinstance(value_type, (ts.FieldType, ts.ScalarType)): + notes = [ + "GT4Py fields and scalars do not provide NumPy-style methods or attributes; " + "use the GT4Py built-in functions instead." + ] + return errors.DSLError( + node.location, + f"Type '{value_type}' has no attribute '{node.attr}'.", + label="attribute does not exist", + notes=notes, + hints=hints, + ) + def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Attribute: new_value = self.visit(node.value, **kwargs) + # Attribute access is only valid on namespaces (e.g. modules) and named + # collections; anything else (e.g. NumPy-style attributes of a field) + # is a user error and must not leak `getattr` failures. + try: + new_type = getattr(new_value.type, node.attr) + except AttributeError as e: + if isinstance(new_value.type, type_translation.NamespaceProxy): + # `getattr` on the underlying Python object failed; its message + # ("module 'numpy' has no attribute ...") is the best we have. + raise errors.DSLError(node.location, f"{e}.") from e + raise self._invalid_attribute_error(node, new_value.type) from e + except ValueError as e: + # the attribute exists on the Python object behind a namespace, but + # its value has no representation in the DSL type system + hints: list[str] = [] + if isinstance(new_value.type, type_translation.NamespaceProxy) and callable( + getattr(new_value.type._object, node.attr, None) + ): + hints = [ + "External Python functions cannot be called inside GT4Py functions. " + "Use the corresponding GT4Py built-in (e.g. 'sin', 'sqrt', 'exp', " + "'maximum' from 'gt4py.next') if one exists." + ] + raise errors.DSLError( + node.location, + f"'{node.attr}' cannot be used inside a GT4Py function.", + label="value not representable in the GT4Py type system", + hints=hints, + ) from e + if not isinstance(new_type, ts.TypeSpec): + raise self._invalid_attribute_error(node, new_value.type) return foast.Attribute( value=new_value, attr=node.attr, location=node.location, - type=getattr(new_value.type, node.attr), + type=new_type, ) def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscript: @@ -452,6 +503,12 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri node.location, f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex + if not -len(types) <= index < len(types): + raise errors.DSLError( + node.location, + f"Tuple index {index} is out of range.", + label=f"this tuple has {len(types)} element{'s' * (len(types) != 1)}", + ) new_type = types[index] case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: @@ -470,7 +527,23 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri ) new_type = new_value.type case ts.FieldType(dims=dims, dtype=dtype): - # e.g. `field[LocalDim(42)]` + # e.g. `field[LocalDim(42)]`: the only valid field subscript is a + # local-dimension index, which removes that dimension + if not isinstance(new_index.type, ts.IndexType): + raise errors.DSLError( + node.location, + f"Fields cannot be indexed with '{new_index.type}'.", + label="invalid field index", + notes=( + "GT4Py expressions operate on whole fields; accessing single " + "grid points by absolute position is not possible.", + ), + hints=( + "To access neighboring grid points, apply a field offset, " + "e.g. 'field(Ioff[1])'. Entries of a local (neighbor) dimension " + "can be selected with 'field[LocalDim(0)]'.", + ), + ) new_type = ts.FieldType( dims=[d for d in dims if d != new_index.type.dim], dtype=dtype, @@ -743,7 +816,21 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: elif isinstance(new_func.type, ts.FieldType): pass elif isinstance(new_func.type, ts.DimensionType): - assert new_func.type.dim.kind == DimensionKind.LOCAL + if new_func.type.dim.kind != DimensionKind.LOCAL: + raise errors.DSLError( + node.location, + f"'{new_func.type.dim.value}' is not a local (neighbor) dimension and " + "cannot be used to construct an index.", + label=f"this dimension is of kind '{new_func.type.dim.kind}'", + notes=( + "Only indices of a local dimension (e.g. 'V2EDim(0)') can be " + "constructed this way, to select one neighbor entry of a field.", + ), + hints=( + "To access neighboring grid points along a dimension, apply a " + "field offset instead, e.g. 'field(Ioff[1])'.", + ), + ) return foast.Call( func=new_func, args=new_args, @@ -751,6 +838,16 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: location=node.location, type=ts.IndexType(dim=new_func.type.dim), ) + elif isinstance(new_func.type, ts_ffront.ProgramType): + raise errors.DSLError( + node.location, + "Programs cannot be called inside field operators.", + label="this is a '@program'", + hints=( + "Call the field operators directly; programs can only be called " + "from plain Python.", + ), + ) else: raise errors.DSLError( node.location, diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 1497d0552e..a9c7166cb8 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -9,8 +9,6 @@ from __future__ import annotations import ast -import textwrap -import typing from typing import Any, Type import gt4py.eve as eve @@ -71,14 +69,21 @@ def func_to_foast(inp: DSLFieldOperatorDef) -> FOASTOperatorDef: """ source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) - annotations = typing.get_type_hints(inp.definition) + annotations = source_utils.get_type_hints_from_function(inp.definition, source_def) try: foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations) loc = foast_definition_node.location - operator_attribute_nodes = { - key: foast.Constant(value=value, type=type_translation.from_value(value), location=loc) - for key, value in inp.attributes.items() - } + operator_attribute_nodes = {} + for key, value in inp.attributes.items(): + try: + type_ = type_translation.from_value(value) + except Exception as e: + raise errors.DSLTypeError( + loc, + f"Argument '{key}' to operator '{foast_definition_node.id}' has a type " + f"not supported by GT4Py: '{type(value).__name__}'.", + ) from e + operator_attribute_nodes[key] = foast.Constant(value=value, type=type_, location=loc) untyped_foast_node = inp.node_class( id=foast_definition_node.id, definition=foast_definition_node, @@ -227,13 +232,19 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs: Any) -> foast.Funct ) def visit_arguments(self, node: ast.arguments) -> list[foast.DataSymbol]: + self._validate_signature(node) return [self.visit_arg(arg) for arg in node.args] def visit_arg(self, node: ast.arg) -> foast.DataSymbol: loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: raise errors.MissingParameterAnnotationError(loc, node.arg) - new_type = type_translation.from_type_hint(annotation) + try: + new_type = type_translation.from_type_hint(annotation) + except ValueError as e: + err = errors.InvalidParameterAnnotationError(loc, node.arg, annotation) + err.notes.append(str(e)) + raise err from e if not isinstance(new_type, ts.DataType): raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type) return foast.DataSymbol(id=node.arg, location=loc, type=new_type) @@ -249,6 +260,12 @@ def visit_Assign( ] = [] for elt in target.elts: + if isinstance(elt, ast.Tuple): + raise errors.DSLError( + self.get_location(elt), + "Nested tuple unpacking is not supported.", + hints=("Unpack the inner tuple in a separate assignment.",), + ) if isinstance(elt, ast.Starred): new_targets.append( foast.Starred( @@ -310,8 +327,19 @@ def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs: Any) -> foast.Assign: ), "Annotations should be ast.Constant(string). Use StringifyAnnotationsPass" context = {**fbuiltins.BUILTINS, **self.closure_vars} - annotation = eval(node.annotation.value, context) - target_type = type_translation.from_type_hint(annotation, globalns=context) + try: + annotation = eval(node.annotation.value, context) + target_type = type_translation.from_type_hint(annotation, globalns=context) + except Exception as e: + raise errors.DSLError( + self.get_location(node), + f"Invalid type annotation '{node.annotation.value}': {e}.", + notes=( + "Inside a GT4Py function, annotations can only use GT4Py " + "builtins and names that are also referenced in the function " + "body.", + ), + ) from e else: target_type = ts.DeferredType() @@ -453,14 +481,8 @@ def visit_Compare(self, node: ast.Compare, **kwargs: Any) -> foast.Compare: refactored = UnchainComparesPass.apply(node) raise errors.DSLError( loc, - textwrap.dedent( - f""" - Comparison chains are not allowed. Please replace - {ast.unparse(node)} - by - {ast.unparse(refactored)} - """, - ), + "Comparison chains are not allowed.", + hints=(f"Replace '{ast.unparse(node)}' by '{ast.unparse(refactored)}'.",), ) return foast.Compare( op=self.visit(node.ops[0]), @@ -490,14 +512,21 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: def _verify_builtin_type_constructor(self, node: ast.Call) -> None: if len(node.args) > 0: arg = node.args[0] + func_name = self._func_name(node) if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) ): raise errors.DSLError( - self.get_location(node), - f"'{self._func_name(node)}()' only takes literal arguments.", + self.get_location(node), f"'{func_name}()' only takes literal arguments." ) + try: + fbuiltins.BUILTINS[func_name](ast.literal_eval(arg)) + except Exception as e: + raise errors.DSLError( + self.get_location(node), + f"'{ast.unparse(arg)}' is not a valid literal for '{func_name}': {e}.", + ) from e def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. @@ -520,9 +549,11 @@ def visit_Constant(self, node: ast.Constant, **kwargs: Any) -> foast.Constant: loc = self.get_location(node) try: type_ = type_translation.from_value(node.value) - except ValueError: + except ValueError as e: raise errors.DSLError( - loc, f"Constants of type {type(node.value)} are not permitted." + loc, + f"Invalid constant of type '{type(node.value).__name__}'.", + notes=(str(e),), ) from None return foast.Constant(value=node.value, location=loc, type=type_) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 392b6db2a5..1db9119117 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -10,7 +10,6 @@ import ast import dataclasses -import typing from typing import Any, cast from gt4py._core import definitions as core_defs @@ -63,7 +62,7 @@ def func_to_past(inp: DSLProgramDef) -> PASTProgramDef: """ source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) - annotations = typing.get_type_hints(inp.definition) + annotations = source_utils.get_type_hints_from_function(inp.definition, source_def) return ffront_stages.PASTProgramDef( past_node=ProgramParser.apply(source_def, closure_vars, annotations), closure_vars=closure_vars, @@ -109,34 +108,70 @@ def _postprocess_dialect_ast( return ProgramTypeDeduction.apply(output_node) def visit_FunctionDef(self, node: ast.FunctionDef) -> past.Program: - self._check_not_a_reserved_name(node.name, self.get_location(node)) - closure_symbols: list[past.Symbol] = [ - past.Symbol( - id=name, - type=type_translation.from_value(val), - namespace=dialect_ast_enums.Namespace.CLOSURE, - location=self.get_location(node), + loc = self.get_location(node) + self._check_not_a_reserved_name(node.name, loc) + closure_symbols: list[past.Symbol] = [] + for name, val in self.closure_vars.items(): + try: + type_ = type_translation.from_value(val) + except ValueError as e: + hints: tuple[str, ...] = () + if callable(val): + hints = ( + "Only functions decorated with '@field_operator' or " + "'@scan_operator' can be called inside a program.", + ) + raise errors.DSLTypeError( + loc, + f"Unexpected object '{name}' of type '{type(val)}' encountered.", + hints=hints, + ) from e + closure_symbols.append( + past.Symbol( + id=name, + type=type_, + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=loc, + ) ) - for name, val in self.closure_vars.items() - ] + + body: list[past.LocatedNode] = [] + for stmt in node.body: + new_stmt = self.visit(stmt) + if not isinstance(new_stmt, past.Call): + raise errors.DSLError( + self.get_location(stmt), + "Only calls to GT4Py operators are allowed as statements in a program.", + notes=( + "A program orchestrates operator calls that write into 'out' " + "arguments; computations belong inside field operators.", + ), + ) + body.append(new_stmt) return past.Program( id=node.name, type=ts.DeferredType(constraint=ts_ffront.ProgramType), params=self.visit(node.args), - body=[self.visit(node) for node in node.body], + body=body, closure_vars=closure_symbols, - location=self.get_location(node), + location=loc, ) def visit_arguments(self, node: ast.arguments) -> list[past.DataSymbol]: + self._validate_signature(node) return [self.visit_arg(arg) for arg in node.args] def visit_arg(self, node: ast.arg) -> past.DataSymbol: loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: raise errors.MissingParameterAnnotationError(loc, node.arg) - new_type = type_translation.from_type_hint(annotation) + try: + new_type = type_translation.from_type_hint(annotation) + except ValueError as e: + err = errors.InvalidParameterAnnotationError(loc, node.arg, annotation) + err.notes.append(str(e)) + raise err from e if not isinstance(new_type, ts.DataType): raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type) return past.DataSymbol(id=node.arg, location=loc, type=new_type) @@ -199,8 +234,18 @@ def visit_Attribute(self, node: ast.Attribute) -> past.Attribute: ) def visit_Dict(self, node: ast.Dict) -> past.Dict: + keys = [] + for param in node.keys: + new_key = self.visit(cast(ast.AST, param)) + if not isinstance(new_key, (past.Name, past.Attribute)): + raise errors.DSLError( + self.get_location(cast(ast.AST, param)), + "Dictionary keys must be dimension objects referenced by name " + "(e.g. 'IDim', not '\"IDim\"').", + ) + keys.append(new_key) return past.Dict( - keys_=[self.visit(cast(ast.AST, param)) for param in node.keys], + keys_=keys, values_=[self.visit(param) for param in node.values], location=self.get_location(node), ) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9d021ceb51..77813030f6 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -261,9 +261,25 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: f"Got '{arg_types[0]}' and '{arg_types[1]}'." ) return_type = arg_types[0] + elif isinstance(new_func.type, ts_ffront.ProgramType): + raise errors.DSLError( + node.location, + f"Program '{node.func.id}' cannot be called from within another program.", + label="this is a '@program'", + hints=( + "Call the field operators directly, or compose the programs " + "from plain Python.", + ), + ) else: - raise AssertionError( - "Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed." + raise errors.DSLError( + node.location, + f"'{node.func.id}' cannot be called inside a program.", + label=f"this has type '{new_func.type}'", + notes=( + "Inside a program, only field operators, scan operators and the " + "builtins 'minimum' and 'maximum' can be called.", + ), ) except ValueError as ex: diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index cc02c1d89b..a2d708eeec 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -13,9 +13,13 @@ import pathlib import symtable import textwrap +import typing from collections.abc import Callable, Iterator from dataclasses import dataclass -from typing import Any, cast +from typing import Any, Optional, cast + +from gt4py.eve.concepts import SourceLocation +from gt4py.next import errors MISSING_FILENAME = "" @@ -28,6 +32,39 @@ def get_closure_vars_from_function(function: Callable) -> dict[str, Any]: return dict(sorted({**builtins, **globals, **nonlocals}.items())) +def get_type_hints_from_function( + function: Callable, source_definition: Optional[SourceDefinition] = None +) -> dict[str, Any]: + """ + Resolve the type annotations of ``function``, reporting failures as :class:`errors.DSLError`. + + Annotations are resolved with :func:`typing.get_type_hints`, which evaluates + them in the function's module namespace; unresolvable annotations (e.g. names + that are not defined, or strings that are not valid types) raise exceptions + that would otherwise leak to the user as plain Python errors. + """ + try: + return typing.get_type_hints(function) + except Exception as err: + location = ( + SourceLocation( + filename=source_definition.filename, + line=source_definition.line_offset + 1, + column=source_definition.column_offset + 1, + ) + if source_definition is not None + else None + ) + raise errors.DSLError( + location, + f"Could not resolve type annotations of '{function.__name__}': {err}.", + hints=( + "Make sure every name used in an annotation is defined or imported in " + "the module where the function is defined.", + ), + ) from err + + def make_source_definition_from_function(func: Callable) -> SourceDefinition: try: filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_diagnostic_messages.py b/tests/next_tests/unit_tests/ffront_tests/test_diagnostic_messages.py index 70573d280e..33b6c9b81c 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_diagnostic_messages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_diagnostic_messages.py @@ -23,6 +23,7 @@ import gt4py.next as gtx from gt4py.next import errors, float32, float64 from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.ffront.func_to_past import ProgramParser IDim = gtx.Dimension("IDim") @@ -34,6 +35,12 @@ def parse_error(func) -> errors.DSLError: return exc_info.value +def parse_program_error(func) -> errors.DSLError: + with pytest.raises(errors.DSLError) as exc_info: + ProgramParser.apply_to_function(func) + return exc_info.value + + def test_undeclared_symbol_suggests_close_match(): def misspelled(temperature: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: tmp_field = temperature * 2.0 @@ -163,6 +170,396 @@ def misspelled(temperature: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], flo assert "While processing the definition of 'misspelled'." in exc_info.value.__notes__ +def test_global_statement_is_rejected_with_friendly_message(): + # 'global' used to crash an AST preprocessing pass with an AttributeError + # because its 'names' field holds plain strings, not AST nodes. + def with_global(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + global IDim + return a + + err = parse_error(with_global) + + assert isinstance(err, errors.UnsupportedPythonFeatureError) + assert err.message == "Unsupported Python syntax: 'global' statement." + assert any("read-only" in hint for hint in err.hints) + + +def test_numpy_style_attribute_on_field_is_a_dsl_error(): + # used to leak AttributeError: 'FieldType' object has no attribute 'T' + def with_numpy_attr(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return a.T + + err = parse_error(with_numpy_attr) + + assert err.message == "Type 'Field[[IDim], float64]' has no attribute 'T'." + assert any("NumPy-style" in note for note in err.notes) + + +def test_numpy_function_call_is_a_dsl_error(): + # used to leak ValueError: Type not supported + import numpy as np + + def with_numpy_call(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return np.sin(a) + + err = parse_error(with_numpy_call) + + assert err.message == "'sin' cannot be used inside a GT4Py function." + assert any("GT4Py built-in" in hint for hint in err.hints) + + +def test_missing_module_attribute_is_a_dsl_error(): + # used to leak AttributeError: module 'numpy' has no attribute 'sinn' + import numpy as np + + def with_missing_attr(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return np.sinn(a) + + err = parse_error(with_missing_attr) + + assert "module 'numpy' has no attribute 'sinn'" in err.message + + +def test_absolute_field_index_is_a_dsl_error(): + # used to leak AttributeError: 'ScalarType' object has no attribute 'dim' + def with_index(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return a[3] + + err = parse_error(with_index) + + assert err.message == "Fields cannot be indexed with 'int32'." + assert any("field offset" in hint for hint in err.hints) + + +def test_tuple_index_out_of_range_is_a_dsl_error(): + # used to leak IndexError: list index out of range + def with_oob_index(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + t = (a, a) + return t[5] + + err = parse_error(with_oob_index) + + assert err.message == "Tuple index 5 is out of range." + assert err.label == "this tuple has 2 elements" + + +def test_non_local_dimension_index_is_a_dsl_error(): + # used to crash with AssertionError on the dimension kind + def with_dim_index(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return a[IDim(3)] + + err = parse_error(with_dim_index) + + assert "'IDim' is not a local (neighbor) dimension" in err.message + + +def test_unresolvable_string_annotation_is_a_dsl_error(): + # used to leak SyntaxError from typing.get_type_hints + def with_bad_annotation(a: "not a type") -> gtx.Field[[IDim], float64]: # noqa: F722 [syntax-error-in-forward-annotation] + return a + + err = parse_error(with_bad_annotation) + + assert "Could not resolve type annotations of 'with_bad_annotation'" in err.message + + +def test_non_gt4py_parameter_annotation_is_a_dsl_error(): + # used to leak ValueError: Type not supported + def with_list_param(a: list) -> gtx.Field[[IDim], float64]: + return a + + err = parse_error(with_list_param) + + assert isinstance(err, errors.InvalidParameterAnnotationError) + assert any("GT4Py type" in hint for hint in err.hints) + + +def test_unresolvable_annotated_assignment_is_a_dsl_error(): + # used to leak NameError from eval'ing the annotation: 'gtx' is only + # visible inside the function if it is also referenced in the body + def with_ann_assign(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + b: gtx.Field[[IDim], float64] = a + return b + + err = parse_error(with_ann_assign) + + assert "Invalid type annotation 'gtx.Field[[IDim], float64]'" in err.message + assert any("GT4Py builtins" in note for note in err.notes) + + +@gtx.field_operator +def _copy_op(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return a + + +def test_expression_statement_in_program_is_a_dsl_error(): + # used to leak a TypeError from IR node validation + def with_expr_stmt(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + a + a + + err = parse_program_error(with_expr_stmt) + + assert err.message == "Only calls to GT4Py operators are allowed as statements in a program." + + +def test_calling_program_from_program_is_a_dsl_error(): + # used to crash with AssertionError + @gtx.program + def inner(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + _copy_op(a, out=out) + + def with_program_call(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + inner(a, out) + + err = parse_program_error(with_program_call) + + assert err.message == "Program 'inner' cannot be called from within another program." + + +def test_plain_python_function_in_program_is_a_dsl_error(): + # used to leak ValueError: Invalid callable annotations ... + def plain(a, out): + return a + + def with_plain_call(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + plain(a, out=out) + + err = parse_program_error(with_plain_call) + + assert isinstance(err, errors.DSLTypeError) + assert any("@field_operator" in hint for hint in err.hints) + + +def test_nested_tuple_unpacking_is_a_dsl_error(): + # used to leak AttributeError: 'TupleExpr' object has no attribute 'id' + def with_nested_unpack(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + (b, c), d = (a, a), a + return b + + err = parse_error(with_nested_unpack) + + assert err.message == "Nested tuple unpacking is not supported." + + +def test_invalid_literal_for_type_constructor_is_a_dsl_error(): + # used to leak ValueError at execution time, long after the definition + from gt4py.next import int32 + + def with_bad_cast(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + b = int32("abc") + return a + + err = parse_error(with_bad_cast) + + assert "is not a valid literal for 'int32'" in err.message + + +def test_scan_operator_init_with_unsupported_type_is_a_dsl_error(): + # 'init=np.zeros(...)' used to crash with NumPy's ambiguous-truth-value + # ValueError when fingerprinting the value + import numpy as np + + KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) + + with pytest.raises(errors.DSLTypeError, match="Argument 'init'") as exc_info: + + @gtx.scan_operator(axis=KDim, forward=True, init=np.zeros(3)) + def scan_op(state: float64, x: float64) -> float64: + return state + x + + assert "ndarray" in exc_info.value.message + + +def test_string_dimension_key_in_domain_is_a_dsl_error(): + # used to leak a TypeError from IR node validation + def with_str_domain_key(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + _copy_op(a, out=out, domain={"IDim": (0, 10)}) + + err = parse_program_error(with_str_domain_key) + + assert "Dictionary keys must be dimension objects" in err.message + + +@pytest.mark.parametrize( + "feature, definition", + [ + ("keyword-only parameters", "def f(a: F, *, b: F) -> F:\n return a"), + ("positional-only parameters", "def f(a: F, /) -> F:\n return a"), + ("'*args' parameters", "def f(*args: F) -> F:\n return args[0]"), + ("'**kwargs' parameters", "def f(a: F, **kw: F) -> F:\n return a"), + ("default values for parameters", "def f(a: F, b: float = 1.0) -> F:\n return a"), + ], +) +def test_unsupported_signature_features_are_dsl_errors(feature, definition, tmp_path): + # these used to be silently dropped, leading to misleading + # "Undeclared symbol" errors for the affected parameters + import textwrap + + module = tmp_path / "sig_case.py" + module.write_text( + textwrap.dedent( + """ + import gt4py.next as gtx + from gt4py.next import float64 + IDim = gtx.Dimension("IDim") + F = gtx.Field[[IDim], float64] + """ + ) + + definition + ) + import importlib.util + + spec = importlib.util.spec_from_file_location("sig_case", module) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + with pytest.raises(errors.UnsupportedPythonFeatureError) as exc_info: + FieldOperatorParser.apply_to_function(mod.f) + assert exc_info.value.message == f"Unsupported Python syntax: {feature}." + + +def test_out_of_range_integer_constant_explains_reason(): + def with_huge_constant(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + b = 99999999999999999999999999 + return a + + err = parse_error(with_huge_constant) + + assert err.message == "Invalid constant of type 'int'." + assert any("out of range" in note for note in err.notes) + + +def test_calling_program_inside_field_operator_is_explained(): + # the message used to dump the entire 'ProgramType(...)' spec + @gtx.program + def some_prog(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + _copy_op(a, out=out) + + def with_prog_call(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return some_prog(a) + + err = parse_error(with_prog_call) + + assert err.message == "Programs cannot be called inside field operators." + + +# --- call-time diagnostics (embedded execution, no backend) --------------- + +import numpy as np # noqa: E402 [import-not-at-top-of-file] + + +def test_numpy_array_argument_in_direct_call_is_a_dsl_error(): + # used to crash an assert deep inside the embedded execution + arr = np.zeros(5) + out = gtx.as_field([IDim], np.zeros(5)) + + with pytest.raises(errors.DSLTypeError) as exc_info: + _copy_op(arr, out=out, offset_provider={}) + + assert "argument 1 has a type not supported by GT4Py: 'ndarray'" in exc_info.value.message + assert any("as_field" in hint for hint in exc_info.value.hints) + + +def test_wrong_dims_argument_in_direct_call_is_a_dsl_error(): + # used to leak ValueError: Incompatible 'Domain' in assignment + JDim = gtx.Dimension("JDim") + a = gtx.as_field([JDim], np.zeros(5)) + out = gtx.as_field([IDim], np.zeros(5)) + + with pytest.raises(errors.DSLError, match="Invalid argument types in call to '_copy_op'"): + _copy_op(a, out=out, offset_provider={}) + + +def test_wrong_out_type_in_direct_call_is_a_dsl_error(): + # used to leak ValueError: Incompatible 'Domain' in assignment + JDim = gtx.Dimension("JDim") + a = gtx.as_field([IDim], np.zeros(5)) + out = gtx.as_field([JDim], np.zeros(5)) + + with pytest.raises(errors.DSLTypeError, match="expected keyword argument 'out' to be of type"): + _copy_op(a, out=out, offset_provider={}) + + +def test_extra_argument_in_direct_call_is_a_dsl_error(): + # used to leak TypeError: ... takes 1 positional argument but 2 were given + a = gtx.as_field([IDim], np.zeros(5)) + out = gtx.as_field([IDim], np.zeros(5)) + + with pytest.raises(errors.DSLError, match="Invalid argument types in call to '_copy_op'"): + _copy_op(a, a, out=out, offset_provider={}) + + +def test_missing_offset_provider_entry_is_a_dsl_error(): + # used to leak a KeyError + Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) + + @gtx.field_operator + def shift_op(f: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return f(Ioff[1]) + + a = gtx.as_field([IDim], np.zeros(5)) + out = gtx.as_field([IDim], np.zeros(5)) + + with pytest.raises(errors.DSLError, match="Offset 'Ioff' not found") as exc_info: + shift_op(a, out=out, offset_provider={}) + assert any("offset_provider" in hint for hint in exc_info.value.hints) + + +def test_invalid_domain_argument_is_a_dsl_error(): + # used to leak ValueError: '0' is not 'DomainLike' + a = gtx.as_field([IDim], np.zeros(5)) + out = gtx.as_field([IDim], np.zeros(5)) + + with pytest.raises(errors.DSLTypeError, match="Invalid 'domain' argument"): + _copy_op(a, out=out, domain=(0, 5), offset_provider={}) + + +def test_numpy_array_argument_in_program_call_is_a_dsl_error(): + # used to leak ValueError: The truth value of an array ... is ambiguous + @gtx.program + def copy_prog(f: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + _copy_op(f, out=out) + + arr = np.zeros(5) + out = gtx.as_field([IDim], np.zeros(5)) + + with pytest.raises(errors.DSLTypeError) as exc_info: + copy_prog(arr, out, offset_provider={}) + + assert "argument 1 has a type not supported by GT4Py" in exc_info.value.message + + +def test_non_mapping_offset_provider_is_a_dsl_error(): + # used to be silently accepted until an offset lookup crashed (or not at all) + a = gtx.as_field([IDim], np.zeros(5)) + out = gtx.as_field([IDim], np.zeros(5)) + + with pytest.raises(errors.DSLTypeError, match="'offset_provider' must be a mapping"): + _copy_op(a, out=out, offset_provider=[("Ioff", IDim)]) + + +def test_unsupported_field_dtype_reports_dtype(): + # used to crash an assert in NdArrayField.from_array + with pytest.raises(ValueError, match="unsupported dtype 'float16'"): + gtx.as_field([IDim], np.zeros(5, dtype=np.float16)) + + +def test_string_dimensions_in_as_field_are_rejected(): + # used to fail with a baffling "''D'' cannot be interpreted as 'UnitRange'" + with pytest.raises(TypeError, match="must be 'Dimension' objects"): + gtx.as_field(["IDim"], np.zeros(5)) + + +def test_non_integral_neighbor_table_reports_dtype(): + # used to crash an assert in NdArrayConnectivityField.from_array + Vertex = gtx.Dimension("Vertex") + Edge = gtx.Dimension("Edge") + V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) + + with pytest.raises(ValueError, match="integral dtype"): + gtx.as_connectivity([Vertex, V2EDim], codomain=Edge, data=np.array([[0.5, 1.5]])) + + def test_diagnostic_codes_are_stable(): assert errors.UndefinedSymbolError.code == "undefined-symbol" assert errors.UnsupportedPythonFeatureError.code == "unsupported-syntax" diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 6399133870..f3d3a6eff1 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -91,7 +91,10 @@ def test_mistyped_arg(): def mistyped(inp: gtx.Field): return inp - with pytest.raises(ValueError, match="Field type requires two arguments, got 0."): + with pytest.raises( + errors.InvalidParameterAnnotationError, + match="Field type requires two arguments, got 0.", + ): _ = FieldOperatorParser.apply_to_function(mistyped)