11from functools import wraps
2- import jax .numpy as np
2+ import jax .numpy as jnp
3+ import numpy as np
34
45from autoarray .mask .abstract_mask import Mask
56
@@ -83,7 +84,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float:
8384 chi_squared_map
8485 The chi-squared-map of values of the model-data fit to the dataset.
8586 """
86- return np .sum (chi_squared_map . _array )
87+ return jnp .sum (np . array ( chi_squared_map ) )
8788
8889
8990def noise_normalization_from (* , noise_map : ty .DataLike ) -> float :
@@ -97,12 +98,12 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
9798 noise_map
9899 The masked noise-map of the dataset.
99100 """
100- return np .sum (np .log (2 * np .pi * noise_map . _array ** 2.0 ))
101+ return jnp .sum (jnp .log (2 * jnp .pi * np . array ( noise_map ) ** 2.0 ))
101102
102103
103104def normalized_residual_map_complex_from (
104- * , residual_map : np .ndarray , noise_map : np .ndarray
105- ) -> np .ndarray :
105+ * , residual_map : jnp .ndarray , noise_map : jnp .ndarray
106+ ) -> jnp .ndarray :
106107 """
107108 Returns the normalized residual-map of the fit of complex model-data to a dataset, where:
108109
@@ -126,8 +127,8 @@ def normalized_residual_map_complex_from(
126127
127128
128129def chi_squared_map_complex_from (
129- * , residual_map : np .ndarray , noise_map : np .ndarray
130- ) -> np .ndarray :
130+ * , residual_map : jnp .ndarray , noise_map : jnp .ndarray
131+ ) -> jnp .ndarray :
131132 """
132133 Returnss the chi-squared-map of the fit of complex model-data to a dataset, where:
133134
@@ -145,7 +146,7 @@ def chi_squared_map_complex_from(
145146 return chi_squared_map_real + 1j * chi_squared_map_imag
146147
147148
148- def chi_squared_complex_from (* , chi_squared_map : np .ndarray ) -> float :
149+ def chi_squared_complex_from (* , chi_squared_map : jnp .ndarray ) -> float :
149150 """
150151 Returns the chi-squared terms of each complex model data's fit to a masked dataset, by summing the masked
151152 chi-squared-map of the fit.
@@ -157,12 +158,12 @@ def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float:
157158 chi_squared_map
158159 The chi-squared-map of values of the model-data fit to the dataset.
159160 """
160- chi_squared_real = np .sum (chi_squared_map .real )
161- chi_squared_imag = np .sum (chi_squared_map .imag )
161+ chi_squared_real = jnp .sum (chi_squared_map .real )
162+ chi_squared_imag = jnp .sum (chi_squared_map .imag )
162163 return chi_squared_real + chi_squared_imag
163164
164165
165- def noise_normalization_complex_from (* , noise_map : np .ndarray ) -> float :
166+ def noise_normalization_complex_from (* , noise_map : jnp .ndarray ) -> float :
166167 """
167168 Returns the noise-map normalization terms of a complex noise-map, summing the noise_map value in every pixel as:
168169
@@ -173,8 +174,8 @@ def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float:
173174 noise_map
174175 The masked noise-map of the dataset.
175176 """
176- noise_normalization_real = np .sum (np .log (2 * np .pi * noise_map .real ** 2.0 ))
177- noise_normalization_imag = np .sum (np .log (2 * np .pi * noise_map .imag ** 2.0 ))
177+ noise_normalization_real = jnp .sum (jnp .log (2 * jnp .pi * noise_map .real ** 2.0 ))
178+ noise_normalization_imag = jnp .sum (jnp .log (2 * jnp .pi * noise_map .imag ** 2.0 ))
178179 return noise_normalization_real + noise_normalization_imag
179180
180181
@@ -198,9 +199,7 @@ def residual_map_with_mask_from(
198199 model_data
199200 The model data used to fit the data.
200201 """
201- return np .subtract (
202- data , model_data , out = np .zeros_like (data ), where = np .asarray (mask ) == 0
203- )
202+ return jnp .where (jnp .asarray (mask ) == 0 , jnp .subtract (data , model_data ), 0 )
204203
205204
206205@to_new_array
@@ -223,13 +222,7 @@ def normalized_residual_map_with_mask_from(
223222 mask
224223 The mask applied to the residual-map, where `False` entries are included in the calculation.
225224 """
226- return np .divide (
227- residual_map ,
228- noise_map ,
229- out = np .zeros_like (residual_map ),
230- where = np .asarray (mask ) == 0 ,
231- )
232-
225+ return jnp .where (jnp .asarray (mask ) == 0 , jnp .divide (residual_map , noise_map ), 0 )
233226
234227@to_new_array
235228def chi_squared_map_with_mask_from (
@@ -251,13 +244,10 @@ def chi_squared_map_with_mask_from(
251244 mask
252245 The mask applied to the residual-map, where `False` entries are included in the calculation.
253246 """
254- return np .square (
255- np .divide (
256- residual_map ,
257- noise_map ,
258- out = np .zeros_like (residual_map ),
259- where = np .asarray (mask ) == 0 ,
260- )
247+ return jnp .where (
248+ jnp .asarray (mask ) == 0 ,
249+ jnp .square (residual_map / noise_map ),
250+ 0
261251 )
262252
263253
@@ -275,7 +265,7 @@ def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask) -> f
275265 mask
276266 The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
277267 """
278- return float (np .sum (chi_squared_map [np .asarray (mask ) == 0 ]))
268+ return float (jnp .sum (chi_squared_map [jnp .asarray (mask ) == 0 ]))
279269
280270
281271def chi_squared_with_mask_fast_from (
@@ -302,14 +292,14 @@ def chi_squared_with_mask_fast_from(
302292 The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
303293 """
304294 return float (
305- np .sum (
306- np .square (
307- np .divide (
308- np .subtract (
295+ jnp .sum (
296+ jnp .square (
297+ jnp .divide (
298+ jnp .subtract (
309299 data ,
310300 model_data ,
311- )[np .asarray (mask ) == 0 ],
312- noise_map [np .asarray (mask ) == 0 ],
301+ )[jnp .asarray (mask ) == 0 ],
302+ noise_map [jnp .asarray (mask ) == 0 ],
313303 )
314304 )
315305 )
@@ -331,11 +321,11 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask) ->
331321 mask
332322 The mask applied to the noise-map, where `False` entries are included in the calculation.
333323 """
334- return float (np .sum (np .log (2 * np .pi * noise_map [np .asarray (mask ) == 0 ] ** 2.0 )))
324+ return float (jnp .sum (jnp .log (2 * jnp .pi * noise_map [jnp .asarray (mask ) == 0 ] ** 2.0 )))
335325
336326
337327def chi_squared_with_noise_covariance_from (
338- * , residual_map : ty .DataLike , noise_covariance_matrix_inv : np .ndarray
328+ * , residual_map : ty .DataLike , noise_covariance_matrix_inv : jnp .ndarray
339329) -> float :
340330 """
341331 Returns the chi-squared value of the fit of model-data to a masked dataset, where
@@ -351,7 +341,7 @@ def chi_squared_with_noise_covariance_from(
351341 The inverse of the noise covariance matrix.
352342 """
353343
354- return residual_map @ noise_covariance_matrix_inv @ residual_map
344+ return residual_map . array @ noise_covariance_matrix_inv @ residual_map . array
355345
356346
357347def log_likelihood_from (* , chi_squared : float , noise_normalization : float ) -> float :
@@ -431,8 +421,8 @@ def log_evidence_from(
431421
432422
433423def residual_flux_fraction_map_from (
434- * , residual_map : np .ndarray , data : np .ndarray
435- ) -> np .ndarray :
424+ * , residual_map : jnp .ndarray , data : jnp .ndarray
425+ ) -> jnp .ndarray :
436426 """
437427 Returns the residual flux fraction map of the fit of model-data to a masked dataset, where:
438428
@@ -445,12 +435,12 @@ def residual_flux_fraction_map_from(
445435 data
446436 The data of the dataset.
447437 """
448- return np . divide ( residual_map , data , out = np . zeros_like ( residual_map ) )
438+ return jnp . where ( data != 0 , residual_map / data , 0 )
449439
450440
451441def residual_flux_fraction_map_with_mask_from (
452- * , residual_map : np .ndarray , data : np .ndarray , mask : Mask
453- ) -> np .ndarray :
442+ * , residual_map : jnp .ndarray , data : jnp .ndarray , mask : Mask
443+ ) -> jnp .ndarray :
454444 """
455445 Returnss the residual flux fraction map of the fit of model-data to a masked dataset, where:
456446
@@ -467,9 +457,4 @@ def residual_flux_fraction_map_with_mask_from(
467457 mask
468458 The mask applied to the residual-map, where `False` entries are included in the calculation.
469459 """
470- return np .divide (
471- residual_map ,
472- data ,
473- out = np .zeros_like (residual_map ),
474- where = np .asarray (mask ) == 0 ,
475- )
460+ return jnp .where (mask == 0 , residual_map / data , 0 )
0 commit comments