Skip to content

Commit 4e0b925

Browse files
committed
maapping matrix convolve works
1 parent 44cd415 commit 4e0b925

5 files changed

Lines changed: 37 additions & 23 deletions

File tree

autoarray/dataset/imaging/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ def __init__(
166166

167167
self.psf = psf
168168

169+
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
170+
raise exc.KernelException("Kernel2D Kernel2D must be odd")
171+
169172
@cached_property
170173
def grids(self):
171174
return GridsDataset(

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]:
526526
)
527527

528528
mapped_reconstructed_image = self.convolver.convolve_image_no_blurring(
529-
image=mapped_reconstructed_image
529+
image=mapped_reconstructed_image,
530+
mask=self.mask
530531
)
531532

532533
else:

autoarray/structures/arrays/kernel_2d.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ def __init__(
5454
store_native=store_native,
5555
)
5656

57-
if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0:
58-
raise exc.KernelException("Kernel2D Kernel2D must be odd")
59-
6057
if normalize:
6158
self._array = np.divide(self._array, np.sum(self._array))
6259

@@ -529,7 +526,7 @@ def convolve_image(self, image, blurring_image, jax_method="fft"):
529526

530527
return Array2D(values=convolved_array_1d, mask=image.mask)
531528

532-
def convolve_image_no_blurring(self, image, jax_method="fft"):
529+
def convolve_image_no_blurring(self, image, mask, jax_method="fft"):
533530
"""
534531
For a given 1D array and blurring array, convolve the two using this convolver.
535532
@@ -547,13 +544,13 @@ def convolve_image_no_blurring(self, image, jax_method="fft"):
547544
"""
548545

549546
slim_to_native = jnp.nonzero(
550-
jnp.logical_not(image.mask.array), size=image.shape[0]
547+
jnp.logical_not(mask.array), size=image.shape[0]
551548
)
552549

553-
expanded_array_native = jnp.zeros(image.mask.shape)
550+
expanded_array_native = jnp.zeros(mask.shape)
554551

555552
expanded_array_native = expanded_array_native.at[slim_to_native].set(
556-
image.array
553+
image
557554
)
558555

559556
kernel = np.array(self.native.array)
@@ -564,4 +561,14 @@ def convolve_image_no_blurring(self, image, jax_method="fft"):
564561

565562
convolved_array_1d = convolve_native[slim_to_native]
566563

567-
return Array2D(values=convolved_array_1d, mask=image.mask)
564+
return Array2D(values=convolved_array_1d, mask=mask)
565+
566+
def convolve_mapping_matrix(self, mapping_matrix, mask):
567+
"""For a given 1D array and blurring array, convolve the two using this convolver.
568+
569+
Parameters
570+
----------
571+
image
572+
1D array of the values which are to be blurred with the convolver's PSF.
573+
"""
574+
return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None))(mapping_matrix, mask).T

test_autoarray/dataset/imaging/test_dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import autoarray as aa
1010

11+
from autoarray import exc
12+
1113
test_data_path = path.join(
1214
"{}".format(path.dirname(path.realpath(__file__))),
1315
"files",
@@ -241,3 +243,8 @@ def test__noise_map_unmasked_has_zeros_or_negative__raises_exception():
241243

242244
with pytest.raises(aa.exc.DatasetException):
243245
aa.Imaging(data=array, noise_map=noise_map)
246+
247+
def test__psf_not_odd_x_odd_kernel__raises_error():
248+
249+
with pytest.raises(exc.KernelException):
250+
aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0)

test_autoarray/structures/arrays/test_kernel_2d.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -292,12 +292,6 @@ def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy
292292
assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4)
293293

294294

295-
def test__not_odd_x_odd_kernel__raises_error():
296-
297-
with pytest.raises(exc.KernelException):
298-
aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0)
299-
300-
301295
def test__convolved_array_from():
302296

303297
array_2d = aa.Array2D.no_mask(
@@ -430,13 +424,13 @@ def test__convolve_image_no_blurring():
430424

431425
masked_image = aa.Array2D(values=image.native, mask=mask)
432426

433-
blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image)
427+
blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image, mask=mask)
434428

435429
assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4)
436430

437431

438432
def test__convolve_mapping_matrix():
439-
mask = np.array(
433+
mask = aa.Mask2D(mask=np.array(
440434
[
441435
[True, True, True, True, True, True],
442436
[True, False, False, False, False, True],
@@ -445,7 +439,7 @@ def test__convolve_mapping_matrix():
445439
[True, False, False, False, False, True],
446440
[True, True, True, True, True, True],
447441
]
448-
)
442+
), pixel_scales=1.0)
449443

450444
kernel = aa.Kernel2D.no_mask(
451445
values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0
@@ -476,11 +470,11 @@ def test__convolve_mapping_matrix():
476470
]
477471
)
478472

479-
blurred_mapping = kernel.convolve_mapping_matrix(mapping)
473+
blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask)
480474

481475
assert (
482476
blurred_mapping
483-
== np.array(
477+
== pytest.approx(np.array(
484478
[
485479
[0, 0, 0],
486480
[0, 0, 0],
@@ -500,7 +494,7 @@ def test__convolve_mapping_matrix():
500494
[0, 0, 0],
501495
]
502496
)
503-
).all()
497+
), 1.0e-4)
504498

505499
kernel = aa.Kernel2D.no_mask(
506500
values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0
@@ -531,7 +525,9 @@ def test__convolve_mapping_matrix():
531525
]
532526
)
533527

534-
blurred_mapping = kernel.convolve_mapping_matrix(mapping)
528+
blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask)
529+
530+
print(blurred_mapping)
535531

536532
assert blurred_mapping == pytest.approx(
537533
np.array(
@@ -554,5 +550,5 @@ def test__convolve_mapping_matrix():
554550
[0, 0, 0],
555551
]
556552
),
557-
1e-4,
553+
abs=1e-4,
558554
)

0 commit comments

Comments
 (0)