|
1 | | -from functools import wraps |
2 | | -import numpy as np |
3 | | -from typing import Union |
4 | | - |
5 | | -from autoarray.structures.grids.uniform_1d import Grid1D |
6 | | -from autoarray.structures.grids.irregular_2d import Grid2DIrregular |
7 | | -from autoarray.structures.grids.uniform_2d import Grid2D |
8 | | - |
9 | | - |
10 | | -def transform(func): |
11 | | - """ |
12 | | - Checks whether the input Grid2D of (y,x) coordinates have previously been transformed. If they have not |
13 | | - been transformed then they are transformed. |
14 | | -
|
15 | | - Parameters |
16 | | - ---------- |
17 | | - func |
18 | | - A function where the input grid is the grid whose coordinates are transformed. |
19 | | -
|
20 | | - Returns |
21 | | - ------- |
22 | | - A function that can accept cartesian or transformed coordinates |
23 | | - """ |
24 | | - |
25 | | - @wraps(func) |
26 | | - def wrapper( |
27 | | - obj: object, |
28 | | - grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D], |
29 | | - xp=np, |
30 | | - *args, |
31 | | - **kwargs, |
32 | | - ) -> Union[np.ndarray, Grid2D, Grid2DIrregular]: |
33 | | - """ |
34 | | - This decorator checks whether the input grid has been transformed to the reference frame of the class |
35 | | - that owns the function. If it has not been transformed, it is transformed. |
36 | | -
|
37 | | - A function call which uses this decorator often has many subsequent function calls which also use the |
38 | | - decorator. To ensure the grid is only transformed once, the `is_transformed` keyword is used to track |
39 | | - whether the grid has been transformed. |
40 | | -
|
41 | | - Parameters |
42 | | - ---------- |
43 | | - obj |
44 | | - An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid. |
45 | | - grid |
46 | | - The (y, x) coordinates in the original reference frame of the grid. |
47 | | -
|
48 | | - Returns |
49 | | - ------- |
50 | | - A grid_like object whose coordinates may be transformed. |
51 | | - """ |
52 | | - |
53 | | - if not kwargs.get("is_transformed"): |
54 | | - kwargs["is_transformed"] = True |
55 | | - |
56 | | - transformed_grid = obj.transformed_to_reference_frame_grid_from( |
57 | | - grid, xp, **kwargs |
58 | | - ) |
59 | | - |
60 | | - result = func(obj, transformed_grid, xp, *args, **kwargs) |
61 | | - |
62 | | - else: |
63 | | - result = func(obj, grid, xp, *args, **kwargs) |
64 | | - |
65 | | - return result |
66 | | - |
67 | | - return wrapper |
| 1 | +from functools import wraps |
| 2 | +import numpy as np |
| 3 | +from typing import Union |
| 4 | + |
| 5 | +from autoarray.structures.grids.uniform_1d import Grid1D |
| 6 | +from autoarray.structures.grids.irregular_2d import Grid2DIrregular |
| 7 | +from autoarray.structures.grids.uniform_2d import Grid2D |
| 8 | + |
| 9 | + |
| 10 | +def transform(func=None, *, rotate_back=False): |
| 11 | + """ |
| 12 | + Checks whether the input Grid2D of (y,x) coordinates have previously been transformed. If they have not |
| 13 | + been transformed then they are transformed. |
| 14 | +
|
| 15 | + Can be used with or without arguments:: |
| 16 | +
|
| 17 | + @transform |
| 18 | + def convergence_2d_from(self, grid, xp=np, **kwargs): ... |
| 19 | +
|
| 20 | + @transform(rotate_back=True) |
| 21 | + def deflections_yx_2d_from(self, grid, xp=np, **kwargs): ... |
| 22 | +
|
| 23 | + **Frame conventions and rotate_back** |
| 24 | +
|
| 25 | + This decorator transforms the input grid into the profile's reference frame (centred on the |
| 26 | + profile centre and rotated by its position angle) before calling the decorated function. |
| 27 | +
|
| 28 | + For **scalar** quantities (convergence, potential), the returned value is frame-invariant — no |
| 29 | + back-rotation is needed, so use ``@transform`` without ``rotate_back``. |
| 30 | +
|
| 31 | + For **vector** quantities (e.g. deflection angles), whether back-rotation is needed depends on |
| 32 | + which frame the returned components are expressed in: |
| 33 | +
|
| 34 | + - If the function computes vector components using the rotated grid coordinates (i.e. the |
| 35 | + components are expressed in the profile's frame), they must be rotated back to the observer |
| 36 | + frame before use in ray-tracing. Set ``rotate_back=True`` for this case. |
| 37 | +
|
| 38 | + - If the function reconstructs observer-frame components from scalar quantities (e.g. computing |
| 39 | + a radial deflection magnitude and converting to Cartesian using observer-frame geometry), the |
| 40 | + result is already in the observer frame and ``rotate_back`` should remain ``False``. |
| 41 | +
|
| 42 | + When ``rotate_back=True``, the decorator calls ``obj.rotated_grid_from_reference_frame_from`` |
| 43 | + on the result after evaluation, applying the inverse rotation by the profile's position angle. |
| 44 | +
|
| 45 | + For **spin-2** quantities (shear), the transformation law uses twice the profile angle. This |
| 46 | + is not handled by ``rotate_back`` — shear methods must apply the 2-theta rotation manually. |
| 47 | +
|
| 48 | + Parameters |
| 49 | + ---------- |
| 50 | + func |
| 51 | + A function where the input grid is the grid whose coordinates are transformed. |
| 52 | + rotate_back |
| 53 | + If ``True``, the result is rotated back from the profile's reference frame after |
| 54 | + evaluation. Use this when the decorated function returns vector components that were |
| 55 | + computed in the profile's rotated coordinate basis and need to be expressed in the |
| 56 | + original observer frame. |
| 57 | +
|
| 58 | + Returns |
| 59 | + ------- |
| 60 | + A function that can accept cartesian or transformed coordinates |
| 61 | + """ |
| 62 | + |
| 63 | + def decorator(func): |
| 64 | + @wraps(func) |
| 65 | + def wrapper( |
| 66 | + obj: object, |
| 67 | + grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D], |
| 68 | + xp=np, |
| 69 | + *args, |
| 70 | + **kwargs, |
| 71 | + ) -> Union[np.ndarray, Grid2D, Grid2DIrregular]: |
| 72 | + """ |
| 73 | + This decorator checks whether the input grid has been transformed to the reference frame of the class |
| 74 | + that owns the function. If it has not been transformed, it is transformed. |
| 75 | +
|
| 76 | + The transform state is tracked via the ``is_transformed`` property on the grid object itself. |
| 77 | + When a decorated function calls another decorated function with the same (already-transformed) |
| 78 | + grid, the flag prevents the grid from being transformed a second time. |
| 79 | +
|
| 80 | + Parameters |
| 81 | + ---------- |
| 82 | + obj |
| 83 | + An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid. |
| 84 | + grid |
| 85 | + The (y, x) coordinates in the original reference frame of the grid. |
| 86 | +
|
| 87 | + Returns |
| 88 | + ------- |
| 89 | + A grid_like object whose coordinates may be transformed. |
| 90 | + """ |
| 91 | + |
| 92 | + if not getattr(grid, "is_transformed", False): |
| 93 | + transformed_grid = obj.transformed_to_reference_frame_grid_from( |
| 94 | + grid, xp, **kwargs |
| 95 | + ) |
| 96 | + transformed_grid.is_transformed = True |
| 97 | + |
| 98 | + result = func(obj, transformed_grid, xp, *args, **kwargs) |
| 99 | + |
| 100 | + else: |
| 101 | + result = func(obj, grid, xp, *args, **kwargs) |
| 102 | + |
| 103 | + if rotate_back: |
| 104 | + result = obj.rotated_grid_from_reference_frame_from( |
| 105 | + grid=result, xp=xp |
| 106 | + ) |
| 107 | + |
| 108 | + return result |
| 109 | + |
| 110 | + return wrapper |
| 111 | + |
| 112 | + if func is not None: |
| 113 | + # Called without arguments: @transform |
| 114 | + return decorator(func) |
| 115 | + |
| 116 | + # Called with arguments: @transform(rotate_back=True) |
| 117 | + return decorator |
0 commit comments