fix: NaN gradients from jaxnnls backward pass in NNLS reconstruction#279
Merged
fix: NaN gradients from jaxnnls backward pass in NNLS reconstruction#279
Conversation
The curvature matrix passed into `jaxnnls.solve_nnls_primal` from
`reconstruction_positive_only_from` is severely ill-conditioned for
typical MGE / linear-light-profile problems (cond ~ 6.7e10 on a 40x40
Q). This causes:
- forward NNLS to hit its 50-iteration cap without converging,
- the relaxed-KKT backward solver (custom_vjp) to diverge to NaN
from a non-converged seed,
- `jax.value_and_grad` to return all-NaN gradients through the whole
downstream pipeline.
Fix: Jacobi (diagonal) preconditioning inside the JAX branch. Rescale
Q so its diagonal is unit via D = diag(Q)^{-1/2}, solve
`(D Q D) y = D q` with `y >= 0`, recover `x = D y`. D is diagonal and
positive so non-negativity is preserved, and the primal solution is
mathematically equivalent to the raw solve. Empirically cond drops
~4 orders of magnitude (6.7e10 -> 1.1e7), forward converges in ~19
iters, relaxed converges in ~21 iters, and grad norm is finite
(~6.8e4). Forward NNLS also runs ~2x faster.
Gated by a new `inversion.nnls_jacobi_preconditioning` key in
`autoarray/config/general.yaml`, default True. Falls back to True if
the key is missing so workspace configs that shadow ours do not break.
Adds two regression tests to `test_inversion_util.py` covering the
ill-conditioned-gradient case and primal equivalence with the raw
solve on a well-conditioned problem.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix NaN gradients returned by
jax.value_and_gradthroughreconstruction_positive_only_fromwhenever the curvature matrix fed tojaxnnls.solve_nnls_primalis ill-conditioned (e.g. 40x40 MGE problems withcond ~ 6.7e10). The jaxnnls forward solver hits its 50-iteration cap without
converging, the relaxed-KKT backward pass diverges to NaN from the
non-converged seed, and the whole downstream gradient pipeline is poisoned.
Jacobi (diagonal) preconditioning inside the JAX branch rescales Q so its
diagonal is unit (
D = diag(Q)^{-1/2}, solve(D Q D) y = D q, recoverx = D y). D is diagonal and positive soy >= 0impliesx >= 0, and theprimal solution is mathematically equivalent to the raw solve. Empirically
cond drops ~4 orders of magnitude, forward converges in ~19 iters, relaxed
converges in ~21 iters, and the backward pass yields finite gradients. Forward
NNLS is also ~2x faster as a side effect.
Behaviour is gated by a new config key
general.yaml::inversion.nnls_jacobi_preconditioning(defaulttrue). Thelookup falls back to
Trueif the key is missing so workspace configs thatshadow
autoarray/config/general.yamlcontinue to work without change.Fixes #278.
API Changes
None — internal change to the JAX branch of
reconstruction_positive_only_from.A new
general.yamlconfig key is added (default on). No public Python symbolschanged. See full details below.
Test Plan
pytest test_autoarray/— 732 passed (includes 2 new regression tests).mge_gradients.py) — all 9 steps includingthe full Fitness pipeline now produce finite gradients (was 4 FAILs).
nnls_invariance.py, MGE + Delaunay, 5parameter vectors each) — max rel diff ~1e-15 vs pre-fix baseline.
nnls_precondition_bench.py) — forward solve 2.18x faster,vmap 2.38x faster, value_and_grad 1.37x faster.
Full API Changes (for automation & release notes)
Added
general.yaml::inversion.nnls_jacobi_preconditioning(bool, defaulttrue) —toggles Jacobi preconditioning of Q inside the JAX branch of
reconstruction_positive_only_from. Not exposed as a Python argument; setto
falsein a project'sconfig/general.yamlto restore the raw solve.Changed Behaviour
autoarray.util.inversion.reconstruction_positive_only_from(JAX pathonly) — by default now preconditions the curvature matrix before calling
jaxnnls.solve_nnls_primal. Primal solution is equivalent to the rawsolve to ~machine epsilon on well-conditioned problems; produces finite
rather than NaN gradients on ill-conditioned ones. NumPy path unchanged.
Migration
opt out, set
inversion.nnls_jacobi_preconditioning: falsein therelevant
config/general.yaml.🤖 Generated with Claude Code