Skip to content

fix: NaN gradients from jaxnnls backward pass in NNLS reconstruction#279

Merged
Jammy2211 merged 1 commit intomainfrom
feature/nnls-gradient-nan-fix
Apr 14, 2026
Merged

fix: NaN gradients from jaxnnls backward pass in NNLS reconstruction#279
Jammy2211 merged 1 commit intomainfrom
feature/nnls-gradient-nan-fix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Fix NaN gradients returned by jax.value_and_grad through
reconstruction_positive_only_from whenever the curvature matrix fed to
jaxnnls.solve_nnls_primal is ill-conditioned (e.g. 40x40 MGE problems with
cond ~ 6.7e10). The jaxnnls forward solver hits its 50-iteration cap without
converging, the relaxed-KKT backward pass diverges to NaN from the
non-converged seed, and the whole downstream gradient pipeline is poisoned.

Jacobi (diagonal) preconditioning inside the JAX branch rescales Q so its
diagonal is unit (D = diag(Q)^{-1/2}, solve (D Q D) y = D q, recover
x = D y). D is diagonal and positive so y >= 0 implies x >= 0, and the
primal solution is mathematically equivalent to the raw solve. Empirically
cond drops ~4 orders of magnitude, forward converges in ~19 iters, relaxed
converges in ~21 iters, and the backward pass yields finite gradients. Forward
NNLS is also ~2x faster as a side effect.

Behaviour is gated by a new config key
general.yaml::inversion.nnls_jacobi_preconditioning (default true). The
lookup falls back to True if the key is missing so workspace configs that
shadow autoarray/config/general.yaml continue to work without change.

Fixes #278.

API Changes

None — internal change to the JAX branch of reconstruction_positive_only_from.
A new general.yaml config key is added (default on). No public Python symbols
changed. See full details below.

Test Plan

  • pytest test_autoarray/ — 732 passed (includes 2 new regression tests).
  • Full MGE gradient script (mge_gradients.py) — all 9 steps including
    the full Fitness pipeline now produce finite gradients (was 4 FAILs).
  • Likelihood invariance sweep (nnls_invariance.py, MGE + Delaunay, 5
    parameter vectors each) — max rel diff ~1e-15 vs pre-fix baseline.
  • Benchmark (nnls_precondition_bench.py) — forward solve 2.18x faster,
    vmap 2.38x faster, value_and_grad 1.37x faster.
Full API Changes (for automation & release notes)

Added

  • general.yaml::inversion.nnls_jacobi_preconditioning (bool, default true) —
    toggles Jacobi preconditioning of Q inside the JAX branch of
    reconstruction_positive_only_from. Not exposed as a Python argument; set
    to false in a project's config/general.yaml to restore the raw solve.

Changed Behaviour

  • autoarray.util.inversion.reconstruction_positive_only_from (JAX path
    only) — by default now preconditions the curvature matrix before calling
    jaxnnls.solve_nnls_primal. Primal solution is equivalent to the raw
    solve to ~machine epsilon on well-conditioned problems; produces finite
    rather than NaN gradients on ill-conditioned ones. NumPy path unchanged.

Migration

  • None required. Existing callers get correct gradients by default. To
    opt out, set inversion.nnls_jacobi_preconditioning: false in the
    relevant config/general.yaml.

🤖 Generated with Claude Code

The curvature matrix passed into `jaxnnls.solve_nnls_primal` from
`reconstruction_positive_only_from` is severely ill-conditioned for
typical MGE / linear-light-profile problems (cond ~ 6.7e10 on a 40x40
Q). This causes:

  - forward NNLS to hit its 50-iteration cap without converging,
  - the relaxed-KKT backward solver (custom_vjp) to diverge to NaN
    from a non-converged seed,
  - `jax.value_and_grad` to return all-NaN gradients through the whole
    downstream pipeline.

Fix: Jacobi (diagonal) preconditioning inside the JAX branch. Rescale
Q so its diagonal is unit via D = diag(Q)^{-1/2}, solve
`(D Q D) y = D q` with `y >= 0`, recover `x = D y`. D is diagonal and
positive so non-negativity is preserved, and the primal solution is
mathematically equivalent to the raw solve. Empirically cond drops
~4 orders of magnitude (6.7e10 -> 1.1e7), forward converges in ~19
iters, relaxed converges in ~21 iters, and grad norm is finite
(~6.8e4). Forward NNLS also runs ~2x faster.

Gated by a new `inversion.nnls_jacobi_preconditioning` key in
`autoarray/config/general.yaml`, default True. Falls back to True if
the key is missing so workspace configs that shadow ours do not break.

Adds two regression tests to `test_inversion_util.py` covering the
ill-conditioned-gradient case and primal equivalence with the raw
solve on a well-conditioned problem.
@Jammy2211 Jammy2211 merged commit 8e6848d into main Apr 14, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/nnls-gradient-nan-fix branch April 14, 2026 21:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fix: NaN gradients from jaxnnls backward pass in NNLS reconstruction

1 participant