Skip to content

Commit d74a6d3

Browse files
authored
Merge pull request #277 from PyAutoLabs/feature/data-typing-simplify
refactor: simplify decorator type-dispatch system
2 parents 295cead + fddb664 commit d74a6d3

12 files changed

Lines changed: 303 additions & 509 deletions

File tree

autoarray/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@
8484
from .structures.triangles.shape import Triangle
8585
from .structures.triangles.shape import Square
8686
from .structures.triangles.shape import Polygon
87-
from .structures import decorators as grid_dec
87+
from .structures import decorators as grid_dec # deprecated alias
88+
from .structures import decorators
8889
from .structures.header import Header
8990
from .layout.region import Region1D
9091
from .layout.region import Region2D

autoarray/mock.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from autoarray.structures.mock.mock_grid import MockMeshGrid
2121
from autoarray.structures.mock.mock_decorators import MockGrid1DLikeObj
2222
from autoarray.structures.mock.mock_decorators import MockGrid2DLikeObj
23+
from autoarray.structures.mock.mock_decorators import MockTransformProfile
Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
from typing import Union
2-
31
import numpy as np
42

5-
from autoarray.mask.mask_1d import Mask1D
63
from autoarray.mask.mask_2d import Mask2D
7-
from autoarray.structures.grids.uniform_1d import Grid1D
84
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
95
from autoarray.structures.grids.uniform_2d import Grid2D
106

@@ -18,20 +14,16 @@ def __init__(self, func, obj, grid, xp=np, *args, **kwargs):
1814
This is used by the `to_array`, `to_grid` and `to_vector_yx` decorators to ensure that the input grid and output
1915
data structure are consistent.
2016
21-
There are three types of consistent data structures and therefore decorated function mappings:
17+
There are two types of consistent data structures and therefore decorated function mappings:
2218
2319
- Uniform: 2D structures defined on a uniform grid of data points, for example the `Array2D` and `Grid2D`
2420
objects. Both structures are defined according to a `Mask2D`, which the maker object ensures is passed through
2521
self consistently.
2622
2723
- Irregular: 2D structures defined on an irregular grid of data points, for example an `ArrayIrregular`
28-
and `Grid2DIrregular` objects. Neither structure is defined according to a mask and the maker sures the lack of
24+
and `Grid2DIrregular` objects. Neither structure is defined according to a mask and the maker ensures the lack of
2925
a mask does not prevent the function from being evaluated.
3026
31-
- 1D: 1D structures defined on a 1D grid of data points, for example the `Array1D` and `Grid1D` objects.
32-
These project the 1D grid to a 2D grid to ensure the function can be evaluated, and then deproject the 2D grid
33-
back to a 1D grid to ensure the output data structure is consistent with the input grid.
34-
3527
Parameters
3628
----------
3729
func
@@ -66,7 +58,7 @@ def _xp(self):
6658
return np
6759

6860
@property
69-
def mask(self) -> Union[Mask1D, Mask2D]:
61+
def mask(self) -> Mask2D:
7062
return self.grid.mask
7163

7264
@property
@@ -79,30 +71,8 @@ def via_grid_2d(self, result):
7971
def via_grid_2d_irr(self, result):
8072
raise NotImplementedError
8173

82-
def via_grid_1d(self, result):
83-
raise NotImplementedError
84-
8574
@property
8675
def evaluate_func(self):
87-
"""
88-
Evaluate the function that is being decorated, using the grid that is passed to the maker object when it is
89-
initialized.
90-
91-
In normal usage, the input grid is 2D and it is simply passed to the decorated function.
92-
93-
However, if the input grid is 1D, the grid is projected to a 2D grid before being passed to the function. This
94-
is because the function is expected to evaluate a 2D grid, and the maker object ensures that the function can
95-
be evaluated by projecting the 1D grid to a 2D grid.
96-
97-
Returns
98-
-------
99-
The result of the function that is being decorated, which is the output data structure that is consistent with
100-
the input grid.
101-
"""
102-
103-
if isinstance(self.grid, Grid1D):
104-
grid = self.grid.grid_2d_radial_projected_from()
105-
return self.func(self.obj, grid, self._xp, *self.args, **self.kwargs)
10676
return self.func(self.obj, self.grid, self._xp, *self.args, **self.kwargs)
10777

10878
@property
@@ -111,21 +81,17 @@ def result(self):
11181
The result of the function that is being decorated, which this function converts to the output data structure
11282
that is consistent with the input grid.
11383
114-
This function called one of three methods, depending on the type of the input grid:
84+
This function calls one of two methods, depending on the type of the input grid:
11585
11686
- `via_grid_2d`: If the input grid is a `Grid2D` object.
11787
- `via_grid_2d_irr`: If the input grid is a `Grid2DIrregular` object.
118-
- `via_grid_1d`: If the input grid is a `Grid1D` object.
11988
120-
These functions are over written depending on whether the decorated function returns an array, grid or vector.
121-
The over written functions are in the child classes `ArrayMaker`, `GridMaker` and `VectorYXMaker`.
89+
If the input is a raw ndarray (e.g. numpy or JAX), the function result is returned unchanged.
12290
"""
12391

12492
if isinstance(self.grid, Grid2D):
12593
return self.via_grid_2d(self.evaluate_func)
12694
elif isinstance(self.grid, Grid2DIrregular):
12795
return self.via_grid_2d_irr(self.evaluate_func)
128-
elif isinstance(self.grid, Grid1D):
129-
return self.via_grid_1d(self.evaluate_func)
13096

13197
return self.evaluate_func
Lines changed: 49 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,49 @@
1-
import numpy as np
2-
from functools import wraps
3-
4-
5-
from typing import List, Union
6-
7-
from autoarray.structures.arrays.irregular import ArrayIrregular
8-
from autoarray.structures.arrays.uniform_1d import Array1D
9-
from autoarray.structures.arrays.uniform_2d import Array2D
10-
from autoarray.structures.decorators.abstract import AbstractMaker
11-
from autoarray.structures.grids.uniform_1d import Grid1D
12-
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
13-
from autoarray.structures.grids.uniform_2d import Grid2D
14-
15-
16-
class ArrayMaker(AbstractMaker):
17-
def via_grid_2d(self, result) -> Union[Array2D, List[Array2D]]:
18-
"""
19-
Convert the result of a decorated function which receives as input a `Grid2D` object to an `Array2D` object.
20-
21-
If the result returns a list, a list of `Array2D` objects is returned.
22-
23-
Parameters
24-
----------
25-
result
26-
The input result (e.g. of a decorated function) that is converted to an Array2D or list of Array2D objects.
27-
"""
28-
29-
if not isinstance(result, list):
30-
return Array2D(values=result, mask=self.mask)
31-
return [Array2D(values=res, mask=self.mask) for res in result]
32-
33-
def via_grid_2d_irr(self, result) -> Union[ArrayIrregular, List[ArrayIrregular]]:
34-
"""
35-
Convert the result of a decorated function which receives as input a `Grid2DIrregular` object to an `ArrayIrregular`
36-
object.
37-
38-
If the result returns a list, a list of `ArrayIrregular` objects is returned.
39-
40-
Parameters
41-
----------
42-
result
43-
The input result (e.g. of a decorated function) that is converted to an ArrayIrregular or list of
44-
ArrayIrregular objects.
45-
"""
46-
if not isinstance(result, list):
47-
return ArrayIrregular(values=result)
48-
return [ArrayIrregular(values=res) for res in result]
49-
50-
def via_grid_1d(self, result) -> Union[Array1D, List[Array1D]]:
51-
"""
52-
Convert the result of a decorated function which receives as input a `Grid1D` object to an `Array1D` object.
53-
54-
If the result returns a list, a list of `Array1D` objects is returned.
55-
56-
Parameters
57-
----------
58-
result
59-
The input result (e.g. of a decorated function) that is converted to an Array1D or list of Array1D objects.
60-
"""
61-
if not isinstance(result, list):
62-
return Array1D(values=result, mask=self.mask)
63-
return [Array1D(values=res, mask=self.mask) for res in result]
64-
65-
66-
def to_array(func):
67-
"""
68-
Homogenize the inputs and outputs of functions that take 1D or 2D grids of coordinates and return a 1D ndarray
69-
which is converted to an `Array2D`, `ArrayIrregular` or `Array1D` object.
70-
71-
Parameters
72-
----------
73-
func
74-
A function which computes a set of values from a 1D or 2D grid of coordinates.
75-
76-
Returns
77-
-------
78-
A function that has its outputs homogenized to `Array2D`, `ArrayIrregular` or `Array1D` objects.
79-
"""
80-
81-
@wraps(func)
82-
def wrapper(
83-
obj: object,
84-
grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D],
85-
xp=np,
86-
*args,
87-
**kwargs,
88-
) -> Union[np.ndarray, Array1D, Array2D, ArrayIrregular, List]:
89-
"""
90-
This decorator homogenizes the input of a "grid_like" 2D structure (`Grid2D`, `Grid2DIrregular` or `Grid1D`)
91-
into a function which outputs an array-like structure (`Array2D`, `ArrayIrregular` or `Array1D`).
92-
93-
It allows these classes to be interchangeably input into a function, such that the grid is used to evaluate
94-
the function at every (y,x) coordinates of the grid using specific functionality of the input grid.
95-
96-
The grid_like objects `Grid2D` and `Grid2DIrregular` are input into the function as a slimmed 2D ndarray array
97-
of shape [total_coordinates, 2] where the second dimension stores the (y,x)
98-
99-
There are three types of consistent data structures and therefore decorated function mappings:
100-
101-
- Uniform (`Grid2D` -> `Array`): 2D structures defined on a uniform grid of data points. Both structures are
102-
defined according to a `Mask2D`, which the maker object ensures is passed through self consistently.
103-
104-
- Irregular (`Grid2DIrregular` -> `ArrayIrregular`: 2D structures defined on an irregular grid of data points,
105-
Neither structure is defined according to a mask and the maker sures the lack of a mask does not prevent the
106-
function from being evaluated.
107-
108-
- 1D (`Grid1D` -> `Array1D`): 1D structures defined on a 1D grid of data points. These project the 1D grid
109-
to a 2D grid to ensure the function can be evaluated, and then deproject the 2D grid back to a 1D grid to
110-
ensure the output data structure is consistent with the input grid.
111-
112-
Parameters
113-
----------
114-
obj
115-
An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid.
116-
grid
117-
A grid_like object of coordinates on which the function values are evaluated.
118-
119-
Returns
120-
-------
121-
The function values evaluated on the grid with the same structure as the input grid_like object.
122-
"""
123-
return ArrayMaker(func=func, obj=obj, grid=grid, xp=xp, *args, **kwargs).result
124-
125-
return wrapper
1+
import numpy as np
2+
from functools import wraps
3+
from typing import List, Union
4+
5+
from autoarray.structures.arrays.irregular import ArrayIrregular
6+
from autoarray.structures.arrays.uniform_2d import Array2D
7+
from autoarray.structures.decorators.abstract import AbstractMaker
8+
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
9+
from autoarray.structures.grids.uniform_2d import Grid2D
10+
11+
12+
class ArrayMaker(AbstractMaker):
13+
def via_grid_2d(self, result) -> Union[Array2D, List[Array2D]]:
14+
if not isinstance(result, list):
15+
return Array2D(values=result, mask=self.mask)
16+
return [Array2D(values=res, mask=self.mask) for res in result]
17+
18+
def via_grid_2d_irr(self, result) -> Union[ArrayIrregular, List[ArrayIrregular]]:
19+
if not isinstance(result, list):
20+
return ArrayIrregular(values=result)
21+
return [ArrayIrregular(values=res) for res in result]
22+
23+
24+
def to_array(func):
25+
"""
26+
Homogenize the inputs and outputs of functions that take 2D grids of coordinates and return a 1D ndarray
27+
which is converted to an `Array2D` or `ArrayIrregular` object.
28+
29+
Parameters
30+
----------
31+
func
32+
A function which computes a set of values from a 2D grid of coordinates.
33+
34+
Returns
35+
-------
36+
A function that has its outputs homogenized to `Array2D` or `ArrayIrregular` objects.
37+
"""
38+
39+
@wraps(func)
40+
def wrapper(
41+
obj: object,
42+
grid: Union[np.ndarray, Grid2D, Grid2DIrregular],
43+
xp=np,
44+
*args,
45+
**kwargs,
46+
) -> Union[np.ndarray, Array2D, ArrayIrregular, List]:
47+
return ArrayMaker(func=func, obj=obj, grid=grid, xp=xp, *args, **kwargs).result
48+
49+
return wrapper

0 commit comments

Comments
 (0)