Skip to content
2 changes: 2 additions & 0 deletions src/gt4py/next/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MissingParameterAnnotationError,
UndefinedSymbolError,
UnsupportedPythonFeatureError,
did_you_mean,
)


Expand All @@ -32,4 +33,5 @@
"MissingParameterAnnotationError",
"UndefinedSymbolError",
"UnsupportedPythonFeatureError",
"did_you_mean",
]
13 changes: 9 additions & 4 deletions src/gt4py/next/errors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def message(self) -> str:
return self.args[0]


def _did_you_mean(name: str, candidates: Iterable[str]) -> list[str]:
def did_you_mean(name: str, candidates: Iterable[str]) -> list[str]:
"""Produce a 'Did you mean ...?' hint if `name` closely matches any candidate."""
# Never suggest the name the user already wrote (it can appear among the
# candidates as a same-named symbol from another SSA generation).
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(
location,
f"Undeclared symbol '{name}'.",
label="not defined at this point",
hints=_did_you_mean(name, candidates),
hints=did_you_mean(name, candidates),
)
self.sym_name = name

Expand All @@ -183,8 +183,8 @@ def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg:


class DSLTypeError(DSLError):
def __init__(self, location: Optional[SourceLocation], message: str) -> None:
super().__init__(location, message)
def __init__(self, location: Optional[SourceLocation], message: str, **kwargs: Any) -> None:
super().__init__(location, message, **kwargs)


class MissingParameterAnnotationError(DSLTypeError):
Expand All @@ -203,6 +203,11 @@ def __init__(self, location: Optional[SourceLocation], param_name: str, type_: A
super().__init__(
location, f"Parameter '{param_name}' has invalid type annotation '{type_}'."
)
self.hints = [
"Annotate parameters with a GT4Py type: a field (e.g. "
"'gtx.Field[[IDim], gtx.float64]'), a scalar (e.g. 'gtx.float64') or a "
"tuple of these."
]
self.param_name = param_name
self.annotated_type = type_

Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/ffront/ast_passes/simple_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def generic_visit(self, node: ast.AST) -> Iterator[ast.AST]: # type: ignore[ove
"""Override generic visit to deal with generators."""
for field, old_value in ast.iter_fields(node):
if isinstance(old_value, list):
new_values = [i for j in old_value for i in self.visit(j)]
# fields may contain non-AST values, e.g. the names of a `global`
# statement; those must be kept as-is, not visited.
new_values = [
i
for j in old_value
for i in (self.visit(j) if isinstance(j, ast.AST) else (j,))
]
old_value[:] = new_values
elif isinstance(old_value, ast.AST):
new_node, *_ = list(self.visit(old_value)) or (None,)
Expand Down
18 changes: 16 additions & 2 deletions src/gt4py/next/ffront/dialect_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

import ast
import textwrap
import typing
from dataclasses import dataclass
from typing import Callable, ClassVar, Collection

from gt4py.eve.concepts import SourceLocation
from gt4py.eve.extended_typing import Any, Generic, TypeVar
from gt4py.next import errors
from gt4py.next.ffront import source_utils
from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations
from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings
from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function
Expand Down Expand Up @@ -67,6 +67,20 @@
ast.ClassDef: ("class definition", ()),
ast.JoinedStr: ("f-string", ("Strings cannot be computed inside GT4Py functions.",)),
ast.Match: ("'match' statement", ("Use 'if'/'elif' chains or 'where' instead.",)),
ast.Global: (
"'global' statement",
(
"Variables from the surrounding scope are read-only inside GT4Py functions; "
"pass values as parameters and return results instead.",
),
),
ast.Nonlocal: (
"'nonlocal' statement",
(
"Variables from the surrounding scope are read-only inside GT4Py functions; "
"pass values as parameters and return results instead.",
),
),
}


Expand Down Expand Up @@ -136,7 +150,7 @@ def apply(
def apply_to_function(cls, function: Callable) -> DialectRootT:
src = SourceDefinition.from_function(function)
closure_vars = get_closure_vars_from_function(function)
annotations = typing.get_type_hints(function)
annotations = source_utils.get_type_hints_from_function(function, src)
return cls.apply(src, closure_vars, annotations)

@classmethod
Expand Down
93 changes: 90 additions & 3 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,64 @@ def visit_Symbol(
return new_node
return node

def _invalid_attribute_error(
self, node: foast.Attribute, value_type: ts.TypeSpec
) -> errors.DSLError:
notes: list[str] = []
hints: list[str] = []
if isinstance(value_type, ts.NamedCollectionType):
hints = errors.did_you_mean(node.attr, value_type.keys)
elif isinstance(value_type, (ts.FieldType, ts.ScalarType)):
notes = [
"GT4Py fields and scalars do not provide NumPy-style methods or attributes; "
"use the GT4Py built-in functions instead."
]
return errors.DSLError(
node.location,
f"Type '{value_type}' has no attribute '{node.attr}'.",
label="attribute does not exist",
notes=notes,
hints=hints,
)

def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Attribute:
new_value = self.visit(node.value, **kwargs)
# Attribute access is only valid on namespaces (e.g. modules) and named
# collections; anything else (e.g. NumPy-style attributes of a field)
# is a user error and must not leak `getattr` failures.
try:
new_type = getattr(new_value.type, node.attr)
except AttributeError as e:
if isinstance(new_value.type, type_translation.NamespaceProxy):
# `getattr` on the underlying Python object failed; its message
# ("module 'numpy' has no attribute ...") is the best we have.
raise errors.DSLError(node.location, f"{e}.") from e
raise self._invalid_attribute_error(node, new_value.type) from e
except ValueError as e:
# the attribute exists on the Python object behind a namespace, but
# its value has no representation in the DSL type system
hints: list[str] = []
if isinstance(new_value.type, type_translation.NamespaceProxy) and callable(
getattr(new_value.type._object, node.attr, None)
):
hints = [
"External Python functions cannot be called inside GT4Py functions. "
"Use the corresponding GT4Py built-in (e.g. 'sin', 'sqrt', 'exp', "
"'maximum' from 'gt4py.next') if one exists."
]
raise errors.DSLError(
node.location,
f"'{node.attr}' cannot be used inside a GT4Py function.",
label="value not representable in the GT4Py type system",
hints=hints,
) from e
if not isinstance(new_type, ts.TypeSpec):
raise self._invalid_attribute_error(node, new_value.type)
return foast.Attribute(
value=new_value,
attr=node.attr,
location=node.location,
type=getattr(new_value.type, node.attr),
type=new_type,
)

def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscript:
Expand All @@ -452,6 +503,12 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri
node.location,
f"Tuples need to be indexed with literal integers, got '{node.index}'.",
) from ex
if not -len(types) <= index < len(types):
raise errors.DSLError(
node.location,
f"Tuple index {index} is out of range.",
label=f"this tuple has {len(types)} element{'s' * (len(types) != 1)}",
)
new_type = types[index]
case ts.OffsetType(source=source, target=(target1, target2)):
if not target2.kind == DimensionKind.LOCAL:
Expand All @@ -470,7 +527,23 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri
)
new_type = new_value.type
case ts.FieldType(dims=dims, dtype=dtype):
# e.g. `field[LocalDim(42)]`
# e.g. `field[LocalDim(42)]`: the only valid field subscript is a
# local-dimension index, which removes that dimension
if not isinstance(new_index.type, ts.IndexType):
raise errors.DSLError(
node.location,
f"Fields cannot be indexed with '{new_index.type}'.",
label="invalid field index",
notes=(
"GT4Py expressions operate on whole fields; accessing single "
"grid points by absolute position is not possible.",
),
hints=(
"To access neighboring grid points, apply a field offset, "
"e.g. 'field(Ioff[1])'. Entries of a local (neighbor) dimension "
"can be selected with 'field[LocalDim(0)]'.",
),
)
new_type = ts.FieldType(
dims=[d for d in dims if d != new_index.type.dim],
dtype=dtype,
Expand Down Expand Up @@ -743,7 +816,21 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call:
elif isinstance(new_func.type, ts.FieldType):
pass
elif isinstance(new_func.type, ts.DimensionType):
assert new_func.type.dim.kind == DimensionKind.LOCAL
if new_func.type.dim.kind != DimensionKind.LOCAL:
raise errors.DSLError(
node.location,
f"'{new_func.type.dim.value}' is not a local (neighbor) dimension and "
"cannot be used to construct an index.",
label=f"this dimension is of kind '{new_func.type.dim.kind}'",
notes=(
"Only indices of a local dimension (e.g. 'V2EDim(0)') can be "
"constructed this way, to select one neighbor entry of a field.",
),
hints=(
"To access neighboring grid points along a dimension, apply a "
"field offset instead, e.g. 'field(Ioff[1])'.",
),
)
return foast.Call(
func=new_func,
args=new_args,
Expand Down
36 changes: 22 additions & 14 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from __future__ import annotations

import ast
import textwrap
import typing
from typing import Any, Type

import gt4py.eve as eve
Expand Down Expand Up @@ -71,7 +69,7 @@ def func_to_foast(inp: DSLFieldOperatorDef) -> FOASTOperatorDef:
"""
source_def = source_utils.SourceDefinition.from_function(inp.definition)
closure_vars = source_utils.get_closure_vars_from_function(inp.definition)
annotations = typing.get_type_hints(inp.definition)
annotations = source_utils.get_type_hints_from_function(inp.definition, source_def)
try:
foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations)
loc = foast_definition_node.location
Expand Down Expand Up @@ -233,7 +231,12 @@ def visit_arg(self, node: ast.arg) -> foast.DataSymbol:
loc = self.get_location(node)
if (annotation := self.annotations.get(node.arg, None)) is None:
raise errors.MissingParameterAnnotationError(loc, node.arg)
new_type = type_translation.from_type_hint(annotation)
try:
new_type = type_translation.from_type_hint(annotation)
except ValueError as e:
err = errors.InvalidParameterAnnotationError(loc, node.arg, annotation)
err.notes.append(str(e))
raise err from e
if not isinstance(new_type, ts.DataType):
raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type)
return foast.DataSymbol(id=node.arg, location=loc, type=new_type)
Expand Down Expand Up @@ -310,8 +313,19 @@ def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs: Any) -> foast.Assign:
), "Annotations should be ast.Constant(string). Use StringifyAnnotationsPass"

context = {**fbuiltins.BUILTINS, **self.closure_vars}
annotation = eval(node.annotation.value, context)
target_type = type_translation.from_type_hint(annotation, globalns=context)
try:
annotation = eval(node.annotation.value, context)
target_type = type_translation.from_type_hint(annotation, globalns=context)
except Exception as e:
raise errors.DSLError(
self.get_location(node),
f"Invalid type annotation '{node.annotation.value}': {e}.",
notes=(
"Inside a GT4Py function, annotations can only use GT4Py "
"builtins and names that are also referenced in the function "
"body.",
),
) from e
else:
target_type = ts.DeferredType()

Expand Down Expand Up @@ -453,14 +467,8 @@ def visit_Compare(self, node: ast.Compare, **kwargs: Any) -> foast.Compare:
refactored = UnchainComparesPass.apply(node)
raise errors.DSLError(
loc,
textwrap.dedent(
f"""
Comparison chains are not allowed. Please replace
{ast.unparse(node)}
by
{ast.unparse(refactored)}
""",
),
"Comparison chains are not allowed.",
hints=(f"Replace '{ast.unparse(node)}' by '{ast.unparse(refactored)}'.",),
)
return foast.Compare(
op=self.visit(node.ops[0]),
Expand Down
Loading