Skip to content

Commit 72af86b

Browse files
committed
fix all fit tests
1 parent 083ed0b commit 72af86b

2 files changed

Lines changed: 135 additions & 151 deletions

File tree

autoarray/fit/fit_util.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import wraps
2-
import jax.numpy as np
2+
import jax.numpy as jnp
3+
import numpy as np
34

45
from autoarray.mask.abstract_mask import Mask
56

@@ -83,7 +84,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float:
8384
chi_squared_map
8485
The chi-squared-map of values of the model-data fit to the dataset.
8586
"""
86-
return np.sum(chi_squared_map._array)
87+
return jnp.sum(np.array(chi_squared_map))
8788

8889

8990
def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
@@ -97,12 +98,12 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
9798
noise_map
9899
The masked noise-map of the dataset.
99100
"""
100-
return np.sum(np.log(2 * np.pi * noise_map._array**2.0))
101+
return jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map)**2.0))
101102

102103

103104
def normalized_residual_map_complex_from(
104-
*, residual_map: np.ndarray, noise_map: np.ndarray
105-
) -> np.ndarray:
105+
*, residual_map: jnp.ndarray, noise_map: jnp.ndarray
106+
) -> jnp.ndarray:
106107
"""
107108
Returns the normalized residual-map of the fit of complex model-data to a dataset, where:
108109
@@ -126,8 +127,8 @@ def normalized_residual_map_complex_from(
126127

127128

128129
def chi_squared_map_complex_from(
129-
*, residual_map: np.ndarray, noise_map: np.ndarray
130-
) -> np.ndarray:
130+
*, residual_map: jnp.ndarray, noise_map: jnp.ndarray
131+
) -> jnp.ndarray:
131132
"""
132133
Returnss the chi-squared-map of the fit of complex model-data to a dataset, where:
133134
@@ -145,7 +146,7 @@ def chi_squared_map_complex_from(
145146
return chi_squared_map_real + 1j * chi_squared_map_imag
146147

147148

148-
def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float:
149+
def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float:
149150
"""
150151
Returns the chi-squared terms of each complex model data's fit to a masked dataset, by summing the masked
151152
chi-squared-map of the fit.
@@ -157,12 +158,12 @@ def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float:
157158
chi_squared_map
158159
The chi-squared-map of values of the model-data fit to the dataset.
159160
"""
160-
chi_squared_real = np.sum(chi_squared_map.real)
161-
chi_squared_imag = np.sum(chi_squared_map.imag)
161+
chi_squared_real = jnp.sum(chi_squared_map.real)
162+
chi_squared_imag = jnp.sum(chi_squared_map.imag)
162163
return chi_squared_real + chi_squared_imag
163164

164165

165-
def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float:
166+
def noise_normalization_complex_from(*, noise_map: jnp.ndarray) -> float:
166167
"""
167168
Returns the noise-map normalization terms of a complex noise-map, summing the noise_map value in every pixel as:
168169
@@ -173,8 +174,8 @@ def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float:
173174
noise_map
174175
The masked noise-map of the dataset.
175176
"""
176-
noise_normalization_real = np.sum(np.log(2 * np.pi * noise_map.real**2.0))
177-
noise_normalization_imag = np.sum(np.log(2 * np.pi * noise_map.imag**2.0))
177+
noise_normalization_real = jnp.sum(jnp.log(2 * jnp.pi * noise_map.real**2.0))
178+
noise_normalization_imag = jnp.sum(jnp.log(2 * jnp.pi * noise_map.imag**2.0))
178179
return noise_normalization_real + noise_normalization_imag
179180

180181

@@ -198,9 +199,7 @@ def residual_map_with_mask_from(
198199
model_data
199200
The model data used to fit the data.
200201
"""
201-
return np.subtract(
202-
data, model_data, out=np.zeros_like(data), where=np.asarray(mask) == 0
203-
)
202+
return jnp.where(jnp.asarray(mask) == 0, jnp.subtract(data, model_data), 0)
204203

205204

206205
@to_new_array
@@ -223,13 +222,7 @@ def normalized_residual_map_with_mask_from(
223222
mask
224223
The mask applied to the residual-map, where `False` entries are included in the calculation.
225224
"""
226-
return np.divide(
227-
residual_map,
228-
noise_map,
229-
out=np.zeros_like(residual_map),
230-
where=np.asarray(mask) == 0,
231-
)
232-
225+
return jnp.where(jnp.asarray(mask) == 0, jnp.divide(residual_map, noise_map), 0)
233226

234227
@to_new_array
235228
def chi_squared_map_with_mask_from(
@@ -251,13 +244,10 @@ def chi_squared_map_with_mask_from(
251244
mask
252245
The mask applied to the residual-map, where `False` entries are included in the calculation.
253246
"""
254-
return np.square(
255-
np.divide(
256-
residual_map,
257-
noise_map,
258-
out=np.zeros_like(residual_map),
259-
where=np.asarray(mask) == 0,
260-
)
247+
return jnp.where(
248+
jnp.asarray(mask) == 0,
249+
jnp.square(residual_map / noise_map),
250+
0
261251
)
262252

263253

@@ -275,7 +265,7 @@ def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask) -> f
275265
mask
276266
The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
277267
"""
278-
return float(np.sum(chi_squared_map[np.asarray(mask) == 0]))
268+
return float(jnp.sum(chi_squared_map[jnp.asarray(mask) == 0]))
279269

280270

281271
def chi_squared_with_mask_fast_from(
@@ -302,14 +292,14 @@ def chi_squared_with_mask_fast_from(
302292
The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
303293
"""
304294
return float(
305-
np.sum(
306-
np.square(
307-
np.divide(
308-
np.subtract(
295+
jnp.sum(
296+
jnp.square(
297+
jnp.divide(
298+
jnp.subtract(
309299
data,
310300
model_data,
311-
)[np.asarray(mask) == 0],
312-
noise_map[np.asarray(mask) == 0],
301+
)[jnp.asarray(mask) == 0],
302+
noise_map[jnp.asarray(mask) == 0],
313303
)
314304
)
315305
)
@@ -331,11 +321,11 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask) ->
331321
mask
332322
The mask applied to the noise-map, where `False` entries are included in the calculation.
333323
"""
334-
return float(np.sum(np.log(2 * np.pi * noise_map[np.asarray(mask) == 0] ** 2.0)))
324+
return float(jnp.sum(jnp.log(2 * jnp.pi * noise_map[jnp.asarray(mask) == 0] ** 2.0)))
335325

336326

337327
def chi_squared_with_noise_covariance_from(
338-
*, residual_map: ty.DataLike, noise_covariance_matrix_inv: np.ndarray
328+
*, residual_map: ty.DataLike, noise_covariance_matrix_inv: jnp.ndarray
339329
) -> float:
340330
"""
341331
Returns the chi-squared value of the fit of model-data to a masked dataset, where
@@ -351,7 +341,7 @@ def chi_squared_with_noise_covariance_from(
351341
The inverse of the noise covariance matrix.
352342
"""
353343

354-
return residual_map @ noise_covariance_matrix_inv @ residual_map
344+
return residual_map.array @ noise_covariance_matrix_inv @ residual_map.array
355345

356346

357347
def log_likelihood_from(*, chi_squared: float, noise_normalization: float) -> float:
@@ -431,8 +421,8 @@ def log_evidence_from(
431421

432422

433423
def residual_flux_fraction_map_from(
434-
*, residual_map: np.ndarray, data: np.ndarray
435-
) -> np.ndarray:
424+
*, residual_map: jnp.ndarray, data: jnp.ndarray
425+
) -> jnp.ndarray:
436426
"""
437427
Returns the residual flux fraction map of the fit of model-data to a masked dataset, where:
438428
@@ -445,12 +435,12 @@ def residual_flux_fraction_map_from(
445435
data
446436
The data of the dataset.
447437
"""
448-
return np.divide(residual_map, data, out=np.zeros_like(residual_map))
438+
return jnp.where(data != 0, residual_map / data, 0)
449439

450440

451441
def residual_flux_fraction_map_with_mask_from(
452-
*, residual_map: np.ndarray, data: np.ndarray, mask: Mask
453-
) -> np.ndarray:
442+
*, residual_map: jnp.ndarray, data: jnp.ndarray, mask: Mask
443+
) -> jnp.ndarray:
454444
"""
455445
Returnss the residual flux fraction map of the fit of model-data to a masked dataset, where:
456446
@@ -467,9 +457,4 @@ def residual_flux_fraction_map_with_mask_from(
467457
mask
468458
The mask applied to the residual-map, where `False` entries are included in the calculation.
469459
"""
470-
return np.divide(
471-
residual_map,
472-
data,
473-
out=np.zeros_like(residual_map),
474-
where=np.asarray(mask) == 0,
475-
)
460+
return jnp.where(mask == 0, residual_map / data, 0)

0 commit comments

Comments
 (0)