Skip to content

Commit a4ce9db

Browse files
Jammy2211Jammy2211
authored andcommitted
update extracted_array_2d_from to not use numba
1 parent ec8423a commit a4ce9db

11 files changed

Lines changed: 320 additions & 69 deletions

File tree

autoarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .layout.layout import Layout2D
6464
from .structures.arrays.uniform_1d import Array1D
6565
from .structures.arrays.uniform_2d import Array2D
66+
from .structures.arrays.rgb import Array2DRGB
6667
from .structures.arrays.irregular import ArrayIrregular
6768
from .structures.grids.uniform_1d import Grid1D
6869
from .structures.grids.uniform_2d import Grid2D

autoarray/abstract_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def output_to_fits(self, file_path: str, overwrite: bool = False):
282282
If a file already exists at the path, if overwrite=True it is overwritten else an error is raised.
283283
"""
284284
output_to_fits(
285-
values=self.native.array,
285+
values=self.native.array.astype("float"),
286286
file_path=file_path,
287287
overwrite=overwrite,
288288
header_dict=self.mask.header_dict,

autoarray/fixtures.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def make_array_1d_7():
6767
def make_array_2d_7x7():
6868
return aa.Array2D.ones(shape_native=(7, 7), pixel_scales=(1.0, 1.0))
6969

70+
def make_array_2d_rgb_7x7():
71+
return aa.Array2DRGB(values=np.ones((7, 7, 3)), mask=make_mask_2d_7x7())
72+
7073

7174
def make_layout_2d_7x7():
7275
return aa.Layout2D(

autoarray/inversion/inversion/abstract.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def operated_mapping_matrix(self) -> np.ndarray:
317317
If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous
318318
linear equations are solved simultaneously.
319319
"""
320-
return jnp.hstack(self.operated_mapping_matrix_list)
320+
return np.hstack(self.operated_mapping_matrix_list)
321321

322322
@cached_property
323323
@profile_func
@@ -446,6 +446,7 @@ def mapper_zero_pixel_list(self) -> np.ndarray:
446446
)
447447
return mapper_zero_pixel_list
448448

449+
449450
@cached_property
450451
@profile_func
451452
def reconstruction(self) -> np.ndarray:
@@ -495,15 +496,16 @@ def reconstruction(self) -> np.ndarray:
495496
values_to_solve, :
496497
][:, values_to_solve]
497498

498-
solutions = inversion_util.reconstruction_positive_only_from(
499-
data_vector=data_vector_input,
500-
curvature_reg_matrix=curvature_reg_matrix_input,
501-
settings=self.settings,
502-
)
503-
504-
mask = values_to_solve.astype(bool)
499+
solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0])
505500

506-
return solutions[mask]
501+
solutions[values_to_solve] = (
502+
inversion_util.reconstruction_positive_only_from(
503+
data_vector=data_vector_input,
504+
curvature_reg_matrix=curvature_reg_matrix_input,
505+
settings=self.settings,
506+
)
507+
)
508+
return solutions
507509
else:
508510
solutions = inversion_util.reconstruction_positive_only_from(
509511
data_vector=self.data_vector,
@@ -521,6 +523,81 @@ def reconstruction(self) -> np.ndarray:
521523
mapper_param_range_list=mapper_param_range_list,
522524
)
523525

526+
# @cached_property
527+
# @profile_func
528+
# def reconstruction(self) -> np.ndarray:
529+
# """
530+
# Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12)
531+
# of https://arxiv.org/pdf/astro-ph/0302587.pdf (Positive-Negative solution)
532+
#
533+
# ============================================================================================
534+
#
535+
# Solve the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf (Non-negative solution)
536+
# Find non-negative solution that minimizes |Z * S - x|^2.
537+
#
538+
# We use fnnls (https://github.com/jvendrow/fnnls) to optimize the quadratic value. Two commonly used
539+
# variables in the code are defined as follows:
540+
# ZTZ := np.dot(Z.T, Z)
541+
# ZTx := np.dot(Z.T, x)
542+
# """
543+
# if self.settings.use_positive_only_solver:
544+
# """
545+
# For the new implementation, we now need to take out the cols and rows of
546+
# the curvature_reg_matrix that corresponds to the parameters we force to be 0.
547+
# Similar for the data vector.
548+
#
549+
# What we actually doing is that we have set the correspoding cols of the Z to be 0.
550+
# As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out.
551+
# And the data_vector = ZTx, so the corresponding row is also taken out.
552+
# """
553+
#
554+
# if self.settings.force_edge_pixels_to_zeros:
555+
# if self.settings.force_edge_image_pixels_to_zeros:
556+
# ids_zeros = np.unique(
557+
# np.append(
558+
# self.mapper_edge_pixel_list, self.mapper_zero_pixel_list
559+
# )
560+
# )
561+
# else:
562+
# ids_zeros = self.mapper_edge_pixel_list
563+
#
564+
# values_to_solve = np.ones(
565+
# np.shape(self.curvature_reg_matrix)[0], dtype=bool
566+
# )
567+
# values_to_solve[ids_zeros] = False
568+
#
569+
# data_vector_input = self.data_vector[values_to_solve]
570+
#
571+
# curvature_reg_matrix_input = self.curvature_reg_matrix[
572+
# values_to_solve, :
573+
# ][:, values_to_solve]
574+
#
575+
# solutions = inversion_util.reconstruction_positive_only_from(
576+
# data_vector=data_vector_input,
577+
# curvature_reg_matrix=curvature_reg_matrix_input,
578+
# settings=self.settings,
579+
# )
580+
#
581+
# mask = values_to_solve.astype(bool)
582+
#
583+
# return solutions[mask]
584+
# else:
585+
# solutions = inversion_util.reconstruction_positive_only_from(
586+
# data_vector=self.data_vector,
587+
# curvature_reg_matrix=self.curvature_reg_matrix,
588+
# settings=self.settings,
589+
# )
590+
#
591+
# return solutions
592+
#
593+
# mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)
594+
#
595+
# return inversion_util.reconstruction_positive_negative_from(
596+
# data_vector=self.data_vector,
597+
# curvature_reg_matrix=self.curvature_reg_matrix,
598+
# mapper_param_range_list=mapper_param_range_list,
599+
# )
600+
524601
@cached_property
525602
@profile_func
526603
def reconstruction_reduced(self) -> np.ndarray:

autoarray/inversion/inversion/inversion_util.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def curvature_matrix_via_w_tilde_from(
4141
return np.dot(mapping_matrix.T, np.dot(w_tilde, mapping_matrix))
4242

4343

44+
@numba_util.jit()
4445
def curvature_matrix_with_added_to_diag_from(
4546
curvature_matrix: np.ndarray,
4647
value: float,
@@ -60,9 +61,35 @@ def curvature_matrix_with_added_to_diag_from(
6061
curvature_matrix
6162
The curvature matrix which is being constructed in order to solve a linear system of equations.
6263
"""
63-
return curvature_matrix.at[
64-
no_regularization_index_list, no_regularization_index_list
65-
].add(value)
64+
65+
for i in no_regularization_index_list:
66+
curvature_matrix[i, i] += value
67+
68+
return curvature_matrix
69+
70+
71+
# def curvature_matrix_with_added_to_diag_from(
72+
# curvature_matrix: np.ndarray,
73+
# value: float,
74+
# no_regularization_index_list: Optional[List] = None,
75+
# ) -> np.ndarray:
76+
# """
77+
# It is common for the `curvature_matrix` computed to not be positive-definite, leading for the inversion
78+
# via `np.linalg.solve` to fail and raise a `LinAlgError`.
79+
#
80+
# In many circumstances, adding a small numerical value of `1.0e-8` to the diagonal of the `curvature_matrix`
81+
# makes it positive definite, such that the inversion is performed without raising an error.
82+
#
83+
# This function adds this numerical value to the diagonal of the curvature matrix.
84+
#
85+
# Parameters
86+
# ----------
87+
# curvature_matrix
88+
# The curvature matrix which is being constructed in order to solve a linear system of equations.
89+
# """
90+
# return curvature_matrix.at[
91+
# no_regularization_index_list, no_regularization_index_list
92+
# ].add(value)
6693

6794

6895
@numba_util.jit()
@@ -105,7 +132,7 @@ def curvature_matrix_via_mapping_matrix_from(
105132
Flattened 1D array of the noise-map used by the inversion during the fit.
106133
"""
107134
array = mapping_matrix / noise_map[:, None]
108-
curvature_matrix = jnp.dot(array.T, array)
135+
curvature_matrix = np.dot(array.T, array)
109136

110137
if add_to_curvature_diag and len(no_regularization_index_list) > 0:
111138
curvature_matrix = curvature_matrix_with_added_to_diag_from(
@@ -161,7 +188,7 @@ def mapped_reconstructed_data_via_mapping_matrix_from(
161188
The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels.
162189
163190
"""
164-
return jnp.dot(mapping_matrix, reconstruction)
191+
return np.dot(mapping_matrix, reconstruction)
165192

166193

167194
def reconstruction_positive_negative_from(
@@ -272,31 +299,31 @@ def reconstruction_positive_only_from(
272299
Non-negative S that minimizes the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf.
273300
"""
274301

275-
try:
276-
return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector)
277-
except (RuntimeError, np.linalg.LinAlgError, ValueError) as e:
278-
raise exc.InversionException() from e
302+
# try:
303+
# return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector)
304+
# except (RuntimeError, np.linalg.LinAlgError, ValueError) as e:
305+
# raise exc.InversionException() from e
306+
307+
if len(data_vector):
308+
try:
309+
if settings.positive_only_uses_p_initial:
310+
P_initial = np.linalg.solve(curvature_reg_matrix, data_vector) > 0
311+
else:
312+
P_initial = np.zeros(0, dtype=int)
313+
314+
reconstruction = fnnls_cholesky(
315+
curvature_reg_matrix,
316+
(data_vector).T,
317+
P_initial=P_initial,
318+
)
279319

280-
# if len(data_vector):
281-
# try:
282-
# if settings.positive_only_uses_p_initial:
283-
# P_initial = np.linalg.solve(curvature_reg_matrix, data_vector) > 0
284-
# else:
285-
# P_initial = np.zeros(0, dtype=int)
286-
#
287-
# reconstruction = fnnls_cholesky(
288-
# curvature_reg_matrix,
289-
# (data_vector).T,
290-
# P_initial=P_initial,
291-
# )
292-
#
293-
# except (RuntimeError, np.linalg.LinAlgError, ValueError) as e:
294-
# raise exc.InversionException() from e
295-
#
296-
# else:
297-
# raise exc.InversionException()
298-
#
299-
# return reconstruction
320+
except (RuntimeError, np.linalg.LinAlgError, ValueError) as e:
321+
raise exc.InversionException() from e
322+
323+
else:
324+
raise exc.InversionException()
325+
326+
return reconstruction
300327

301328

302329
def preconditioner_matrix_via_mapping_matrix_from(

autoarray/mask/derive/zoom_2d.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
if TYPE_CHECKING:
66
from autoarray.structures.arrays.uniform_2d import Array2D
7+
from autoarray.structures.arrays.rgb import Array2DRGB
78

89
from autoarray.structures.arrays import array_2d_util
910
from autoarray.structures.grids import grid_2d_util
@@ -229,7 +230,7 @@ def mask_2d_from(self, buffer: int = 1) -> "Mask2D":
229230
origin=self.mask.origin,
230231
)
231232

232-
def array_2d_from(self, array : Array2D, buffer: int = 1) -> Array2D:
233+
def array_2d_from(self, array: Array2D, buffer: int = 1) -> Array2D:
233234
"""
234235
Extract the 2D region of an array corresponding to the rectangle encompassing all unmasked values.
235236
@@ -241,24 +242,82 @@ def array_2d_from(self, array : Array2D, buffer: int = 1) -> Array2D:
241242
The number pixels around the extracted array used as a buffer.
242243
"""
243244
from autoarray.structures.arrays.uniform_2d import Array2D
245+
from autoarray.structures.arrays.rgb import Array2DRGB
244246
from autoarray.mask.mask_2d import Mask2D
245247

248+
if isinstance(array, Array2DRGB):
249+
return self.array_2d_rgb_from(array=array, buffer=buffer)
250+
246251
extracted_array_2d = array_2d_util.extracted_array_2d_from(
247-
array_2d=np.array(array.native),
252+
array_2d=array.native.array,
248253
y0=self.region[0] - buffer,
249254
y1=self.region[1] + buffer,
250255
x0=self.region[2] - buffer,
251256
x1=self.region[3] + buffer,
252257
)
253258

254-
mask = Mask2D.all_false(
255-
shape_native=extracted_array_2d.shape,
259+
extracted_mask_2d = array_2d_util.extracted_array_2d_from(
260+
array_2d=np.array(self.mask),
261+
y0=self.region[0] - buffer,
262+
y1=self.region[1] + buffer,
263+
x0=self.region[2] - buffer,
264+
x1=self.region[3] + buffer,
265+
)
266+
267+
mask = Mask2D(
268+
mask=extracted_mask_2d,
256269
pixel_scales=array.pixel_scales,
257270
origin=array.mask.mask_centre,
258271
)
259272

260-
arr = array_2d_util.convert_array_2d(
261-
array_2d=extracted_array_2d, mask_2d=mask
273+
arr = array_2d_util.convert_array_2d(array_2d=extracted_array_2d, mask_2d=mask)
274+
275+
return Array2D(values=arr, mask=mask, header=array.header)
276+
277+
def array_2d_rgb_from(self, array: Array2DRGB, buffer: int = 1) -> Array2DRGB:
278+
"""
279+
Extract the 2D region of an RGB array corresponding to the rectangle encompassing all unmasked values.
280+
281+
This works the same as the `array_2d_from` method, but for RGB arrays, meaning that it iterates over the three
282+
channels of the RGB array and extracts the region for each channel separately.
283+
284+
This is used to extract and visualize only the region of an RGB image that is used in an analysis.
285+
286+
Parameters
287+
----------
288+
buffer
289+
The number pixels around the extracted array used as a buffer.
290+
"""
291+
from autoarray.structures.arrays.rgb import Array2DRGB
292+
from autoarray.mask.mask_2d import Mask2D
293+
294+
for i in range(3):
295+
296+
extracted_array_2d = array_2d_util.extracted_array_2d_from(
297+
array_2d=np.array(array.native[:, :, i]),
298+
y0=self.region[0] - buffer,
299+
y1=self.region[1] + buffer,
300+
x0=self.region[2] - buffer,
301+
x1=self.region[3] + buffer,
302+
)
303+
304+
if i == 0:
305+
array_2d_rgb = np.zeros((extracted_array_2d.shape[0], extracted_array_2d.shape[1], 3))
306+
307+
array_2d_rgb[:, :, i] = extracted_array_2d
308+
309+
extracted_mask_2d = array_2d_util.extracted_array_2d_from(
310+
array_2d=np.array(self.mask),
311+
y0=self.region[0] - buffer,
312+
y1=self.region[1] + buffer,
313+
x0=self.region[2] - buffer,
314+
x1=self.region[3] + buffer,
315+
)
316+
317+
mask = Mask2D(
318+
mask=extracted_mask_2d,
319+
pixel_scales=array.pixel_scales,
320+
origin=array.mask.mask_centre,
262321
)
263322

264-
return Array2D(values=arr, mask=mask, header=array.header)
323+
return Array2DRGB(values=array_2d_rgb.astype("int"), mask=mask)

0 commit comments

Comments
 (0)