Skip to content

Commit ca6a934

Browse files
authored
Merge pull request #170 from Jammy2211/feature/jax_ndarray_casting_rules
Feature/jax ndarray casting rules
2 parents c9268b2 + 135d5b8 commit ca6a934

93 files changed

Lines changed: 1309 additions & 1084 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

autoarray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .fit.fit_imaging import FitImaging
2121
from .fit.fit_interferometer import FitInterferometer
2222
from .geometry.geometry_2d import Geometry2D
23+
from .inversion.convolver import Convolver
2324
from .inversion.pixelization.mappers.abstract import AbstractMapper
2425
from .inversion.pixelization import mesh
2526
from .inversion.pixelization import image_mesh
@@ -44,7 +45,6 @@
4445
from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde
4546
from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde
4647
from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping
47-
from .inversion.inversion.interferometer.lop import InversionInterferometerMappingPyLops
4848
from .inversion.linear_obj.linear_obj import LinearObj
4949
from .inversion.linear_obj.func_list import AbstractLinearObjFuncList
5050
from .mask.derive.indexes_2d import DeriveIndexes2D

autoarray/abstract_ndarray.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from abc import ABC
66
from abc import abstractmethod
7+
import jax.numpy as jnp
78

89
from autoconf.fitsable import output_to_fits
910

10-
from autoarray.numpy_wrapper import np, register_pytree_node, Array
11+
from autoarray.numpy_wrapper import register_pytree_node, Array
1112

1213
from typing import TYPE_CHECKING
1314

@@ -82,7 +83,7 @@ def __init__(self, array):
8283

8384
def invert(self):
8485
new = self.copy()
85-
new._array = np.invert(new._array)
86+
new._array = jnp.invert(new._array)
8687
return new
8788

8889
@classmethod
@@ -104,7 +105,7 @@ def instance_flatten(cls, instance):
104105
@staticmethod
105106
def flip_hdu_for_ds9(values):
106107
if conf.instance["general"]["fits"]["flip_for_ds9"]:
107-
return np.flipud(values)
108+
return jnp.flipud(values)
108109
return values
109110

110111
@classmethod
@@ -117,7 +118,7 @@ def instance_unflatten(cls, aux_data, children):
117118
setattr(instance, key, value)
118119
return instance
119120

120-
def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
121+
def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
121122
"""
122123
Copy this object but give it a new array.
123124
@@ -164,7 +165,7 @@ def __iter__(self):
164165

165166
@to_new_array
166167
def sqrt(self):
167-
return np.sqrt(self._array)
168+
return jnp.sqrt(self._array)
168169

169170
@property
170171
def array(self):
@@ -330,13 +331,13 @@ def __getitem__(self, item):
330331
result = self._array[item]
331332
if isinstance(item, slice):
332333
result = self.with_new_array(result)
333-
if isinstance(result, np.ndarray):
334+
if isinstance(result, jnp.ndarray):
334335
result = self.with_new_array(result)
335336
return result
336337

337338
def __setitem__(self, key, value):
338-
if isinstance(key, (np.ndarray, AbstractNDArray, Array)):
339-
self._array = np.where(key, value, self._array)
339+
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
340+
self._array = jnp.where(key, value, self._array)
340341
else:
341342
self._array[key] = value
342343

autoarray/config/grids.yaml

Lines changed: 0 additions & 3 deletions
This file was deleted.

autoarray/dataset/imaging/dataset.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ def __init__(
166166

167167
self.psf = psf
168168

169-
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
170-
raise exc.KernelException("Kernel2D Kernel2D must be odd")
169+
if psf is not None:
170+
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
171+
raise exc.KernelException("Kernel2D Kernel2D must be odd")
171172

172173
@cached_property
173174
def grids(self):
@@ -178,6 +179,27 @@ def grids(self):
178179
psf=self.psf,
179180
)
180181

182+
@cached_property
183+
def convolver(self):
184+
"""
185+
Returns a `Convolver` from a mask and 2D PSF kernel.
186+
187+
The `Convolver` stores in memory the array indexing between the mask and PSF, enabling efficient 2D PSF
188+
convolution of images and matrices used for linear algebra calculations (see `operators.convolver`).
189+
190+
This uses lazy allocation such that the calculation is only performed when the convolver is used, ensuring
191+
efficient set up of the `Imaging` class.
192+
193+
Returns
194+
-------
195+
Convolver
196+
The convolver given the masked imaging data's mask and PSF.
197+
"""
198+
199+
from autoarray.inversion.convolver import Convolver
200+
201+
return Convolver(mask=self.mask, kernel=self.psf)
202+
181203
@cached_property
182204
def w_tilde(self):
183205
"""
@@ -203,9 +225,11 @@ def w_tilde(self):
203225
indexes,
204226
lengths,
205227
) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from(
206-
noise_map_native=np.array(self.noise_map.native),
207-
kernel_native=np.array(self.psf.native),
208-
native_index_for_slim_index=self.mask.derive_indexes.native_for_slim,
228+
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
229+
kernel_native=np.array(self.psf.native.array).astype("float64"),
230+
native_index_for_slim_index=np.array(
231+
self.mask.derive_indexes.native_for_slim
232+
).astype("int"),
209233
)
210234

211235
return WTildeImaging(
@@ -408,20 +432,20 @@ def apply_noise_scaling(
408432
"""
409433

410434
if signal_to_noise_value is None:
411-
noise_map = self.noise_map.native
412-
noise_map[mask == False] = noise_value
435+
noise_map = np.array(self.noise_map.native.array)
436+
noise_map[mask.array == False] = noise_value
413437
else:
414438
noise_map = np.where(
415439
mask == False,
416-
np.median(self.data.native[mask.derive_mask.edge == False])
440+
np.median(self.data.native.array[mask.derive_mask.edge == False])
417441
/ signal_to_noise_value,
418-
self.noise_map.native,
442+
self.noise_map.native.array,
419443
)
420444

421445
if should_zero_data:
422-
data = np.where(np.invert(mask), 0.0, self.data.native)
446+
data = np.where(np.invert(mask.array), 0.0, self.data.native.array)
423447
else:
424-
data = self.data.native
448+
data = self.data.native.array
425449

426450
data_unmasked = Array2D.no_mask(
427451
values=data,

autoarray/dataset/imaging/simulator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def via_image_from(
151151
pixel_scales=image.pixel_scales,
152152
)
153153

154-
if np.isnan(noise_map).any():
154+
if np.isnan(noise_map.array).any():
155155
raise exc.DatasetException(
156156
"The noise-map has NaN values in it. This suggests your exposure time and / or"
157157
"background sky levels are too low, creating signal counts at or close to 0.0."
@@ -161,7 +161,9 @@ def via_image_from(
161161
image = image - background_sky_map
162162

163163
mask = Mask2D.all_false(
164-
shape_native=image.shape_native, pixel_scales=image.pixel_scales
164+
shape_native=image.shape_native,
165+
pixel_scales=image.pixel_scales,
166+
origin=image.origin,
165167
)
166168

167169
image = Array2D(values=image, mask=mask)

autoarray/dataset/interferometer/dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def w_tilde(self):
193193

194194
w_matrix = inversion_interferometer_util.w_tilde_via_preload_from(
195195
w_tilde_preload=curvature_preload,
196-
native_index_for_slim_index=self.real_space_mask.derive_indexes.native_for_slim,
196+
native_index_for_slim_index=np.array(
197+
self.real_space_mask.derive_indexes.native_for_slim
198+
).astype("int"),
197199
)
198200

199201
dirty_image = self.transformer.image_from(
@@ -205,7 +207,7 @@ def w_tilde(self):
205207
return WTildeInterferometer(
206208
w_matrix=w_matrix,
207209
curvature_preload=curvature_preload,
208-
dirty_image=dirty_image,
210+
dirty_image=np.array(dirty_image.array),
209211
real_space_mask=self.real_space_mask,
210212
noise_map_value=self.noise_map[0],
211213
)

autoarray/dataset/preprocess.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def noise_map_via_data_eps_and_exposure_time_map_from(data_eps, exposure_time_ma
149149
The exposure time at every data-point of the data.
150150
"""
151151
return data_eps.with_new_array(
152-
np.abs(data_eps * exposure_time_map) ** 0.5 / exposure_time_map
152+
np.abs(data_eps.array * exposure_time_map.array) ** 0.5
153+
/ exposure_time_map.array
153154
)
154155

155156

@@ -263,15 +264,17 @@ def edges_from(image, no_edges):
263264
edges = []
264265

265266
for edge_no in range(no_edges):
266-
top_edge = image.native[edge_no, edge_no : image.shape_native[1] - edge_no]
267-
bottom_edge = image.native[
267+
top_edge = image.native.array[
268+
edge_no, edge_no : image.shape_native[1] - edge_no
269+
]
270+
bottom_edge = image.native.array[
268271
image.shape_native[0] - 1 - edge_no,
269272
edge_no : image.shape_native[1] - edge_no,
270273
]
271-
left_edge = image.native[
274+
left_edge = image.native.array[
272275
edge_no + 1 : image.shape_native[0] - 1 - edge_no, edge_no
273276
]
274-
right_edge = image.native[
277+
right_edge = image.native.array[
275278
edge_no + 1 : image.shape_native[0] - 1 - edge_no,
276279
image.shape_native[1] - 1 - edge_no,
277280
]
@@ -406,9 +409,10 @@ def poisson_noise_via_data_eps_from(data_eps, exposure_time_map, seed=-1):
406409
An array describing simulated poisson noise_maps
407410
"""
408411
setup_random_seed(seed)
409-
image_counts = np.multiply(data_eps, exposure_time_map)
412+
413+
image_counts = np.multiply(data_eps.array, exposure_time_map.array)
410414
return data_eps - np.divide(
411-
np.random.poisson(image_counts, data_eps.shape), exposure_time_map
415+
np.random.poisson(image_counts, data_eps.shape), exposure_time_map.array
412416
)
413417

414418

@@ -506,8 +510,6 @@ def noise_map_with_signal_to_noise_limit_from(
506510
from autoarray.structures.arrays.uniform_1d import Array1D
507511
from autoarray.structures.arrays.uniform_2d import Array2D
508512

509-
# TODO : Refacotr into a util
510-
511513
signal_to_noise_map = data / noise_map
512514
signal_to_noise_map[signal_to_noise_map < 0] = 0
513515

@@ -517,12 +519,14 @@ def noise_map_with_signal_to_noise_limit_from(
517519
noise_map_limit = np.where(
518520
(signal_to_noise_map.native > signal_to_noise_limit)
519521
& (noise_limit_mask == False),
520-
np.abs(data.native) / signal_to_noise_limit,
521-
noise_map.native,
522+
np.abs(data.native.array) / signal_to_noise_limit,
523+
noise_map.native.array,
522524
)
523525

524526
mask = Mask2D.all_false(
525-
shape_native=data.shape_native, pixel_scales=data.pixel_scales
527+
shape_native=data.shape_native,
528+
pixel_scales=data.pixel_scales,
529+
origin=data.origin,
526530
)
527531

528532
if len(noise_map.native) == 1:

autoarray/fit/fit_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def chi_squared(self) -> float:
8686
"""
8787
Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map.
8888
"""
89-
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map)
89+
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array)
9090

9191
@property
9292
def noise_normalization(self) -> float:
@@ -95,7 +95,7 @@ def noise_normalization(self) -> float:
9595
9696
[Noise_Term] = sum(log(2*pi*[Noise]**2.0))
9797
"""
98-
return fit_util.noise_normalization_from(noise_map=self.noise_map)
98+
return fit_util.noise_normalization_from(noise_map=self.noise_map.array)
9999

100100
@property
101101
def log_likelihood(self) -> float:

0 commit comments

Comments
 (0)