Skip to content

Commit 81c368b

Browse files
authored
Merge pull request #272 from PyAutoLabs/feature/transform-decorator
refactor: reduce magic in transform decorator
2 parents 3c25d3c + 04ad882 commit 81c368b

2 files changed

Lines changed: 125 additions & 67 deletions

File tree

autoarray/abstract_ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ def __init__(self, array, xp=np):
7676

7777
self.use_jax = xp is not np
7878

79+
@property
80+
def is_transformed(self) -> bool:
81+
return self._is_transformed
82+
83+
@is_transformed.setter
84+
def is_transformed(self, value: bool):
85+
self._is_transformed = value
86+
7987
@property
8088
def _xp(self):
8189
if self.use_jax:
Lines changed: 117 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,117 @@
1-
from functools import wraps
2-
import numpy as np
3-
from typing import Union
4-
5-
from autoarray.structures.grids.uniform_1d import Grid1D
6-
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
7-
from autoarray.structures.grids.uniform_2d import Grid2D
8-
9-
10-
def transform(func):
11-
"""
12-
Checks whether the input Grid2D of (y,x) coordinates have previously been transformed. If they have not
13-
been transformed then they are transformed.
14-
15-
Parameters
16-
----------
17-
func
18-
A function where the input grid is the grid whose coordinates are transformed.
19-
20-
Returns
21-
-------
22-
A function that can accept cartesian or transformed coordinates
23-
"""
24-
25-
@wraps(func)
26-
def wrapper(
27-
obj: object,
28-
grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D],
29-
xp=np,
30-
*args,
31-
**kwargs,
32-
) -> Union[np.ndarray, Grid2D, Grid2DIrregular]:
33-
"""
34-
This decorator checks whether the input grid has been transformed to the reference frame of the class
35-
that owns the function. If it has not been transformed, it is transformed.
36-
37-
A function call which uses this decorator often has many subsequent function calls which also use the
38-
decorator. To ensure the grid is only transformed once, the `is_transformed` keyword is used to track
39-
whether the grid has been transformed.
40-
41-
Parameters
42-
----------
43-
obj
44-
An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid.
45-
grid
46-
The (y, x) coordinates in the original reference frame of the grid.
47-
48-
Returns
49-
-------
50-
A grid_like object whose coordinates may be transformed.
51-
"""
52-
53-
if not kwargs.get("is_transformed"):
54-
kwargs["is_transformed"] = True
55-
56-
transformed_grid = obj.transformed_to_reference_frame_grid_from(
57-
grid, xp, **kwargs
58-
)
59-
60-
result = func(obj, transformed_grid, xp, *args, **kwargs)
61-
62-
else:
63-
result = func(obj, grid, xp, *args, **kwargs)
64-
65-
return result
66-
67-
return wrapper
1+
from functools import wraps
2+
import numpy as np
3+
from typing import Union
4+
5+
from autoarray.structures.grids.uniform_1d import Grid1D
6+
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
7+
from autoarray.structures.grids.uniform_2d import Grid2D
8+
9+
10+
def transform(func=None, *, rotate_back=False):
11+
"""
12+
Checks whether the input Grid2D of (y,x) coordinates have previously been transformed. If they have not
13+
been transformed then they are transformed.
14+
15+
Can be used with or without arguments::
16+
17+
@transform
18+
def convergence_2d_from(self, grid, xp=np, **kwargs): ...
19+
20+
@transform(rotate_back=True)
21+
def deflections_yx_2d_from(self, grid, xp=np, **kwargs): ...
22+
23+
**Frame conventions and rotate_back**
24+
25+
This decorator transforms the input grid into the profile's reference frame (centred on the
26+
profile centre and rotated by its position angle) before calling the decorated function.
27+
28+
For **scalar** quantities (convergence, potential), the returned value is frame-invariant — no
29+
back-rotation is needed, so use ``@transform`` without ``rotate_back``.
30+
31+
For **vector** quantities (e.g. deflection angles), whether back-rotation is needed depends on
32+
which frame the returned components are expressed in:
33+
34+
- If the function computes vector components using the rotated grid coordinates (i.e. the
35+
components are expressed in the profile's frame), they must be rotated back to the observer
36+
frame before use in ray-tracing. Set ``rotate_back=True`` for this case.
37+
38+
- If the function reconstructs observer-frame components from scalar quantities (e.g. computing
39+
a radial deflection magnitude and converting to Cartesian using observer-frame geometry), the
40+
result is already in the observer frame and ``rotate_back`` should remain ``False``.
41+
42+
When ``rotate_back=True``, the decorator calls ``obj.rotated_grid_from_reference_frame_from``
43+
on the result after evaluation, applying the inverse rotation by the profile's position angle.
44+
45+
For **spin-2** quantities (shear), the transformation law uses twice the profile angle. This
46+
is not handled by ``rotate_back`` — shear methods must apply the 2-theta rotation manually.
47+
48+
Parameters
49+
----------
50+
func
51+
A function where the input grid is the grid whose coordinates are transformed.
52+
rotate_back
53+
If ``True``, the result is rotated back from the profile's reference frame after
54+
evaluation. Use this when the decorated function returns vector components that were
55+
computed in the profile's rotated coordinate basis and need to be expressed in the
56+
original observer frame.
57+
58+
Returns
59+
-------
60+
A function that can accept cartesian or transformed coordinates
61+
"""
62+
63+
def decorator(func):
64+
@wraps(func)
65+
def wrapper(
66+
obj: object,
67+
grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D],
68+
xp=np,
69+
*args,
70+
**kwargs,
71+
) -> Union[np.ndarray, Grid2D, Grid2DIrregular]:
72+
"""
73+
This decorator checks whether the input grid has been transformed to the reference frame of the class
74+
that owns the function. If it has not been transformed, it is transformed.
75+
76+
The transform state is tracked via the ``is_transformed`` property on the grid object itself.
77+
When a decorated function calls another decorated function with the same (already-transformed)
78+
grid, the flag prevents the grid from being transformed a second time.
79+
80+
Parameters
81+
----------
82+
obj
83+
An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid.
84+
grid
85+
The (y, x) coordinates in the original reference frame of the grid.
86+
87+
Returns
88+
-------
89+
A grid_like object whose coordinates may be transformed.
90+
"""
91+
92+
if not getattr(grid, "is_transformed", False):
93+
transformed_grid = obj.transformed_to_reference_frame_grid_from(
94+
grid, xp, **kwargs
95+
)
96+
transformed_grid.is_transformed = True
97+
98+
result = func(obj, transformed_grid, xp, *args, **kwargs)
99+
100+
else:
101+
result = func(obj, grid, xp, *args, **kwargs)
102+
103+
if rotate_back:
104+
result = obj.rotated_grid_from_reference_frame_from(
105+
grid=result, xp=xp
106+
)
107+
108+
return result
109+
110+
return wrapper
111+
112+
if func is not None:
113+
# Called without arguments: @transform
114+
return decorator(func)
115+
116+
# Called with arguments: @transform(rotate_back=True)
117+
return decorator

0 commit comments

Comments
 (0)