Skip to content

Commit 8e6848d

Browse files
Jammy2211Jammy2211
andauthored
Fix NaN gradients from jaxnnls backward pass via Jacobi preconditioning (#279)
The curvature matrix passed into `jaxnnls.solve_nnls_primal` from `reconstruction_positive_only_from` is severely ill-conditioned for typical MGE / linear-light-profile problems (cond ~ 6.7e10 on a 40x40 Q). This causes: - forward NNLS to hit its 50-iteration cap without converging, - the relaxed-KKT backward solver (custom_vjp) to diverge to NaN from a non-converged seed, - `jax.value_and_grad` to return all-NaN gradients through the whole downstream pipeline. Fix: Jacobi (diagonal) preconditioning inside the JAX branch. Rescale Q so its diagonal is unit via D = diag(Q)^{-1/2}, solve `(D Q D) y = D q` with `y >= 0`, recover `x = D y`. D is diagonal and positive so non-negativity is preserved, and the primal solution is mathematically equivalent to the raw solve. Empirically cond drops ~4 orders of magnitude (6.7e10 -> 1.1e7), forward converges in ~19 iters, relaxed converges in ~21 iters, and grad norm is finite (~6.8e4). Forward NNLS also runs ~2x faster. Gated by a new `inversion.nnls_jacobi_preconditioning` key in `autoarray/config/general.yaml`, default True. Falls back to True if the key is missing so workspace configs that shadow ours do not break. Adds two regression tests to `test_inversion_util.py` covering the ill-conditioned-gradient case and primal equivalence with the raw solve on a well-conditioned problem. Co-authored-by: Jammy2211 <JNightingale2211@gmail.com>
1 parent d74a6d3 commit 8e6848d

3 files changed

Lines changed: 95 additions & 0 deletions

File tree

autoarray/config/general.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ inversion:
66
use_edge_zeroed_pixels : true # If True, the edge pixels of a pixelization are set to zero, which prevents unphysical values in the reconstructed solution at the edge of the pixelization.
77
no_regularization_add_to_curvature_diag_value : 1.0e-3 # The default value added to the curvature matrix's diagonal when regularization is not applied to a linear object, which prevents inversion's failing due to the matrix being singular.
88
use_border_relocator: false # If True, by default a pixelization's border is used to relocate all pixels outside its border to the border.
9+
nnls_jacobi_preconditioning: true # If True (default), the curvature matrix passed to jaxnnls.solve_nnls_primal is Jacobi-preconditioned (D Q D y = D q, x = D y). Fixes NaN backward-pass gradients on ill-conditioned Q and roughly halves forward solve time. Set False to restore the raw unpreconditioned solve.
910
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.
1011
numba:
1112
use_numba: true

autoarray/inversion/inversion/inversion_util.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,29 @@ def reconstruction_positive_only_from(
275275
if xp.__name__.startswith("jax"):
276276

277277
import jaxnnls
278+
from autoconf import conf
279+
280+
try:
281+
use_jacobi = conf.instance["general"]["inversion"][
282+
"nnls_jacobi_preconditioning"
283+
]
284+
except KeyError:
285+
# Workspaces ship their own general.yaml that shadows autoarray's;
286+
# default to True so gradients remain well-defined unless the user
287+
# explicitly disables preconditioning in the shadowing config.
288+
use_jacobi = True
289+
290+
if use_jacobi:
291+
# Ill-conditioned Q makes jaxnnls's relaxed-KKT backward pass
292+
# produce NaN gradients. Rescale Q so its diagonal is unit:
293+
# solve (D Q D) y = D q with y >= 0, recover x = D y. D is
294+
# diagonal positive, so non-negativity is preserved and the
295+
# primal solution is mathematically equivalent.
296+
d = xp.sqrt(xp.diag(curvature_reg_matrix))
297+
D = 1.0 / d
298+
Q_pc = (curvature_reg_matrix * D[:, None]) * D[None, :]
299+
q_pc = data_vector * D
300+
return jaxnnls.solve_nnls_primal(Q_pc, q_pc) * D
278301

279302
return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector)
280303

test_autoarray/inversion/inversion/test_inversion_util.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,74 @@ def test__preconditioner_matrix_via_mapping_matrix_from():
228228
preconditioner_matrix
229229
== np.array([[5.0, 2.0, 3.0], [4.0, 9.0, 6.0], [7.0, 8.0, 13.0]])
230230
).all()
231+
232+
233+
def test__reconstruction_positive_only_from__jax_ill_conditioned_grad_is_finite():
234+
"""
235+
On ill-conditioned curvature matrices the jaxnnls backward pass used to
236+
return NaN gradients, because the relaxed-KKT solver diverged. Jacobi
237+
preconditioning inside `reconstruction_positive_only_from` re-parameterises
238+
the NNLS problem so the solve converges and `jax.value_and_grad` produces
239+
finite gradients. Skip the test if jax / jaxnnls are not available.
240+
"""
241+
jax = pytest.importorskip("jax")
242+
import jax.numpy as jnp
243+
pytest.importorskip("jaxnnls")
244+
245+
# A small deliberately ill-conditioned symmetric positive-definite Q,
246+
# cond(Q) ~ 1e7, which is enough to break the raw jaxnnls backward pass.
247+
rng = np.random.default_rng(0)
248+
n = 10
249+
U, _ = np.linalg.qr(rng.standard_normal((n, n)))
250+
eigs = np.logspace(-4, 3, n)
251+
Q_np = (U * eigs) @ U.T
252+
Q_np = 0.5 * (Q_np + Q_np.T)
253+
q_np = rng.standard_normal(n)
254+
255+
Q = jnp.array(Q_np)
256+
q = jnp.array(q_np)
257+
258+
def loss(q_in):
259+
x = aa.util.inversion.reconstruction_positive_only_from(
260+
data_vector=q_in, curvature_reg_matrix=Q, xp=jnp,
261+
)
262+
return jnp.sum(x)
263+
264+
value, grad = jax.value_and_grad(loss)(q)
265+
266+
assert np.isfinite(float(value))
267+
grad_np = np.array(grad)
268+
assert np.all(np.isfinite(grad_np)), (
269+
f"gradient has {np.sum(~np.isfinite(grad_np))} non-finite entries"
270+
)
271+
272+
273+
def test__reconstruction_positive_only_from__jax_matches_unpreconditioned_primal():
274+
"""
275+
Jacobi preconditioning is a change of coordinates; the forward primal
276+
solution must match the raw jaxnnls solve to within solver tolerance for
277+
a moderately-conditioned problem where the raw solver also converges.
278+
"""
279+
jax = pytest.importorskip("jax")
280+
import jax.numpy as jnp
281+
jaxnnls = pytest.importorskip("jaxnnls")
282+
283+
rng = np.random.default_rng(1)
284+
n = 8
285+
U, _ = np.linalg.qr(rng.standard_normal((n, n)))
286+
eigs = np.linspace(0.5, 5.0, n) # well-conditioned
287+
Q_np = (U * eigs) @ U.T
288+
Q_np = 0.5 * (Q_np + Q_np.T)
289+
q_np = rng.standard_normal(n)
290+
291+
Q = jnp.array(Q_np)
292+
q = jnp.array(q_np)
293+
294+
x_raw = np.array(jaxnnls.solve_nnls_primal(Q, q))
295+
x_pc = np.array(
296+
aa.util.inversion.reconstruction_positive_only_from(
297+
data_vector=q, curvature_reg_matrix=Q, xp=jnp,
298+
)
299+
)
300+
301+
np.testing.assert_allclose(x_pc, x_raw, rtol=1e-6, atol=1e-8)

0 commit comments

Comments
 (0)