Skip to content

Commit 0fa9646

Browse files
committed
mask array 2d choose type based on if input is jax or numpy
1 parent 87931b6 commit 0fa9646

8 files changed

Lines changed: 37 additions & 27 deletions

File tree

autoarray/dataset/imaging/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def __init__(
135135
f"the image and noise-map yourself."
136136
)
137137

138+
print(type(data.array))
139+
138140
super().__init__(
139141
data=data,
140142
noise_map=noise_map,

autoarray/structures/arrays/array_2d_util.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray:
2727
except AttributeError:
2828
pass
2929

30-
return jnp.asarray(array)
30+
if isinstance(array, list):
31+
array = np.asarray(array)
32+
33+
return array
3134

3235

3336

@@ -119,27 +122,31 @@ def convert_array_2d(
119122
"""
120123
array_2d = convert_array(array=array_2d).copy()
121124

122-
check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d)
125+
if isinstance(array_2d, np.ndarray):
126+
is_numpy = True
127+
else:
128+
is_numpy = False
123129

124-
mask_2d = jnp.array(mask_2d.array)
130+
check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d)
125131

126132
is_native = len(array_2d.shape) == 2
127133

128134
if is_native and not skip_mask:
129-
array_2d *= jnp.invert(mask_2d)
135+
array_2d *= np.invert(mask_2d)
130136

131137
if is_native == store_native:
132-
return array_2d
138+
return np.array(array_2d) if is_numpy else jnp.array(array_2d)
133139
elif not store_native:
134-
return array_2d_slim_from(
140+
array_2d = array_2d_slim_from(
135141
array_2d_native=array_2d,
136142
mask_2d=mask_2d,
137143
)
144+
return np.array(array_2d) if is_numpy else jnp.array(array_2d)
138145
array_2d = array_2d_native_from(
139146
array_2d_slim=array_2d,
140147
mask_2d=mask_2d,
141148
)
142-
return array_2d
149+
return np.array(array_2d) if is_numpy else jnp.array(array_2d)
143150

144151

145152
def convert_array_2d_to_slim(array_2d: np.ndarray, mask_2d: Mask2D) -> np.ndarray:

autoarray/structures/arrays/uniform_2d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def zoomed_around_mask(self, buffer: int = 1) -> "Array2D":
459459
"""
460460

461461
extracted_array_2d = array_2d_util.extracted_array_2d_from(
462-
array_2d=np.array(self.native._array),
462+
array_2d=np.array(self.native.array),
463463
y0=self.mask.zoom_region[0] - buffer,
464464
y1=self.mask.zoom_region[1] + buffer,
465465
x0=self.mask.zoom_region[2] - buffer,
@@ -493,7 +493,7 @@ def extent_of_zoomed_array(self, buffer: int = 1) -> np.ndarray:
493493
The number pixels around the extracted array used as a buffer.
494494
"""
495495
extracted_array_2d = array_2d_util.extracted_array_2d_from(
496-
array_2d=np.array(self.native._array),
496+
array_2d=np.array(self.native.array),
497497
y0=self.mask.zoom_region[0] - buffer,
498498
y1=self.mask.zoom_region[1] + buffer,
499499
x0=self.mask.zoom_region[2] - buffer,
@@ -527,7 +527,7 @@ def resized_from(
527527
"""
528528

529529
resized_array_2d = array_2d_util.resized_array_2d_from(
530-
array_2d=np.array(self.native._array), resized_shape=new_shape
530+
array_2d=np.array(self.native.array), resized_shape=new_shape
531531
)
532532

533533
resized_mask = self.mask.resized_from(
@@ -587,14 +587,14 @@ def trimmed_after_convolution_from(
587587
psf_cut_x = int(np.ceil(kernel_shape[1] / 2)) - 1
588588
array_y = int(self.mask.shape[0])
589589
array_x = int(self.mask.shape[1])
590-
trimmed_array_2d = self.native[
590+
trimmed_array_2d = self.native.array[
591591
psf_cut_y : array_y - psf_cut_y, psf_cut_x : array_x - psf_cut_x
592592
]
593593

594594
resized_mask = self.mask.resized_from(new_shape=trimmed_array_2d.shape)
595595

596596
array = array_2d_util.convert_array_2d(
597-
array_2d=trimmed_array_2d._array, mask_2d=resized_mask
597+
array_2d=trimmed_array_2d, mask_2d=resized_mask
598598
)
599599

600600
return Array2D(

test_autoarray/dataset/imaging/test_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test__from_fits():
105105

106106

107107
def test__output_to_fits(imaging_7x7, test_data_path):
108+
108109
imaging_7x7.output_to_fits(
109110
data_path=path.join(test_data_path, "data.fits"),
110111
psf_path=path.join(test_data_path, "psf.fits"),

test_autoarray/layout/test_region.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ def test__slice_1d__addition():
3131

3232
region = aa.Region1D(region=(0, 1))
3333

34-
arr_1d[region.slice] += image[region.slice]
34+
arr_1d = arr_1d.array.at[region.slice].add(image[region.slice])
3535

3636
assert (arr_1d == np.array([2.0, 2.0, 3.0, 4.0])).all()
3737

3838
arr_1d = aa.Array1D.no_mask(values=np.array([1.0, 2.0, 3.0, 4.0]), pixel_scales=1.0)
3939

4040
region = aa.Region1D(region=(2, 4))
4141

42-
arr_1d[region.slice] += image[region.slice]
42+
arr_1d = arr_1d.array.at[region.slice].add(image[region.slice])
4343

4444
assert (arr_1d == np.array([1.0, 2.0, 4.0, 5.0])).all()
4545

Binary file not shown.

test_autoarray/structures/arrays/test_kernel_2d.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,31 +150,31 @@ def test__rescaled_with_odd_dimensions_from__evens_to_odds():
150150
rescale_factor=0.5, normalize=True
151151
)
152152
assert kernel_2d.pixel_scales == (2.0, 2.0)
153-
assert (kernel_2d.native == (1.0 / 9.0) * np.ones((3, 3))).all()
153+
assert kernel_2d.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4)
154154

155155
array_2d = np.ones((9, 9))
156156
kernel_2d = aa.Kernel2D.no_mask(values=array_2d, pixel_scales=1.0, normalize=False)
157157
kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from(
158158
rescale_factor=0.333333333333333, normalize=True
159159
)
160160
assert kernel_2d.pixel_scales == (3.0, 3.0)
161-
assert (kernel_2d.native == (1.0 / 9.0) * np.ones((3, 3))).all()
161+
assert kernel_2d.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4)
162162

163163
array_2d = np.ones((18, 6))
164164
kernel_2d = aa.Kernel2D.no_mask(values=array_2d, pixel_scales=1.0, normalize=False)
165165
kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from(
166166
rescale_factor=0.5, normalize=True
167167
)
168168
assert kernel_2d.pixel_scales == (2.0, 2.0)
169-
assert (kernel_2d.native == (1.0 / 27.0) * np.ones((9, 3))).all()
169+
assert kernel_2d.native == pytest.approx((1.0 / 27.0) * np.ones((9, 3)), 1.0e-4)
170170

171171
array_2d = np.ones((6, 18))
172172
kernel_2d = aa.Kernel2D.no_mask(values=array_2d, pixel_scales=1.0, normalize=False)
173173
kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from(
174174
rescale_factor=0.5, normalize=True
175175
)
176176
assert kernel_2d.pixel_scales == (2.0, 2.0)
177-
assert (kernel_2d.native == (1.0 / 27.0) * np.ones((3, 9))).all()
177+
assert kernel_2d.native == pytest.approx((1.0 / 27.0) * np.ones((3, 9)), 1.0e-4)
178178

179179

180180
def test__rescaled_with_odd_dimensions_from__different_scalings():
@@ -183,7 +183,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings():
183183
rescale_factor=2.0, normalize=True
184184
)
185185
assert kernel_2d.pixel_scales == (0.4, 0.4)
186-
assert (kernel_2d.native == (1.0 / 25.0) * np.ones((5, 5))).all()
186+
assert kernel_2d.native == pytest.approx((1.0 / 25.0) * np.ones((5, 5)), 1.0e-4)
187187

188188
kernel_2d = aa.Kernel2D.ones(
189189
shape_native=(40, 40), pixel_scales=1.0, normalize=False
@@ -192,7 +192,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings():
192192
rescale_factor=0.1, normalize=True
193193
)
194194
assert kernel_2d.pixel_scales == (8.0, 8.0)
195-
assert (kernel_2d.native == (1.0 / 25.0) * np.ones((5, 5))).all()
195+
assert kernel_2d.native == pytest.approx((1.0 / 25.0) * np.ones((5, 5)), 1.0e-4)
196196

197197
kernel_2d = aa.Kernel2D.ones(shape_native=(2, 4), pixel_scales=1.0, normalize=False)
198198
kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from(
@@ -201,23 +201,23 @@ def test__rescaled_with_odd_dimensions_from__different_scalings():
201201

202202
assert kernel_2d.pixel_scales[0] == pytest.approx(0.4, 1.0e-4)
203203
assert kernel_2d.pixel_scales[1] == pytest.approx(0.4444444, 1.0e-4)
204-
assert (kernel_2d.native == (1.0 / 45.0) * np.ones((5, 9))).all()
204+
assert kernel_2d.native == pytest.approx((1.0 / 45.0) * np.ones((5, 9)), 1.0e-4)
205205

206206
kernel_2d = aa.Kernel2D.ones(shape_native=(4, 2), pixel_scales=1.0, normalize=False)
207207
kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from(
208208
rescale_factor=2.0, normalize=True
209209
)
210210
assert kernel_2d.pixel_scales[0] == pytest.approx(0.4444444, 1.0e-4)
211211
assert kernel_2d.pixel_scales[1] == pytest.approx(0.4, 1.0e-4)
212-
assert (kernel_2d.native == (1.0 / 45.0) * np.ones((9, 5))).all()
212+
assert kernel_2d.native == pytest.approx((1.0 / 45.0) * np.ones((9, 5)), 1.0e-4)
213213

214214
kernel_2d = aa.Kernel2D.ones(shape_native=(6, 4), pixel_scales=1.0, normalize=False)
215215
kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from(
216216
rescale_factor=0.5, normalize=True
217217
)
218218

219219
assert kernel_2d.pixel_scales == pytest.approx((2.0, 1.3333333333), 1.0e-4)
220-
assert (kernel_2d.native == (1.0 / 9.0) * np.ones((3, 3))).all()
220+
assert kernel_2d.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4)
221221

222222
kernel_2d = aa.Kernel2D.ones(
223223
shape_native=(9, 12), pixel_scales=1.0, normalize=False
@@ -227,15 +227,15 @@ def test__rescaled_with_odd_dimensions_from__different_scalings():
227227
)
228228

229229
assert kernel_2d.pixel_scales == pytest.approx((3.0, 2.4), 1.0e-4)
230-
assert (kernel_2d.native == (1.0 / 15.0) * np.ones((3, 5))).all()
230+
assert kernel_2d.native == pytest.approx((1.0 / 15.0) * np.ones((3, 5)), 1.0e-4)
231231

232232
kernel_2d = aa.Kernel2D.ones(shape_native=(4, 6), pixel_scales=1.0, normalize=False)
233233
kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from(
234234
rescale_factor=0.5, normalize=True
235235
)
236236

237237
assert kernel_2d.pixel_scales == pytest.approx((1.33333333333, 2.0), 1.0e-4)
238-
assert (kernel_2d.native == (1.0 / 9.0) * np.ones((3, 3))).all()
238+
assert kernel_2d.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4)
239239

240240
kernel_2d = aa.Kernel2D.ones(
241241
shape_native=(12, 9), pixel_scales=1.0, normalize=False
@@ -244,7 +244,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings():
244244
rescale_factor=0.33333333333, normalize=True
245245
)
246246
assert kernel_2d.pixel_scales == pytest.approx((2.4, 3.0), 1.0e-4)
247-
assert (kernel_2d.native == (1.0 / 15.0) * np.ones((5, 3))).all()
247+
assert kernel_2d.native == pytest.approx((1.0 / 15.0) * np.ones((5, 3)), 1.0e-4)
248248

249249

250250
def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy_gaussian_model():

test_autoarray/structures/arrays/test_repr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
def test_repr():
55
array = aa.Array2D.no_mask([[1, 2], [3, 4]], pixel_scales=1)
6-
assert repr(array) == "Array([1, 2, 3, 4], dtype=int64)"
6+
assert repr(array) == "Array2D([1, 2, 3, 4])"

0 commit comments

Comments
 (0)