Skip to content

Commit d0c324b

Browse files
authored
Merge pull request #165 from Jammy2211/feature/jax_and_numba
Feature/jax and numba
2 parents a331718 + 70843c0 commit d0c324b

25 files changed

Lines changed: 298 additions & 476 deletions

autoarray/dataset/imaging/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162

163163
if psf is not None and use_normalized_psf:
164164
psf = Kernel2D.no_mask(
165-
values=psf.native, pixel_scales=psf.pixel_scales, normalize=True
165+
values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True
166166
)
167167

168168
self.psf = psf
@@ -193,7 +193,7 @@ def convolver(self):
193193
The convolver given the masked imaging data's mask and PSF.
194194
"""
195195

196-
return Convolver(mask=self.mask, kernel=self.psf)
196+
return Convolver(mask=self.mask, kernel=Kernel2D(values=self.psf._array, mask=self.psf.mask, header=self.psf.header))
197197

198198
@cached_property
199199
def w_tilde(self):

autoarray/fit/fit_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import wraps
2+
import jax.numpy as np
23

3-
from autoarray.numpy_wrapper import np
44
from autoarray.mask.abstract_mask import Mask
55

66
from autoarray import type as ty
@@ -83,7 +83,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float:
8383
chi_squared_map
8484
The chi-squared-map of values of the model-data fit to the dataset.
8585
"""
86-
return np.sum(chi_squared_map)
86+
return np.sum(chi_squared_map._array)
8787

8888

8989
def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
@@ -97,7 +97,7 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
9797
noise_map
9898
The masked noise-map of the dataset.
9999
"""
100-
return np.sum(np.log(2 * np.pi * noise_map**2.0))
100+
return np.sum(np.log(2 * np.pi * noise_map._array**2.0))
101101

102102

103103
def normalized_residual_map_complex_from(

autoarray/geometry/geometry_2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,9 @@ def scaled_coordinates_2d_from(
184184
-------
185185
A 2D (y,x) pixel-value coordinate.
186186
"""
187+
187188
return geometry_util.scaled_coordinates_2d_from(
188-
pixel_coordinates_2d=pixel_coordinates_2d,
189+
pixel_coordinates_2d=np.array(pixel_coordinates_2d),
189190
shape_native=self.shape_native,
190191
pixel_scales=self.pixel_scales,
191192
origins=self.origin,

autoarray/geometry/geometry_util.py

Lines changed: 102 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import jax.numpy as jnp
2+
import numpy as np
13
from typing import Tuple, Union
2-
from autoarray.numpy_wrapper import np, use_jax
4+
35

46
from autoarray import numba_util
57
from autoarray import type as ty
@@ -179,8 +181,69 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float]
179181

180182
return pixel_scales
181183

184+
@numba_util.jit()
185+
def central_pixel_coordinates_2d_numba_from(
186+
shape_native: Tuple[int, int],
187+
) -> Tuple[float, float]:
188+
"""
189+
Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``)
190+
from the shape of that data structure.
191+
192+
Examples of the central pixels are as follows:
193+
194+
- For a 3x3 image, the central pixel is pixel [1, 1].
195+
- For a 4x4 image, the central pixel is [1.5, 1.5].
196+
197+
Parameters
198+
----------
199+
shape_native
200+
The dimensions of the data structure, which can be in 1D, 2D or higher dimensions.
201+
202+
Returns
203+
-------
204+
The central pixel coordinates of the data structure.
205+
"""
206+
return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2)
182207

183208
@numba_util.jit()
209+
def central_scaled_coordinate_2d_numba_from(
210+
shape_native: Tuple[int, int],
211+
pixel_scales: ty.PixelScales,
212+
origin: Tuple[float, float] = (0.0, 0.0),
213+
) -> Tuple[float, float]:
214+
"""
215+
Returns the central scaled coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``)
216+
from the shape of that data structure.
217+
218+
This is computed by using the data structure's shape and converting it to scaled units using an input
219+
pixel-coordinates to scaled-coordinate conversion factor `pixel_scales`.
220+
221+
The origin of the scaled grid can also be input and moved from (0.0, 0.0).
222+
223+
Parameters
224+
----------
225+
shape_native
226+
The 2D shape of the data structure whose central scaled coordinates are computed.
227+
pixel_scales
228+
The (y,x) scaled units to pixel units conversion factor of the 2D data structure.
229+
origin
230+
The (y,x) scaled units origin of the coordinate system the central scaled coordinate is computed on.
231+
232+
Returns
233+
-------
234+
The central coordinates of the 2D data structure in scaled units.
235+
"""
236+
237+
central_pixel_coordinates = central_pixel_coordinates_2d_numba_from(
238+
shape_native=shape_native
239+
)
240+
241+
y_pixel = central_pixel_coordinates[0] + (origin[0] / pixel_scales[0])
242+
x_pixel = central_pixel_coordinates[1] - (origin[1] / pixel_scales[1])
243+
244+
return (y_pixel, x_pixel)
245+
246+
184247
def central_pixel_coordinates_2d_from(
185248
shape_native: Tuple[int, int],
186249
) -> Tuple[float, float]:
@@ -205,7 +268,6 @@ def central_pixel_coordinates_2d_from(
205268
return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2)
206269

207270

208-
@numba_util.jit()
209271
def central_scaled_coordinate_2d_from(
210272
shape_native: Tuple[int, int],
211273
pixel_scales: ty.PixelScales,
@@ -234,7 +296,7 @@ def central_scaled_coordinate_2d_from(
234296
The central coordinates of the 2D data structure in scaled units.
235297
"""
236298

237-
central_pixel_coordinates = central_pixel_coordinates_2d_from(
299+
central_pixel_coordinates = central_pixel_coordinates_2d_numba_from(
238300
shape_native=shape_native
239301
)
240302

@@ -243,8 +305,6 @@ def central_scaled_coordinate_2d_from(
243305

244306
return (y_pixel, x_pixel)
245307

246-
247-
@numba_util.jit()
248308
def pixel_coordinates_2d_from(
249309
scaled_coordinates_2d: Tuple[float, float],
250310
shape_native: Tuple[int, int],
@@ -352,7 +412,7 @@ def scaled_coordinates_2d_from(
352412
origin=(0.0, 0.0)
353413
)
354414
"""
355-
central_scaled_coordinates = central_scaled_coordinate_2d_from(
415+
central_scaled_coordinates = central_scaled_coordinate_2d_numba_from(
356416
shape_native=shape_native, pixel_scales=pixel_scales, origin=origins
357417
)
358418

@@ -382,18 +442,16 @@ def transform_grid_2d_to_reference_frame(
382442
grid
383443
The 2d grid of (y, x) coordinates which are transformed to a new reference frame.
384444
"""
385-
if use_jax:
386-
shifted_grid_2d = grid_2d.array - np.array(centre)
387-
else:
388-
shifted_grid_2d = grid_2d - np.array(centre)
389-
radius = np.sqrt(np.sum(shifted_grid_2d**2.0, axis=1))
390-
theta_coordinate_to_profile = np.arctan2(
445+
shifted_grid_2d = np.array(grid_2d) - jnp.array(centre)
446+
447+
radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1))
448+
theta_coordinate_to_profile = jnp.arctan2(
391449
shifted_grid_2d[:, 0], shifted_grid_2d[:, 1]
392-
) - np.radians(angle)
393-
return np.vstack(
450+
) - jnp.radians(angle)
451+
return jnp.vstack(
394452
[
395-
radius * np.sin(theta_coordinate_to_profile),
396-
radius * np.cos(theta_coordinate_to_profile),
453+
radius * jnp.sin(theta_coordinate_to_profile),
454+
radius * jnp.cos(theta_coordinate_to_profile),
397455
]
398456
).T
399457

@@ -435,7 +493,6 @@ def transform_grid_2d_from_reference_frame(
435493
return np.vstack((y, x)).T
436494

437495

438-
@numba_util.jit()
439496
def grid_pixels_2d_slim_from(
440497
grid_scaled_2d_slim: np.ndarray,
441498
shape_native: Tuple[int, int],
@@ -476,33 +533,15 @@ def grid_pixels_2d_slim_from(
476533
grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_scaled_2d_slim=grid_scaled_2d_slim, shape=(2,2),
477534
pixel_scales=(0.5, 0.5), origin=(0.0, 0.0))
478535
"""
479-
480536
centres_scaled = central_scaled_coordinate_2d_from(
481537
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
482538
)
483-
if use_jax:
484-
centres_scaled = np.array(centres_scaled)
485-
pixel_scales = np.array(pixel_scales)
486-
sign = np.array([-1, 1])
487-
return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5
488-
else:
489-
grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2))
490-
for slim_index in range(grid_scaled_2d_slim.shape[0]):
491-
grid_pixels_2d_slim[slim_index, 0] = (
492-
(-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0])
493-
+ centres_scaled[0]
494-
+ 0.5
495-
)
496-
grid_pixels_2d_slim[slim_index, 1] = (
497-
(grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1])
498-
+ centres_scaled[1]
499-
+ 0.5
500-
)
501-
502-
return grid_pixels_2d_slim
539+
centres_scaled = np.array(centres_scaled)
540+
pixel_scales = np.array(pixel_scales)
541+
sign = np.array([-1, 1])
542+
return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5
503543

504544

505-
@numba_util.jit()
506545
def grid_pixel_centres_2d_slim_from(
507546
grid_scaled_2d_slim: np.ndarray,
508547
shape_native: Tuple[int, int],
@@ -547,32 +586,14 @@ def grid_pixel_centres_2d_slim_from(
547586
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
548587
)
549588

550-
if use_jax:
551-
centres_scaled = np.array(centres_scaled)
552-
pixel_scales = np.array(pixel_scales)
553-
sign = np.array([-1.0, 1.0])
554-
grid_pixels_2d_slim = (
555-
(sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5
556-
).astype(int)
557-
else:
558-
grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2))
559-
560-
for slim_index in range(grid_scaled_2d_slim.shape[0]):
561-
grid_pixels_2d_slim[slim_index, 0] = int(
562-
(-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0])
563-
+ centres_scaled[0]
564-
+ 0.5
565-
)
566-
grid_pixels_2d_slim[slim_index, 1] = int(
567-
(grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1])
568-
+ centres_scaled[1]
569-
+ 0.5
570-
)
571-
572-
return grid_pixels_2d_slim
589+
centres_scaled = np.array(centres_scaled)
590+
pixel_scales = np.array(pixel_scales)
591+
sign = np.array([-1.0, 1.0])
592+
return (
593+
(sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5
594+
).astype(int)
573595

574596

575-
@numba_util.jit()
576597
def grid_pixel_indexes_2d_slim_from(
577598
grid_scaled_2d_slim: np.ndarray,
578599
shape_native: Tuple[int, int],
@@ -625,25 +646,13 @@ def grid_pixel_indexes_2d_slim_from(
625646
origin=origin,
626647
)
627648

628-
if use_jax:
629-
grid_pixel_indexes_2d_slim = (
630-
(grid_pixels_2d_slim * np.array([shape_native[1], 1]))
631-
.sum(axis=1)
632-
.astype(int)
633-
)
634-
else:
635-
grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0])
636-
637-
for slim_index in range(grid_pixels_2d_slim.shape[0]):
638-
grid_pixel_indexes_2d_slim[slim_index] = int(
639-
grid_pixels_2d_slim[slim_index, 0] * shape_native[1]
640-
+ grid_pixels_2d_slim[slim_index, 1]
641-
)
642-
643-
return grid_pixel_indexes_2d_slim
649+
return (
650+
(grid_pixels_2d_slim * np.array([shape_native[1], 1]))
651+
.sum(axis=1)
652+
.astype(int)
653+
)
644654

645655

646-
@numba_util.jit()
647656
def grid_scaled_2d_slim_from(
648657
grid_pixels_2d_slim: np.ndarray,
649658
shape_native: Tuple[int, int],
@@ -682,33 +691,18 @@ def grid_scaled_2d_slim_from(
682691
grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_pixels_2d_slim=grid_pixels_2d_slim, shape=(2,2),
683692
pixel_scales=(0.5, 0.5), origin=(0.0, 0.0))
684693
"""
685-
686694
centres_scaled = central_scaled_coordinate_2d_from(
687695
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
688696
)
689-
if use_jax:
690-
centres_scaled = np.array(centres_scaled)
691-
pixel_scales = np.array(pixel_scales)
692-
sign = np.array([-1, 1])
693-
grid_scaled_2d_slim = (
694-
(grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign
695-
)
696-
else:
697-
grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2))
698-
699-
for slim_index in range(grid_scaled_2d_slim.shape[0]):
700-
grid_scaled_2d_slim[slim_index, 0] = (
701-
-(grid_pixels_2d_slim[slim_index, 0] - centres_scaled[0] - 0.5)
702-
* pixel_scales[0]
703-
)
704-
grid_scaled_2d_slim[slim_index, 1] = (
705-
grid_pixels_2d_slim[slim_index, 1] - centres_scaled[1] - 0.5
706-
) * pixel_scales[1]
707-
708-
return grid_scaled_2d_slim
697+
698+
centres_scaled = np.array(centres_scaled)
699+
pixel_scales = np.array(pixel_scales)
700+
sign = np.array([-1, 1])
701+
return (
702+
(grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign
703+
)
709704

710705

711-
@numba_util.jit()
712706
def grid_pixel_centres_2d_from(
713707
grid_scaled_2d: np.ndarray,
714708
shape_native: Tuple[int, int],
@@ -753,30 +747,12 @@ def grid_pixel_centres_2d_from(
753747
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
754748
)
755749

756-
if use_jax:
757-
centres_scaled = np.array(centres_scaled)
758-
pixel_scales = np.array(pixel_scales)
759-
sign = np.array([-1.0, 1.0])
760-
grid_pixels_2d = (
761-
(sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5
762-
).astype(int)
763-
else:
764-
grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2))
765-
766-
for y in range(grid_scaled_2d.shape[0]):
767-
for x in range(grid_scaled_2d.shape[1]):
768-
grid_pixels_2d[y, x, 0] = int(
769-
(-grid_scaled_2d[y, x, 0] / pixel_scales[0])
770-
+ centres_scaled[0]
771-
+ 0.5
772-
)
773-
grid_pixels_2d[y, x, 1] = int(
774-
(grid_scaled_2d[y, x, 1] / pixel_scales[1])
775-
+ centres_scaled[1]
776-
+ 0.5
777-
)
778-
779-
return grid_pixels_2d
750+
centres_scaled = np.array(centres_scaled)
751+
pixel_scales = np.array(pixel_scales)
752+
sign = np.array([-1.0, 1.0])
753+
return (
754+
(sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5
755+
).astype(int)
780756

781757

782758
def extent_symmetric_from(

0 commit comments

Comments
 (0)