Skip to content

Commit 7f869aa

Browse files
committed
mask circular converts and some aspect simplified
1 parent 64d16b6 commit 7f869aa

13 files changed

Lines changed: 89 additions & 128 deletions

File tree

autoarray/geometry/geometry_util.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float]
182182

183183
@numba_util.jit()
184184
def central_pixel_coordinates_2d_from(
185-
shape_native: Tuple[int, int]
185+
shape_native: Tuple[int, int],
186186
) -> Tuple[float, float]:
187187
"""
188188
Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``)
@@ -477,7 +477,6 @@ def grid_pixels_2d_slim_from(
477477
pixel_scales=(0.5, 0.5), origin=(0.0, 0.0))
478478
"""
479479

480-
481480
centres_scaled = central_scaled_coordinate_2d_from(
482481
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
483482
)
@@ -544,7 +543,6 @@ def grid_pixel_centres_2d_slim_from(
544543
pixel_scales=(0.5, 0.5), origin=(0.0, 0.0))
545544
"""
546545

547-
548546
centres_scaled = central_scaled_coordinate_2d_from(
549547
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
550548
)
@@ -629,8 +627,10 @@ def grid_pixel_indexes_2d_slim_from(
629627

630628
if use_jax:
631629
grid_pixel_indexes_2d_slim = (
632-
grid_pixels_2d_slim * np.array([shape_native[1], 1])
633-
).sum(axis=1).astype(int)
630+
(grid_pixels_2d_slim * np.array([shape_native[1], 1]))
631+
.sum(axis=1)
632+
.astype(int)
633+
)
634634
else:
635635
grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0])
636636

@@ -690,7 +690,9 @@ def grid_scaled_2d_slim_from(
690690
centres_scaled = np.array(centres_scaled)
691691
pixel_scales = np.array(pixel_scales)
692692
sign = np.array([-1, 1])
693-
grid_scaled_2d_slim = (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign
693+
grid_scaled_2d_slim = (
694+
(grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign
695+
)
694696
else:
695697
grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2))
696698

@@ -755,7 +757,7 @@ def grid_pixel_centres_2d_from(
755757
centres_scaled = np.array(centres_scaled)
756758
pixel_scales = np.array(pixel_scales)
757759
sign = np.array([-1.0, 1.0])
758-
grid_pixels_2d = (
760+
grid_pixels_2d = (
759761
(sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5
760762
).astype(int)
761763
else:
@@ -764,17 +766,21 @@ def grid_pixel_centres_2d_from(
764766
for y in range(grid_scaled_2d.shape[0]):
765767
for x in range(grid_scaled_2d.shape[1]):
766768
grid_pixels_2d[y, x, 0] = int(
767-
(-grid_scaled_2d[y, x, 0] / pixel_scales[0]) + centres_scaled[0] + 0.5
769+
(-grid_scaled_2d[y, x, 0] / pixel_scales[0])
770+
+ centres_scaled[0]
771+
+ 0.5
768772
)
769773
grid_pixels_2d[y, x, 1] = int(
770-
(grid_scaled_2d[y, x, 1] / pixel_scales[1]) + centres_scaled[1] + 0.5
774+
(grid_scaled_2d[y, x, 1] / pixel_scales[1])
775+
+ centres_scaled[1]
776+
+ 0.5
771777
)
772778

773779
return grid_pixels_2d
774780

775781

776782
def extent_symmetric_from(
777-
extent: Tuple[float, float, float, float]
783+
extent: Tuple[float, float, float, float],
778784
) -> Tuple[float, float, float, float]:
779785
"""
780786
Given an input extent of the form (x_min, x_max, y_min, y_max), this function returns an extent which is

autoarray/inversion/pixelization/border_relocator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def sub_slim_indexes_for_slim_index_via_mask_2d_from(
4848
sub_mask_1d_indexes_for_mask_1d_index = sub_mask_1d_indexes_for_mask_1d_index_from(mask=mask, sub_size=2)
4949
"""
5050

51-
total_pixels = mask_2d_util.total_pixels_2d_from(mask_2d=mask_2d)
51+
total_pixels = np.sum(~mask_2d)
5252

5353
sub_slim_indexes_for_slim_index = [[] for _ in range(total_pixels)]
5454

autoarray/inversion/pixelization/mesh/mesh_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@numba_util.jit()
99
def rectangular_neighbors_from(
10-
shape_native: Tuple[int, int]
10+
shape_native: Tuple[int, int],
1111
) -> Tuple[np.ndarray, np.ndarray]:
1212
"""
1313
Returns the 4 (or less) adjacent neighbors of every pixel on a rectangular pixelization as an ndarray of shape

autoarray/mask/abstract_mask.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55

66
from autoarray.numpy_wrapper import np, use_jax
7+
78
if use_jax:
89
import jax
910
from pathlib import Path

autoarray/mask/mask_2d_util.py

Lines changed: 37 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,118 +8,78 @@
88
from autoarray.numpy_wrapper import use_jax, np as jnp
99

1010

11-
@numba_util.jit()
1211
def mask_2d_centres_from(
13-
shape_native: Tuple[int, int],
14-
pixel_scales: ty.PixelScales,
15-
centre: Tuple[float, float],
16-
) -> Tuple[float, float]:
12+
shape_native: tuple[int, int],
13+
pixel_scales: tuple[float, float],
14+
centre: tuple[float, float],
15+
) -> tuple[float, float]:
1716
"""
18-
Returns the (y,x) scaled central coordinates of a mask from its shape, pixel-scales and centre.
17+
Compute the (y, x) scaled central coordinates of a mask given its shape, pixel-scales, and centre.
1918
20-
The coordinate system is defined such that the positive y axis is up and positive x axis is right.
19+
The coordinate system is defined such that the positive y-axis is up and the positive x-axis is right.
2120
2221
Parameters
2322
----------
2423
shape_native
25-
The (y,x) shape of the 2D array the scaled centre is computed for.
24+
The shape of the 2D array in pixels.
2625
pixel_scales
27-
The (y,x) scaled units to pixel units conversion factor of the 2D array.
28-
centre : (float, flloat)
29-
The (y,x) centre of the 2D mask.
30-
31-
Returns
32-
-------
33-
tuple (float, float)
34-
The (y,x) scaled central coordinates of the input array.
35-
36-
Examples
37-
--------
38-
centres_scaled = centres_from(shape=(5,5), pixel_scales=(0.5, 0.5), centre=(0.0, 0.0))
39-
"""
40-
y_centre_scaled = (float(shape_native[0] - 1) / 2) - (centre[0] / pixel_scales[0])
41-
x_centre_scaled = (float(shape_native[1] - 1) / 2) + (centre[1] / pixel_scales[1])
42-
43-
return (y_centre_scaled, x_centre_scaled)
44-
45-
46-
@numba_util.jit()
47-
def total_pixels_2d_from(mask_2d: np.ndarray) -> int:
48-
"""
49-
Returns the total number of unmasked pixels in a mask.
50-
51-
Parameters
52-
----------
53-
mask_2d
54-
A 2D array of bools, where `False` values are unmasked and included when counting pixels.
26+
The conversion factors from pixels to scaled units.
27+
centre
28+
The central coordinate of the mask in scaled units.
5529
5630
Returns
5731
-------
58-
int
59-
The total number of pixels that are unmasked.
32+
The (y, x) scaled central coordinates of the input array.
6033
6134
Examples
6235
--------
63-
64-
mask = np.array([[True, False, True],
65-
[False, False, False]
66-
[True, False, True]])
67-
68-
total_regular_pixels = total_regular_pixels_from(mask=mask)
36+
centres_scaled = mask_2d_centres_from(shape_native=(5, 5), pixel_scales=(0.5, 0.5), centre=(0.0, 0.0))
6937
"""
70-
if use_jax:
71-
return (~mask_2d.astype(bool)).sum()
72-
73-
else:
74-
total_regular_pixels = 0
75-
76-
for y in range(mask_2d.shape[0]):
77-
for x in range(mask_2d.shape[1]):
78-
if not mask_2d[y, x]:
79-
total_regular_pixels += 1
80-
81-
return total_regular_pixels
38+
return (
39+
0.5 * (shape_native[0] - 1) - (centre[0] / pixel_scales[0]),
40+
0.5 * (shape_native[1] - 1) + (centre[1] / pixel_scales[1]),
41+
)
8242

8343

84-
@numba_util.jit()
8544
def mask_2d_circular_from(
86-
shape_native: Tuple[int, int],
87-
pixel_scales: ty.PixelScales,
45+
shape_native: tuple[int, int],
46+
pixel_scales: tuple[float, float],
8847
radius: float,
89-
centre: Tuple[float, float] = (0.0, 0.0),
48+
centre: tuple[float, float] = (0.0, 0.0),
9049
) -> np.ndarray:
9150
"""
92-
Returns a circular mask from the 2D mask array shape and radius of the circle.
51+
Create a circular mask within a 2D array.
9352
94-
This creates a 2D array where all values within the mask radius are unmasked and therefore `False`.
53+
This generates a 2D array where all values within the specified radius are unmasked (set to `False`).
9554
9655
Parameters
9756
----------
98-
shape_native: Tuple[int, int]
99-
The (y,x) shape of the mask in units of pixels.
57+
shape_native
58+
The shape of the mask array in pixels.
10059
pixel_scales
101-
The scaled units to pixel units conversion factor of each pixel.
60+
The conversion factors from pixels to scaled units.
10261
radius
103-
The radius (in scaled units) of the circle within which pixels unmasked.
62+
The radius of the circular mask in scaled units.
10463
centre
105-
The centre of the circle used to mask pixels.
64+
The central coordinate of the circle in scaled units.
10665
10766
Returns
10867
-------
109-
ndarray
110-
The 2D mask array whose central pixels are masked as a circle.
68+
The 2D mask array with the central region defined by the radius unmasked (False).
11169
11270
Examples
11371
--------
114-
mask = mask_circular_from(
115-
shape=(10, 10), pixel_scales=0.1, radius=0.5, centre=(0.0, 0.0))
72+
mask = mask_2d_circular_from(shape_native=(10, 10), pixel_scales=(0.1, 0.1), radius=0.5, centre=(0.0, 0.0))
11673
"""
11774
centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre)
118-
ys, xs = np.indices(shape_native)
119-
return (radius * radius) < (
120-
np.square((ys - centres_scaled[0]) * pixel_scales[0]) +
121-
np.square((xs - centres_scaled[1]) * pixel_scales[1])
122-
)
75+
76+
y, x = np.ogrid[: shape_native[0], : shape_native[1]]
77+
y_scaled = (y - centres_scaled[0]) * pixel_scales[0]
78+
x_scaled = (x - centres_scaled[1]) * pixel_scales[1]
79+
80+
distances_squared = x_scaled**2 + y_scaled**2
81+
82+
return distances_squared >= radius**2
12383

12484

12585
@numba_util.jit()
@@ -1047,7 +1007,7 @@ def native_index_for_slim_index_2d_from(
10471007
if use_jax:
10481008
return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T
10491009
else:
1050-
total_pixels = total_pixels_2d_from(mask_2d=mask_2d)
1010+
total_pixels = np.sum(~mask_2d)
10511011
native_index_for_slim_index_2d = np.zeros(shape=(total_pixels, 2))
10521012
slim_index = 0
10531013

autoarray/operators/contour.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def contour_array(self):
5858
@property
5959
def contour_list(self):
6060
# make sure to use base numpy to convert JAX array back to a normal array
61-
contour_indices_list = measure.find_contours(numpy.array(self.contour_array.array), 0)
61+
contour_indices_list = measure.find_contours(
62+
numpy.array(self.contour_array.array), 0
63+
)
6264

6365
if len(contour_indices_list) == 0:
6466
return []

autoarray/operators/over_sampling/over_sample_util.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,7 @@ def binned_array_2d_from(
528528
grid_slim = grid_2d_slim_over_sampled_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), sub_size=1, origin=(0.0, 0.0))
529529
"""
530530

531-
total_pixels = mask_2d_util.total_pixels_2d_from(
532-
mask_2d=mask_2d,
533-
)
531+
total_pixels = np.sum(~mask_2d)
534532

535533
sub_fraction = 1.0 / sub_size**2
536534

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from autofit.jax_wrapper import register_pytree_node_class
1313

14+
1415
@register_pytree_node_class
1516
class OverSampler:
1617
def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]):

autoarray/plot/multi_plotters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def output_to_fits(
315315
output_path = self.plotter_list[0].mat_plot_2d.output.output_path_from(
316316
format="fits_multi"
317317
)
318-
output_fits_file = Path(output_path)/ f"{filename}.fits"
318+
output_fits_file = Path(output_path) / f"{filename}.fits"
319319

320320
if remove_fits_first:
321321
output_fits_file.unlink(missing_ok=True)

0 commit comments

Comments
 (0)