-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathabstract.py
More file actions
97 lines (74 loc) · 3.73 KB
/
abstract.py
File metadata and controls
97 lines (74 loc) · 3.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
from autoarray.mask.mask_2d import Mask2D
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
from autoarray.structures.grids.uniform_2d import Grid2D
class AbstractMaker:
def __init__(self, func, obj, grid, xp=np, *args, **kwargs):
"""
Makes 2D data structures from an input function and grid, ensuring that the structure of the input grid is
paired to the structure of the output data structure.
This is used by the `to_array`, `to_grid` and `to_vector_yx` decorators to ensure that the input grid and output
data structure are consistent.
There are two types of consistent data structures and therefore decorated function mappings:
- Uniform: 2D structures defined on a uniform grid of data points, for example the `Array2D` and `Grid2D`
objects. Both structures are defined according to a `Mask2D`, which the maker object ensures is passed through
self consistently.
- Irregular: 2D structures defined on an irregular grid of data points, for example an `ArrayIrregular`
and `Grid2DIrregular` objects. Neither structure is defined according to a mask and the maker ensures the lack of
a mask does not prevent the function from being evaluated.
Parameters
----------
func
The function that is being decorated, which the maker object evaluates to create the output data structure.
obj
The object that the function is a method of, which is passed to the function when it is evaluated. This
is typically the self object of the class the function is a method of.
grid
The grid that is passed to the function when it is evaluated, which is used to evaluate the function.
For example, this might be the `Grid2D` object that the function checks the typing of when determining
the output data structure.
args
Any arguments that are passed to the function when it is evaluated.
kwargs
Any keyword arguments that are passed to the function when it is evaluated.
"""
self.func = func
self.obj = obj
self.grid = grid
self.args = args
self.kwargs = kwargs
self.use_jax = xp is not np
@property
def _xp(self):
if self.use_jax:
import jax.numpy as jnp
return jnp
return np
@property
def mask(self) -> Mask2D:
return self.grid.mask
@property
def over_sample_size(self) -> np.ndarray:
return self.grid.over_sample_size
def via_grid_2d(self, result):
raise NotImplementedError
def via_grid_2d_irr(self, result):
raise NotImplementedError
@property
def evaluate_func(self):
return self.func(self.obj, self.grid, self._xp, *self.args, **self.kwargs)
@property
def result(self):
"""
The result of the function that is being decorated, which this function converts to the output data structure
that is consistent with the input grid.
This function calls one of two methods, depending on the type of the input grid:
- `via_grid_2d`: If the input grid is a `Grid2D` object.
- `via_grid_2d_irr`: If the input grid is a `Grid2DIrregular` object.
If the input is a raw ndarray (e.g. numpy or JAX), the function result is returned unchanged.
"""
if isinstance(self.grid, Grid2D):
return self.via_grid_2d(self.evaluate_func)
elif isinstance(self.grid, Grid2DIrregular):
return self.via_grid_2d_irr(self.evaluate_func)
return self.evaluate_func