|
8 | 8 | from autoarray.numpy_wrapper import use_jax, np as jnp |
9 | 9 |
|
10 | 10 |
|
11 | | -@numba_util.jit() |
12 | 11 | def mask_2d_centres_from( |
13 | | - shape_native: Tuple[int, int], |
14 | | - pixel_scales: ty.PixelScales, |
15 | | - centre: Tuple[float, float], |
16 | | -) -> Tuple[float, float]: |
| 12 | + shape_native: tuple[int, int], |
| 13 | + pixel_scales: tuple[float, float], |
| 14 | + centre: tuple[float, float], |
| 15 | +) -> tuple[float, float]: |
17 | 16 | """ |
18 | | - Returns the (y,x) scaled central coordinates of a mask from its shape, pixel-scales and centre. |
| 17 | + Compute the (y, x) scaled central coordinates of a mask given its shape, pixel-scales, and centre. |
19 | 18 |
|
20 | | - The coordinate system is defined such that the positive y axis is up and positive x axis is right. |
| 19 | + The coordinate system is defined such that the positive y-axis is up and the positive x-axis is right. |
21 | 20 |
|
22 | 21 | Parameters |
23 | 22 | ---------- |
24 | 23 | shape_native |
25 | | - The (y,x) shape of the 2D array the scaled centre is computed for. |
| 24 | + The shape of the 2D array in pixels. |
26 | 25 | pixel_scales |
27 | | - The (y,x) scaled units to pixel units conversion factor of the 2D array. |
28 | | - centre : (float, flloat) |
29 | | - The (y,x) centre of the 2D mask. |
30 | | -
|
31 | | - Returns |
32 | | - ------- |
33 | | - tuple (float, float) |
34 | | - The (y,x) scaled central coordinates of the input array. |
35 | | -
|
36 | | - Examples |
37 | | - -------- |
38 | | - centres_scaled = centres_from(shape=(5,5), pixel_scales=(0.5, 0.5), centre=(0.0, 0.0)) |
39 | | - """ |
40 | | - y_centre_scaled = (float(shape_native[0] - 1) / 2) - (centre[0] / pixel_scales[0]) |
41 | | - x_centre_scaled = (float(shape_native[1] - 1) / 2) + (centre[1] / pixel_scales[1]) |
42 | | - |
43 | | - return (y_centre_scaled, x_centre_scaled) |
44 | | - |
45 | | - |
46 | | -@numba_util.jit() |
47 | | -def total_pixels_2d_from(mask_2d: np.ndarray) -> int: |
48 | | - """ |
49 | | - Returns the total number of unmasked pixels in a mask. |
50 | | -
|
51 | | - Parameters |
52 | | - ---------- |
53 | | - mask_2d |
54 | | - A 2D array of bools, where `False` values are unmasked and included when counting pixels. |
| 26 | + The conversion factors from pixels to scaled units. |
| 27 | + centre |
| 28 | + The central coordinate of the mask in scaled units. |
55 | 29 |
|
56 | 30 | Returns |
57 | 31 | ------- |
58 | | - int |
59 | | - The total number of pixels that are unmasked. |
| 32 | + The (y, x) scaled central coordinates of the input array. |
60 | 33 |
|
61 | 34 | Examples |
62 | 35 | -------- |
63 | | -
|
64 | | - mask = np.array([[True, False, True], |
65 | | - [False, False, False] |
66 | | - [True, False, True]]) |
67 | | -
|
68 | | - total_regular_pixels = total_regular_pixels_from(mask=mask) |
| 36 | + centres_scaled = mask_2d_centres_from(shape_native=(5, 5), pixel_scales=(0.5, 0.5), centre=(0.0, 0.0)) |
69 | 37 | """ |
70 | | - if use_jax: |
71 | | - return (~mask_2d.astype(bool)).sum() |
72 | | - |
73 | | - else: |
74 | | - total_regular_pixels = 0 |
75 | | - |
76 | | - for y in range(mask_2d.shape[0]): |
77 | | - for x in range(mask_2d.shape[1]): |
78 | | - if not mask_2d[y, x]: |
79 | | - total_regular_pixels += 1 |
80 | | - |
81 | | - return total_regular_pixels |
| 38 | + return ( |
| 39 | + 0.5 * (shape_native[0] - 1) - (centre[0] / pixel_scales[0]), |
| 40 | + 0.5 * (shape_native[1] - 1) + (centre[1] / pixel_scales[1]), |
| 41 | + ) |
82 | 42 |
|
83 | 43 |
|
84 | | -@numba_util.jit() |
85 | 44 | def mask_2d_circular_from( |
86 | | - shape_native: Tuple[int, int], |
87 | | - pixel_scales: ty.PixelScales, |
| 45 | + shape_native: tuple[int, int], |
| 46 | + pixel_scales: tuple[float, float], |
88 | 47 | radius: float, |
89 | | - centre: Tuple[float, float] = (0.0, 0.0), |
| 48 | + centre: tuple[float, float] = (0.0, 0.0), |
90 | 49 | ) -> np.ndarray: |
91 | 50 | """ |
92 | | - Returns a circular mask from the 2D mask array shape and radius of the circle. |
| 51 | + Create a circular mask within a 2D array. |
93 | 52 |
|
94 | | - This creates a 2D array where all values within the mask radius are unmasked and therefore `False`. |
| 53 | + This generates a 2D array where all values within the specified radius are unmasked (set to `False`). |
95 | 54 |
|
96 | 55 | Parameters |
97 | 56 | ---------- |
98 | | - shape_native: Tuple[int, int] |
99 | | - The (y,x) shape of the mask in units of pixels. |
| 57 | + shape_native |
| 58 | + The shape of the mask array in pixels. |
100 | 59 | pixel_scales |
101 | | - The scaled units to pixel units conversion factor of each pixel. |
| 60 | + The conversion factors from pixels to scaled units. |
102 | 61 | radius |
103 | | - The radius (in scaled units) of the circle within which pixels unmasked. |
| 62 | + The radius of the circular mask in scaled units. |
104 | 63 | centre |
105 | | - The centre of the circle used to mask pixels. |
| 64 | + The central coordinate of the circle in scaled units. |
106 | 65 |
|
107 | 66 | Returns |
108 | 67 | ------- |
109 | | - ndarray |
110 | | - The 2D mask array whose central pixels are masked as a circle. |
| 68 | + The 2D mask array with the central region defined by the radius unmasked (False). |
111 | 69 |
|
112 | 70 | Examples |
113 | 71 | -------- |
114 | | - mask = mask_circular_from( |
115 | | - shape=(10, 10), pixel_scales=0.1, radius=0.5, centre=(0.0, 0.0)) |
| 72 | + mask = mask_2d_circular_from(shape_native=(10, 10), pixel_scales=(0.1, 0.1), radius=0.5, centre=(0.0, 0.0)) |
116 | 73 | """ |
117 | 74 | centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) |
118 | | - ys, xs = np.indices(shape_native) |
119 | | - return (radius * radius) < ( |
120 | | - np.square((ys - centres_scaled[0]) * pixel_scales[0]) + |
121 | | - np.square((xs - centres_scaled[1]) * pixel_scales[1]) |
122 | | - ) |
| 75 | + |
| 76 | + y, x = np.ogrid[: shape_native[0], : shape_native[1]] |
| 77 | + y_scaled = (y - centres_scaled[0]) * pixel_scales[0] |
| 78 | + x_scaled = (x - centres_scaled[1]) * pixel_scales[1] |
| 79 | + |
| 80 | + distances_squared = x_scaled**2 + y_scaled**2 |
| 81 | + |
| 82 | + return distances_squared >= radius**2 |
123 | 83 |
|
124 | 84 |
|
125 | 85 | @numba_util.jit() |
@@ -1047,7 +1007,7 @@ def native_index_for_slim_index_2d_from( |
1047 | 1007 | if use_jax: |
1048 | 1008 | return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T |
1049 | 1009 | else: |
1050 | | - total_pixels = total_pixels_2d_from(mask_2d=mask_2d) |
| 1010 | + total_pixels = np.sum(~mask_2d) |
1051 | 1011 | native_index_for_slim_index_2d = np.zeros(shape=(total_pixels, 2)) |
1052 | 1012 | slim_index = 0 |
1053 | 1013 |
|
|
0 commit comments