Skip to content

Commit 107e470

Browse files
Jammy2211Jammy2211
authored andcommitted
moved over most zoom-y stuff
1 parent a625672 commit 107e470

14 files changed

Lines changed: 359 additions & 172 deletions

File tree

autoarray/inversion/inversion/abstract.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
2+
import jax.numpy as jnp
33
import numpy as np
44
from scipy.linalg import block_diag
55
from scipy.sparse import csc_matrix
@@ -317,7 +317,7 @@ def operated_mapping_matrix(self) -> np.ndarray:
317317
If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous
318318
linear equations are solved simultaneously.
319319
"""
320-
return np.hstack(self.operated_mapping_matrix_list)
320+
return jnp.hstack(self.operated_mapping_matrix_list)
321321

322322
@cached_property
323323
@profile_func
@@ -495,16 +495,17 @@ def reconstruction(self) -> np.ndarray:
495495
values_to_solve, :
496496
][:, values_to_solve]
497497

498-
solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0])
499-
500-
solutions[values_to_solve] = (
498+
solutions = (
501499
inversion_util.reconstruction_positive_only_from(
502500
data_vector=data_vector_input,
503501
curvature_reg_matrix=curvature_reg_matrix_input,
504502
settings=self.settings,
505503
)
506504
)
507-
return solutions
505+
506+
mask = values_to_solve.astype(bool)
507+
508+
return solutions[mask]
508509
else:
509510
solutions = inversion_util.reconstruction_positive_only_from(
510511
data_vector=self.data_vector,

autoarray/inversion/inversion/imaging/inversion_imaging_util.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def data_vector_via_w_tilde_data_imaging_from(
391391
return data_vector
392392

393393

394-
@numba_util.jit()
395394
def data_vector_via_blurred_mapping_matrix_from(
396395
blurred_mapping_matrix: np.ndarray, image: np.ndarray, noise_map: np.ndarray
397396
) -> np.ndarray:
@@ -408,21 +407,7 @@ def data_vector_via_blurred_mapping_matrix_from(
408407
noise_map
409408
Flattened 1D array of the noise-map used by the inversion during the fit.
410409
"""
411-
412-
data_shape = blurred_mapping_matrix.shape
413-
414-
data_vector = np.zeros(data_shape[1])
415-
416-
for data_index in range(data_shape[0]):
417-
for pix_index in range(data_shape[1]):
418-
data_vector[pix_index] += (
419-
image[data_index]
420-
* blurred_mapping_matrix[data_index, pix_index]
421-
/ (noise_map[data_index] ** 2.0)
422-
)
423-
424-
return data_vector
425-
410+
return (image / noise_map**2.0) @ blurred_mapping_matrix
426411

427412
@numba_util.jit()
428413
def curvature_matrix_via_w_tilde_curvature_preload_imaging_from(

autoarray/inversion/inversion/inversion_util.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import jax.numpy as jnp
2+
import jaxnnls
13
import numpy as np
24

35
from typing import List, Optional, Tuple
@@ -39,7 +41,6 @@ def curvature_matrix_via_w_tilde_from(
3941
return np.dot(mapping_matrix.T, np.dot(w_tilde, mapping_matrix))
4042

4143

42-
@numba_util.jit()
4344
def curvature_matrix_with_added_to_diag_from(
4445
curvature_matrix: np.ndarray,
4546
value: float,
@@ -59,11 +60,7 @@ def curvature_matrix_with_added_to_diag_from(
5960
curvature_matrix
6061
The curvature matrix which is being constructed in order to solve a linear system of equations.
6162
"""
62-
63-
for i in no_regularization_index_list:
64-
curvature_matrix[i, i] += value
65-
66-
return curvature_matrix
63+
return curvature_matrix.at[no_regularization_index_list, no_regularization_index_list].add(value)
6764

6865

6966
@numba_util.jit()
@@ -106,7 +103,7 @@ def curvature_matrix_via_mapping_matrix_from(
106103
Flattened 1D array of the noise-map used by the inversion during the fit.
107104
"""
108105
array = mapping_matrix / noise_map[:, None]
109-
curvature_matrix = np.dot(array.T, array)
106+
curvature_matrix = jnp.dot(array.T, array)
110107

111108
if add_to_curvature_diag and len(no_regularization_index_list) > 0:
112109
curvature_matrix = curvature_matrix_with_added_to_diag_from(
@@ -150,7 +147,6 @@ def mapped_reconstructed_data_via_image_to_pix_unique_from(
150147
return mapped_reconstructed_data
151148

152149

153-
@numba_util.jit()
154150
def mapped_reconstructed_data_via_mapping_matrix_from(
155151
mapping_matrix: np.ndarray, reconstruction: np.ndarray
156152
) -> np.ndarray:
@@ -163,12 +159,7 @@ def mapped_reconstructed_data_via_mapping_matrix_from(
163159
The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels.
164160
165161
"""
166-
mapped_reconstructed_data = np.zeros(mapping_matrix.shape[0])
167-
for i in range(mapping_matrix.shape[0]):
168-
for j in range(reconstruction.shape[0]):
169-
mapped_reconstructed_data[i] += reconstruction[j] * mapping_matrix[i, j]
170-
171-
return mapped_reconstructed_data
162+
return jnp.dot(mapping_matrix, reconstruction)
172163

173164

174165
def reconstruction_positive_negative_from(
@@ -279,26 +270,31 @@ def reconstruction_positive_only_from(
279270
Non-negative S that minimizes the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf.
280271
"""
281272

282-
if len(data_vector):
283-
try:
284-
if settings.positive_only_uses_p_initial:
285-
P_initial = np.linalg.solve(curvature_reg_matrix, data_vector) > 0
286-
else:
287-
P_initial = np.zeros(0, dtype=int)
288-
289-
reconstruction = fnnls_cholesky(
290-
curvature_reg_matrix,
291-
(data_vector).T,
292-
P_initial=P_initial,
293-
)
294-
295-
except (RuntimeError, np.linalg.LinAlgError, ValueError) as e:
296-
raise exc.InversionException() from e
297-
298-
else:
299-
raise exc.InversionException()
273+
try:
274+
return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector)
275+
except (RuntimeError, np.linalg.LinAlgError, ValueError) as e:
276+
raise exc.InversionException() from e
300277

301-
return reconstruction
278+
# if len(data_vector):
279+
# try:
280+
# if settings.positive_only_uses_p_initial:
281+
# P_initial = np.linalg.solve(curvature_reg_matrix, data_vector) > 0
282+
# else:
283+
# P_initial = np.zeros(0, dtype=int)
284+
#
285+
# reconstruction = fnnls_cholesky(
286+
# curvature_reg_matrix,
287+
# (data_vector).T,
288+
# P_initial=P_initial,
289+
# )
290+
#
291+
# except (RuntimeError, np.linalg.LinAlgError, ValueError) as e:
292+
# raise exc.InversionException() from e
293+
#
294+
# else:
295+
# raise exc.InversionException()
296+
#
297+
# return reconstruction
302298

303299

304300
def preconditioner_matrix_via_mapping_matrix_from(

autoarray/mask/derive/indexes_2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import logging
33
import numpy as np
44

5+
from autoconf import cached_property
6+
57
from autoarray.numpy_wrapper import register_pytree_node_class
68
from typing import TYPE_CHECKING
79

@@ -363,7 +365,7 @@ def border_native(self) -> np.ndarray:
363365
"""
364366
return self.native_for_slim[self.border_slim].astype("int")
365367

366-
@property
368+
@cached_property
367369
def native_for_slim(self) -> np.ndarray:
368370
"""
369371
Derives a 1D ``ndarray`` which maps every 1D ``slim`` index of the ``Mask2D`` to its

autoarray/mask/derive/zoom_2d.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
from typing import List, Tuple, Union
4+
5+
from autoarray.structures.grids import grid_2d_util
6+
7+
8+
class Zoom2D:
9+
10+
def __init__(self, mask: Union[np.ndarray, List]):
11+
"""
12+
Derives a zoomed in `Mask2D` object from a `Mask2D` object, which is typically used to visualize 2D arrays
13+
zoomed in to only the unmasked region an analysis is performed on.
14+
15+
A `Mask2D` masks values which are associated with a uniform 2D rectangular grid of pixels, where unmasked
16+
entries (which are `False`) are used in subsequent calculations and masked values (which are `True`) are
17+
omitted (for a full description see the :meth:`Mask2D` class API
18+
documentation <autoarray.mask.mask_2d.Mask2D.__new__>`).
19+
20+
The `Zoom2D` object calculations many different zoomed in qu
21+
22+
Parameters
23+
----------
24+
mask
25+
The `Mask2D` from which zoomed in `Mask2D` objects are derived.
26+
27+
Examples
28+
--------
29+
30+
.. code-block:: python
31+
32+
import autoarray as aa
33+
34+
mask_2d = aa.Mask2D(
35+
mask=[
36+
[True, True, True, True, True],
37+
[True, False, False, False, True],
38+
[True, False, False, False, True],
39+
[True, False, False, False, True],
40+
[True, True, True, True, True],
41+
],
42+
pixel_scales=1.0,
43+
)
44+
45+
zoom_2d = aa.Zoom2D(mask=mask_2d)
46+
47+
print(zoom_2d.centre)
48+
"""
49+
self.mask = mask
50+
51+
@property
52+
def centre(self) -> Tuple[float, float]:
53+
from autoarray.structures.grids.uniform_2d import Grid2D
54+
55+
grid = grid_2d_util.grid_2d_slim_via_mask_from(
56+
mask_2d=np.array(self.mask),
57+
pixel_scales=self.mask.pixel_scales,
58+
origin=self.mask.origin,
59+
)
60+
61+
grid = Grid2D(values=grid, mask=self.mask)
62+
63+
extraction_grid_1d = self.mask.geometry.grid_pixels_2d_from(grid_scaled_2d=grid)
64+
y_pixels_max = np.max(extraction_grid_1d[:, 0])
65+
y_pixels_min = np.min(extraction_grid_1d[:, 0])
66+
x_pixels_max = np.max(extraction_grid_1d[:, 1])
67+
x_pixels_min = np.min(extraction_grid_1d[:, 1])
68+
69+
return (
70+
((y_pixels_max + y_pixels_min - 1.0) / 2.0),
71+
((x_pixels_max + x_pixels_min - 1.0) / 2.0),
72+
)
73+
74+
@property
75+
def offset_pixels(self) -> Tuple[float, float]:
76+
if self.mask.pixel_scales is None:
77+
return self.mask.geometry.central_pixel_coordinates
78+
79+
return (
80+
self.centre[0] - self.mask.geometry.central_pixel_coordinates[0],
81+
self.centre[1] - self.mask.geometry.central_pixel_coordinates[1],
82+
)
83+
84+
@property
85+
def offset_scaled(self) -> Tuple[float, float]:
86+
return (
87+
-self.mask.pixel_scales[0] * self.offset_pixels[0],
88+
self.mask.pixel_scales[1] * self.offset_pixels[1],
89+
)
90+
91+
@property
92+
def region(self) -> List[int]:
93+
"""
94+
The zoomed rectangular region corresponding to the square encompassing all unmasked values. This zoomed
95+
extraction region is a squuare, even if the mask is rectangular.
96+
97+
This is used to zoom in on the region of an image that is used in an analysis for visualization.
98+
"""
99+
100+
where = np.array(np.where(np.invert(self.mask.astype("bool"))))
101+
y0, x0 = np.amin(where, axis=1)
102+
y1, x1 = np.amax(where, axis=1)
103+
104+
# Have to convert mask to bool for invert function to work.
105+
106+
ylength = y1 - y0
107+
xlength = x1 - x0
108+
109+
if ylength > xlength:
110+
length_difference = ylength - xlength
111+
x1 += int(length_difference / 2.0)
112+
x0 -= int(length_difference / 2.0)
113+
elif xlength > ylength:
114+
length_difference = xlength - ylength
115+
y1 += int(length_difference / 2.0)
116+
y0 -= int(length_difference / 2.0)
117+
118+
return [y0, y1 + 1, x0, x1 + 1]
119+
120+
@property
121+
def shape_native(self) -> Tuple[int, int]:
122+
region = self.region
123+
return (region[1] - region[0], region[3] - region[2])
124+
125+
@property
126+
def mask_unmasked(self) -> "Mask2D":
127+
"""
128+
The scaled-grid of (y,x) coordinates of every pixel.
129+
130+
This is defined from the top-left corner, such that the first pixel at location [0, 0] will have a negative x
131+
value y value in scaled units.
132+
"""
133+
134+
from autoarray.mask.mask_2d import Mask2D
135+
136+
return Mask2D.all_false(
137+
shape_native=self.shape_native,
138+
pixel_scales=self.mask.pixel_scales,
139+
origin=self.offset_scaled,
140+
)

autoarray/mask/mask_2d.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from pathlib import Path
66
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
77

8+
9+
from autoconf import cached_property
10+
811
from autoarray.structures.abstract_structure import Structure
912

1013
if TYPE_CHECKING:
@@ -214,6 +217,8 @@ def __init__(
214217
pixel_scales=pixel_scales,
215218
)
216219

220+
self.derive_indexes.native_for_slim
221+
217222
__no_flatten__ = ("derive_indexes",)
218223

219224
def __array_finalize__(self, obj):
@@ -238,7 +243,7 @@ def geometry(self) -> Geometry2D:
238243
origin=self.origin,
239244
)
240245

241-
@property
246+
@cached_property
242247
def derive_indexes(self) -> DeriveIndexes2D:
243248
return DeriveIndexes2D(mask=self)
244249

autoarray/operators/contour.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ def contour_list(self):
6363

6464
contour_indices_list = measure.find_contours(contour_array, 0)
6565

66-
print(contour_indices_list)
67-
6866
if len(contour_indices_list) == 0:
6967
return []
7068

0 commit comments

Comments
 (0)