Commit b214413
Fix NaN gradients from jaxnnls backward pass via Jacobi preconditioning
The curvature matrix passed into `jaxnnls.solve_nnls_primal` from
`reconstruction_positive_only_from` is severely ill-conditioned for
typical MGE / linear-light-profile problems (cond ~ 6.7e10 on a 40x40
Q). This causes:
- forward NNLS to hit its 50-iteration cap without converging,
- the relaxed-KKT backward solver (custom_vjp) to diverge to NaN
from a non-converged seed,
- `jax.value_and_grad` to return all-NaN gradients through the whole
downstream pipeline.
Fix: Jacobi (diagonal) preconditioning inside the JAX branch. Rescale
Q so its diagonal is unit via D = diag(Q)^{-1/2}, solve
`(D Q D) y = D q` with `y >= 0`, recover `x = D y`. D is diagonal and
positive so non-negativity is preserved, and the primal solution is
mathematically equivalent to the raw solve. Empirically cond drops
~4 orders of magnitude (6.7e10 -> 1.1e7), forward converges in ~19
iters, relaxed converges in ~21 iters, and grad norm is finite
(~6.8e4). Forward NNLS also runs ~2x faster.
Gated by a new `inversion.nnls_jacobi_preconditioning` key in
`autoarray/config/general.yaml`, default True. Falls back to True if
the key is missing so workspace configs that shadow ours do not break.
Adds two regression tests to `test_inversion_util.py` covering the
ill-conditioned-gradient case and primal equivalence with the raw
solve on a well-conditioned problem.1 parent d74a6d3 commit b214413
3 files changed
Lines changed: 95 additions & 0 deletions
File tree
- autoarray
- config
- inversion/inversion
- test_autoarray/inversion/inversion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| 9 | + | |
9 | 10 | | |
10 | 11 | | |
11 | 12 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
275 | 275 | | |
276 | 276 | | |
277 | 277 | | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
278 | 301 | | |
279 | 302 | | |
280 | 303 | | |
| |||
Lines changed: 71 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
228 | 228 | | |
229 | 229 | | |
230 | 230 | | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
0 commit comments