diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 87237a0a75..21b351ad48 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -19,6 +19,7 @@ MissingParameterAnnotationError, UndefinedSymbolError, UnsupportedPythonFeatureError, + did_you_mean, ) @@ -32,4 +33,5 @@ "MissingParameterAnnotationError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", + "did_you_mean", ] diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index fb7e7eaa7d..ab2e83784f 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -36,7 +36,7 @@ def message(self) -> str: return self.args[0] -def _did_you_mean(name: str, candidates: Iterable[str]) -> list[str]: +def did_you_mean(name: str, candidates: Iterable[str]) -> list[str]: """Produce a 'Did you mean ...?' hint if `name` closely matches any candidate.""" # Never suggest the name the user already wrote (it can appear among the # candidates as a same-named symbol from another SSA generation). @@ -157,7 +157,7 @@ def __init__( location, f"Undeclared symbol '{name}'.", label="not defined at this point", - hints=_did_you_mean(name, candidates), + hints=did_you_mean(name, candidates), ) self.sym_name = name @@ -183,8 +183,8 @@ def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg: class DSLTypeError(DSLError): - def __init__(self, location: Optional[SourceLocation], message: str) -> None: - super().__init__(location, message) + def __init__(self, location: Optional[SourceLocation], message: str, **kwargs: Any) -> None: + super().__init__(location, message, **kwargs) class MissingParameterAnnotationError(DSLTypeError): @@ -203,6 +203,11 @@ def __init__(self, location: Optional[SourceLocation], param_name: str, type_: A super().__init__( location, f"Parameter '{param_name}' has invalid type annotation '{type_}'." ) + self.hints = [ + "Annotate parameters with a GT4Py type: a field (e.g. " + "'gtx.Field[[IDim], gtx.float64]'), a scalar (e.g. 'gtx.float64') or a " + "tuple of these." + ] self.param_name = param_name self.annotated_type = type_ diff --git a/src/gt4py/next/ffront/ast_passes/simple_assign.py b/src/gt4py/next/ffront/ast_passes/simple_assign.py index a251490846..3813d78516 100644 --- a/src/gt4py/next/ffront/ast_passes/simple_assign.py +++ b/src/gt4py/next/ffront/ast_passes/simple_assign.py @@ -30,7 +30,13 @@ def generic_visit(self, node: ast.AST) -> Iterator[ast.AST]: # type: ignore[ove """Override generic visit to deal with generators.""" for field, old_value in ast.iter_fields(node): if isinstance(old_value, list): - new_values = [i for j in old_value for i in self.visit(j)] + # fields may contain non-AST values, e.g. the names of a `global` + # statement; those must be kept as-is, not visited. + new_values = [ + i + for j in old_value + for i in (self.visit(j) if isinstance(j, ast.AST) else (j,)) + ] old_value[:] = new_values elif isinstance(old_value, ast.AST): new_node, *_ = list(self.visit(old_value)) or (None,) diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index aa4aae7b49..629211f086 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -8,13 +8,13 @@ import ast import textwrap -import typing from dataclasses import dataclass from typing import Callable, ClassVar, Collection from gt4py.eve.concepts import SourceLocation from gt4py.eve.extended_typing import Any, Generic, TypeVar from gt4py.next import errors +from gt4py.next.ffront import source_utils from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function @@ -67,6 +67,20 @@ ast.ClassDef: ("class definition", ()), ast.JoinedStr: ("f-string", ("Strings cannot be computed inside GT4Py functions.",)), ast.Match: ("'match' statement", ("Use 'if'/'elif' chains or 'where' instead.",)), + ast.Global: ( + "'global' statement", + ( + "Variables from the surrounding scope are read-only inside GT4Py functions; " + "pass values as parameters and return results instead.", + ), + ), + ast.Nonlocal: ( + "'nonlocal' statement", + ( + "Variables from the surrounding scope are read-only inside GT4Py functions; " + "pass values as parameters and return results instead.", + ), + ), } @@ -136,7 +150,7 @@ def apply( def apply_to_function(cls, function: Callable) -> DialectRootT: src = SourceDefinition.from_function(function) closure_vars = get_closure_vars_from_function(function) - annotations = typing.get_type_hints(function) + annotations = source_utils.get_type_hints_from_function(function, src) return cls.apply(src, closure_vars, annotations) @classmethod diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 1b090489ff..77f24f7a67 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -429,13 +429,64 @@ def visit_Symbol( return new_node return node + def _invalid_attribute_error( + self, node: foast.Attribute, value_type: ts.TypeSpec + ) -> errors.DSLError: + notes: list[str] = [] + hints: list[str] = [] + if isinstance(value_type, ts.NamedCollectionType): + hints = errors.did_you_mean(node.attr, value_type.keys) + elif isinstance(value_type, (ts.FieldType, ts.ScalarType)): + notes = [ + "GT4Py fields and scalars do not provide NumPy-style methods or attributes; " + "use the GT4Py built-in functions instead." + ] + return errors.DSLError( + node.location, + f"Type '{value_type}' has no attribute '{node.attr}'.", + label="attribute does not exist", + notes=notes, + hints=hints, + ) + def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Attribute: new_value = self.visit(node.value, **kwargs) + # Attribute access is only valid on namespaces (e.g. modules) and named + # collections; anything else (e.g. NumPy-style attributes of a field) + # is a user error and must not leak `getattr` failures. + try: + new_type = getattr(new_value.type, node.attr) + except AttributeError as e: + if isinstance(new_value.type, type_translation.NamespaceProxy): + # `getattr` on the underlying Python object failed; its message + # ("module 'numpy' has no attribute ...") is the best we have. + raise errors.DSLError(node.location, f"{e}.") from e + raise self._invalid_attribute_error(node, new_value.type) from e + except ValueError as e: + # the attribute exists on the Python object behind a namespace, but + # its value has no representation in the DSL type system + hints: list[str] = [] + if isinstance(new_value.type, type_translation.NamespaceProxy) and callable( + getattr(new_value.type._object, node.attr, None) + ): + hints = [ + "External Python functions cannot be called inside GT4Py functions. " + "Use the corresponding GT4Py built-in (e.g. 'sin', 'sqrt', 'exp', " + "'maximum' from 'gt4py.next') if one exists." + ] + raise errors.DSLError( + node.location, + f"'{node.attr}' cannot be used inside a GT4Py function.", + label="value not representable in the GT4Py type system", + hints=hints, + ) from e + if not isinstance(new_type, ts.TypeSpec): + raise self._invalid_attribute_error(node, new_value.type) return foast.Attribute( value=new_value, attr=node.attr, location=node.location, - type=getattr(new_value.type, node.attr), + type=new_type, ) def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscript: @@ -452,6 +503,12 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri node.location, f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex + if not -len(types) <= index < len(types): + raise errors.DSLError( + node.location, + f"Tuple index {index} is out of range.", + label=f"this tuple has {len(types)} element{'s' * (len(types) != 1)}", + ) new_type = types[index] case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: @@ -470,7 +527,23 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri ) new_type = new_value.type case ts.FieldType(dims=dims, dtype=dtype): - # e.g. `field[LocalDim(42)]` + # e.g. `field[LocalDim(42)]`: the only valid field subscript is a + # local-dimension index, which removes that dimension + if not isinstance(new_index.type, ts.IndexType): + raise errors.DSLError( + node.location, + f"Fields cannot be indexed with '{new_index.type}'.", + label="invalid field index", + notes=( + "GT4Py expressions operate on whole fields; accessing single " + "grid points by absolute position is not possible.", + ), + hints=( + "To access neighboring grid points, apply a field offset, " + "e.g. 'field(Ioff[1])'. Entries of a local (neighbor) dimension " + "can be selected with 'field[LocalDim(0)]'.", + ), + ) new_type = ts.FieldType( dims=[d for d in dims if d != new_index.type.dim], dtype=dtype, @@ -743,7 +816,21 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: elif isinstance(new_func.type, ts.FieldType): pass elif isinstance(new_func.type, ts.DimensionType): - assert new_func.type.dim.kind == DimensionKind.LOCAL + if new_func.type.dim.kind != DimensionKind.LOCAL: + raise errors.DSLError( + node.location, + f"'{new_func.type.dim.value}' is not a local (neighbor) dimension and " + "cannot be used to construct an index.", + label=f"this dimension is of kind '{new_func.type.dim.kind}'", + notes=( + "Only indices of a local dimension (e.g. 'V2EDim(0)') can be " + "constructed this way, to select one neighbor entry of a field.", + ), + hints=( + "To access neighboring grid points along a dimension, apply a " + "field offset instead, e.g. 'field(Ioff[1])'.", + ), + ) return foast.Call( func=new_func, args=new_args, diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 1497d0552e..1ffe1d83da 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -9,8 +9,6 @@ from __future__ import annotations import ast -import textwrap -import typing from typing import Any, Type import gt4py.eve as eve @@ -71,7 +69,7 @@ def func_to_foast(inp: DSLFieldOperatorDef) -> FOASTOperatorDef: """ source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) - annotations = typing.get_type_hints(inp.definition) + annotations = source_utils.get_type_hints_from_function(inp.definition, source_def) try: foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations) loc = foast_definition_node.location @@ -233,7 +231,12 @@ def visit_arg(self, node: ast.arg) -> foast.DataSymbol: loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: raise errors.MissingParameterAnnotationError(loc, node.arg) - new_type = type_translation.from_type_hint(annotation) + try: + new_type = type_translation.from_type_hint(annotation) + except ValueError as e: + err = errors.InvalidParameterAnnotationError(loc, node.arg, annotation) + err.notes.append(str(e)) + raise err from e if not isinstance(new_type, ts.DataType): raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type) return foast.DataSymbol(id=node.arg, location=loc, type=new_type) @@ -310,8 +313,19 @@ def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs: Any) -> foast.Assign: ), "Annotations should be ast.Constant(string). Use StringifyAnnotationsPass" context = {**fbuiltins.BUILTINS, **self.closure_vars} - annotation = eval(node.annotation.value, context) - target_type = type_translation.from_type_hint(annotation, globalns=context) + try: + annotation = eval(node.annotation.value, context) + target_type = type_translation.from_type_hint(annotation, globalns=context) + except Exception as e: + raise errors.DSLError( + self.get_location(node), + f"Invalid type annotation '{node.annotation.value}': {e}.", + notes=( + "Inside a GT4Py function, annotations can only use GT4Py " + "builtins and names that are also referenced in the function " + "body.", + ), + ) from e else: target_type = ts.DeferredType() @@ -453,14 +467,8 @@ def visit_Compare(self, node: ast.Compare, **kwargs: Any) -> foast.Compare: refactored = UnchainComparesPass.apply(node) raise errors.DSLError( loc, - textwrap.dedent( - f""" - Comparison chains are not allowed. Please replace - {ast.unparse(node)} - by - {ast.unparse(refactored)} - """, - ), + "Comparison chains are not allowed.", + hints=(f"Replace '{ast.unparse(node)}' by '{ast.unparse(refactored)}'.",), ) return foast.Compare( op=self.visit(node.ops[0]), diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 392b6db2a5..7a07130d50 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -10,7 +10,6 @@ import ast import dataclasses -import typing from typing import Any, cast from gt4py._core import definitions as core_defs @@ -63,7 +62,7 @@ def func_to_past(inp: DSLProgramDef) -> PASTProgramDef: """ source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) - annotations = typing.get_type_hints(inp.definition) + annotations = source_utils.get_type_hints_from_function(inp.definition, source_def) return ffront_stages.PASTProgramDef( past_node=ProgramParser.apply(source_def, closure_vars, annotations), closure_vars=closure_vars, @@ -109,24 +108,54 @@ def _postprocess_dialect_ast( return ProgramTypeDeduction.apply(output_node) def visit_FunctionDef(self, node: ast.FunctionDef) -> past.Program: - self._check_not_a_reserved_name(node.name, self.get_location(node)) - closure_symbols: list[past.Symbol] = [ - past.Symbol( - id=name, - type=type_translation.from_value(val), - namespace=dialect_ast_enums.Namespace.CLOSURE, - location=self.get_location(node), + loc = self.get_location(node) + self._check_not_a_reserved_name(node.name, loc) + closure_symbols: list[past.Symbol] = [] + for name, val in self.closure_vars.items(): + try: + type_ = type_translation.from_value(val) + except ValueError as e: + hints: tuple[str, ...] = () + if callable(val): + hints = ( + "Only functions decorated with '@field_operator' or " + "'@scan_operator' can be called inside a program.", + ) + raise errors.DSLTypeError( + loc, + f"Unexpected object '{name}' of type '{type(val)}' encountered.", + hints=hints, + ) from e + closure_symbols.append( + past.Symbol( + id=name, + type=type_, + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=loc, + ) ) - for name, val in self.closure_vars.items() - ] + + body: list[past.LocatedNode] = [] + for stmt in node.body: + new_stmt = self.visit(stmt) + if not isinstance(new_stmt, past.Call): + raise errors.DSLError( + self.get_location(stmt), + "Only calls to GT4Py operators are allowed as statements in a program.", + notes=( + "A program orchestrates operator calls that write into 'out' " + "arguments; computations belong inside field operators.", + ), + ) + body.append(new_stmt) return past.Program( id=node.name, type=ts.DeferredType(constraint=ts_ffront.ProgramType), params=self.visit(node.args), - body=[self.visit(node) for node in node.body], + body=body, closure_vars=closure_symbols, - location=self.get_location(node), + location=loc, ) def visit_arguments(self, node: ast.arguments) -> list[past.DataSymbol]: @@ -136,7 +165,12 @@ def visit_arg(self, node: ast.arg) -> past.DataSymbol: loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: raise errors.MissingParameterAnnotationError(loc, node.arg) - new_type = type_translation.from_type_hint(annotation) + try: + new_type = type_translation.from_type_hint(annotation) + except ValueError as e: + err = errors.InvalidParameterAnnotationError(loc, node.arg, annotation) + err.notes.append(str(e)) + raise err from e if not isinstance(new_type, ts.DataType): raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type) return past.DataSymbol(id=node.arg, location=loc, type=new_type) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9d021ceb51..77813030f6 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -261,9 +261,25 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: f"Got '{arg_types[0]}' and '{arg_types[1]}'." ) return_type = arg_types[0] + elif isinstance(new_func.type, ts_ffront.ProgramType): + raise errors.DSLError( + node.location, + f"Program '{node.func.id}' cannot be called from within another program.", + label="this is a '@program'", + hints=( + "Call the field operators directly, or compose the programs " + "from plain Python.", + ), + ) else: - raise AssertionError( - "Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed." + raise errors.DSLError( + node.location, + f"'{node.func.id}' cannot be called inside a program.", + label=f"this has type '{new_func.type}'", + notes=( + "Inside a program, only field operators, scan operators and the " + "builtins 'minimum' and 'maximum' can be called.", + ), ) except ValueError as ex: diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index cc02c1d89b..a2d708eeec 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -13,9 +13,13 @@ import pathlib import symtable import textwrap +import typing from collections.abc import Callable, Iterator from dataclasses import dataclass -from typing import Any, cast +from typing import Any, Optional, cast + +from gt4py.eve.concepts import SourceLocation +from gt4py.next import errors MISSING_FILENAME = "" @@ -28,6 +32,39 @@ def get_closure_vars_from_function(function: Callable) -> dict[str, Any]: return dict(sorted({**builtins, **globals, **nonlocals}.items())) +def get_type_hints_from_function( + function: Callable, source_definition: Optional[SourceDefinition] = None +) -> dict[str, Any]: + """ + Resolve the type annotations of ``function``, reporting failures as :class:`errors.DSLError`. + + Annotations are resolved with :func:`typing.get_type_hints`, which evaluates + them in the function's module namespace; unresolvable annotations (e.g. names + that are not defined, or strings that are not valid types) raise exceptions + that would otherwise leak to the user as plain Python errors. + """ + try: + return typing.get_type_hints(function) + except Exception as err: + location = ( + SourceLocation( + filename=source_definition.filename, + line=source_definition.line_offset + 1, + column=source_definition.column_offset + 1, + ) + if source_definition is not None + else None + ) + raise errors.DSLError( + location, + f"Could not resolve type annotations of '{function.__name__}': {err}.", + hints=( + "Make sure every name used in an annotation is defined or imported in " + "the module where the function is defined.", + ), + ) from err + + def make_source_definition_from_function(func: Callable) -> SourceDefinition: try: filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_diagnostic_messages.py b/tests/next_tests/unit_tests/ffront_tests/test_diagnostic_messages.py index 70573d280e..c3974a9658 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_diagnostic_messages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_diagnostic_messages.py @@ -23,6 +23,7 @@ import gt4py.next as gtx from gt4py.next import errors, float32, float64 from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.ffront.func_to_past import ProgramParser IDim = gtx.Dimension("IDim") @@ -34,6 +35,12 @@ def parse_error(func) -> errors.DSLError: return exc_info.value +def parse_program_error(func) -> errors.DSLError: + with pytest.raises(errors.DSLError) as exc_info: + ProgramParser.apply_to_function(func) + return exc_info.value + + def test_undeclared_symbol_suggests_close_match(): def misspelled(temperature: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: tmp_field = temperature * 2.0 @@ -163,6 +170,166 @@ def misspelled(temperature: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], flo assert "While processing the definition of 'misspelled'." in exc_info.value.__notes__ +def test_global_statement_is_rejected_with_friendly_message(): + # 'global' used to crash an AST preprocessing pass with an AttributeError + # because its 'names' field holds plain strings, not AST nodes. + def with_global(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + global IDim + return a + + err = parse_error(with_global) + + assert isinstance(err, errors.UnsupportedPythonFeatureError) + assert err.message == "Unsupported Python syntax: 'global' statement." + assert any("read-only" in hint for hint in err.hints) + + +def test_numpy_style_attribute_on_field_is_a_dsl_error(): + # used to leak AttributeError: 'FieldType' object has no attribute 'T' + def with_numpy_attr(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return a.T + + err = parse_error(with_numpy_attr) + + assert err.message == "Type 'Field[[IDim], float64]' has no attribute 'T'." + assert any("NumPy-style" in note for note in err.notes) + + +def test_numpy_function_call_is_a_dsl_error(): + # used to leak ValueError: Type not supported + import numpy as np + + def with_numpy_call(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return np.sin(a) + + err = parse_error(with_numpy_call) + + assert err.message == "'sin' cannot be used inside a GT4Py function." + assert any("GT4Py built-in" in hint for hint in err.hints) + + +def test_missing_module_attribute_is_a_dsl_error(): + # used to leak AttributeError: module 'numpy' has no attribute 'sinn' + import numpy as np + + def with_missing_attr(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return np.sinn(a) + + err = parse_error(with_missing_attr) + + assert "module 'numpy' has no attribute 'sinn'" in err.message + + +def test_absolute_field_index_is_a_dsl_error(): + # used to leak AttributeError: 'ScalarType' object has no attribute 'dim' + def with_index(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return a[3] + + err = parse_error(with_index) + + assert err.message == "Fields cannot be indexed with 'int32'." + assert any("field offset" in hint for hint in err.hints) + + +def test_tuple_index_out_of_range_is_a_dsl_error(): + # used to leak IndexError: list index out of range + def with_oob_index(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + t = (a, a) + return t[5] + + err = parse_error(with_oob_index) + + assert err.message == "Tuple index 5 is out of range." + assert err.label == "this tuple has 2 elements" + + +def test_non_local_dimension_index_is_a_dsl_error(): + # used to crash with AssertionError on the dimension kind + def with_dim_index(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return a[IDim(3)] + + err = parse_error(with_dim_index) + + assert "'IDim' is not a local (neighbor) dimension" in err.message + + +def test_unresolvable_string_annotation_is_a_dsl_error(): + # used to leak SyntaxError from typing.get_type_hints + def with_bad_annotation(a: "not a type") -> gtx.Field[[IDim], float64]: # noqa: F722 [syntax-error-in-forward-annotation] + return a + + err = parse_error(with_bad_annotation) + + assert "Could not resolve type annotations of 'with_bad_annotation'" in err.message + + +def test_non_gt4py_parameter_annotation_is_a_dsl_error(): + # used to leak ValueError: Type not supported + def with_list_param(a: list) -> gtx.Field[[IDim], float64]: + return a + + err = parse_error(with_list_param) + + assert isinstance(err, errors.InvalidParameterAnnotationError) + assert any("GT4Py type" in hint for hint in err.hints) + + +def test_unresolvable_annotated_assignment_is_a_dsl_error(): + # used to leak NameError from eval'ing the annotation: 'gtx' is only + # visible inside the function if it is also referenced in the body + def with_ann_assign(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + b: gtx.Field[[IDim], float64] = a + return b + + err = parse_error(with_ann_assign) + + assert "Invalid type annotation 'gtx.Field[[IDim], float64]'" in err.message + assert any("GT4Py builtins" in note for note in err.notes) + + +@gtx.field_operator +def _copy_op(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return a + + +def test_expression_statement_in_program_is_a_dsl_error(): + # used to leak a TypeError from IR node validation + def with_expr_stmt(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + a + a + + err = parse_program_error(with_expr_stmt) + + assert err.message == "Only calls to GT4Py operators are allowed as statements in a program." + + +def test_calling_program_from_program_is_a_dsl_error(): + # used to crash with AssertionError + @gtx.program + def inner(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + _copy_op(a, out=out) + + def with_program_call(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + inner(a, out) + + err = parse_program_error(with_program_call) + + assert err.message == "Program 'inner' cannot be called from within another program." + + +def test_plain_python_function_in_program_is_a_dsl_error(): + # used to leak ValueError: Invalid callable annotations ... + def plain(a, out): + return a + + def with_plain_call(a: gtx.Field[[IDim], float64], out: gtx.Field[[IDim], float64]): + plain(a, out=out) + + err = parse_program_error(with_plain_call) + + assert isinstance(err, errors.DSLTypeError) + assert any("@field_operator" in hint for hint in err.hints) + + def test_diagnostic_codes_are_stable(): assert errors.UndefinedSymbolError.code == "undefined-symbol" assert errors.UnsupportedPythonFeatureError.code == "unsupported-syntax" diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 6399133870..f3d3a6eff1 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -91,7 +91,10 @@ def test_mistyped_arg(): def mistyped(inp: gtx.Field): return inp - with pytest.raises(ValueError, match="Field type requires two arguments, got 0."): + with pytest.raises( + errors.InvalidParameterAnnotationError, + match="Field type requires two arguments, got 0.", + ): _ = FieldOperatorParser.apply_to_function(mistyped)