Skip to content

Fix exploding JAX gradient in rectangular interpolator (duplicate sort-knots)#281

Open
Jammy2211 wants to merge 2 commits intomainfrom
feature/rectangular-interp-grad-fix
Open

Fix exploding JAX gradient in rectangular interpolator (duplicate sort-knots)#281
Jammy2211 wants to merge 2 commits intomainfrom
feature/rectangular-interp-grad-fix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Fixes an ~O(1e24) gradient blow-up inside autoarray.inversion.mesh.interpolator.rectangular.create_transforms that silently poisoned the entire JAX pixelization likelihood gradient. Downstream of this function, FitImaging with any rectangular pixelization mesh (RectangularUniform, RectangularAdaptDensity) produced effectively-zero or NaN gradients through the mapping matrix, making the pixelization path unusable for gradient-based optimisation / HMC.

After the patch, the mapping matrix, data vector, and curvature matrix all return finite gradients that agree with finite differences up to expected bin-boundary O(1) noise.

What was going wrong — detailed narrative for the reviewer

The rectangular interpolator converts each ray-traced source-plane point into a rank-space coordinate via

sort_points = jnp.sort(traced_points, axis=0)
t           = jnp.arange(1, N+1) / (N+1)
transform(q) = jnp.interp(q, sort_points, t)   # per column, via vmap

and then uses the output to index into a regular source_grid_size × source_grid_size mesh with bilinear weights (floor/ceil + sub-bin t_row/t_col).

The bug. Ray-traced source-plane grids from realistic lens models contain massive numbers of duplicate coordinates. For an Isothermal lens over a circular mask of radius 3.5 (the standard HST test setup), the probe in autolens_workspace_developer/jax_profiling/imaging/mapper_grad_isolate.py measured:

data_grid N                 = 15361
y sort gap min / median     = 0.000e+00 / 6.663e-09
x sort gap min / median     = 0.000e+00 / 6.670e-09
# y gaps <= 1e-12           = 7680   (~50% of all sort-adjacent pairs)
# x gaps <= 1e-12           = 7680

Half of all adjacent pairs in the sorted source grid are at exactly zero gap (floating-point zero — these are genuinely coincident rays, not rounding noise). This breaks jnp.interp's vjp in two distinct ways, which is the crucial subtlety:

  1. Knot-gradient term. The vjp of jnp.interp w.r.t. its knot array xp contains a division by xp[i+1] - xp[i]. With 7680 exact-zero gaps this is literally 0/0, which JAX propagates as O(1e24) cotangents back through sort_points → source_plane_data_grid → deflections → lens model parameters.

  2. Query-gradient term. The vjp w.r.t. the query x is the local slope, (yp[i+1] - yp[i]) / (xp[i+1] - xp[i]). Even if the knots themselves were frozen, any query landing in a zero-gap bin sees slope = Δt / 0. An isolated one-bin duplicate doesn't trigger this in practice (JAX's indexing dodges it), but long runs of consecutive duplicates, which is exactly what caustics produce, do.

Stage-by-stage gradient vs finite-difference before the fix (from mapper_grad_isolate.py):

       stage          JAX grad      FD grad      ratio
source_grid_scaled   2.04e+00     2.04e+00     1.00         # OK
sort_points          2.04e+00     2.04e+00     1.00         # OK
grid_over_scaled     2.57e-01     2.57e-01     1.00         # OK
grid_over_transformed -6.62e+24   -7.82e+02    8.5e+21      # <-- blow-up starts here
grid_over_index      -4.47e+27   -5.20e+05    8.6e+21
t_row_t_col          -1.99e+26    1.34e+05    1.5e+21
weights              -3.97e+25    6.25e+02    6.4e+22

The blow-up is injected exactly at the jnp.interp call and then propagated by linearity through every downstream operation, including the final inversion.operated_mapping_matrixD, F, NNLS, log_evidence.

The fix

Inside the xp.__name__.startswith("jax") branch of create_transforms:

JITTER = 1e-7
jitter = xp.arange(N, dtype=sort_points.dtype) * JITTER
jitter = xp.stack([jitter, jitter], axis=1)
sort_points = jax.lax.stop_gradient(sort_points + jitter)

Two patches, each addressing one of the two failure modes above:

  • stop_gradient(sort_points) kills the knot-gradient term (failure mode 1). This is semantically correct, not a hack: the only consumer of this interpolator's output is adaptive_rectangular_mappings_weights_via_interpolation_from, which does floor(...) / ceil(...) on the transformed coordinate to pick the 4 corner pixels. That bin-assignment already has zero gradient. The knot-gradient term we drop is precisely the derivative of which-bin-is-selected with respect to the knot positions — a derivative that has no downstream consumer. The smooth sub-bin t_row / t_col contribution, which is the only thing that really feeds into the bilinear weights, flows through the query-point argument of jnp.interp entirely unchanged.

  • Monotonic jitter arange(N) * 1e-7 prevents the query-gradient term (failure mode 2) from seeing a zero knot gap. With the jitter, the minimum gap is guaranteed to be 1e-7. The worst-case slope through any bin becomes bounded by (Δt_max) / JITTER = (1/(N+1)) / 1e-7 ≈ 650 for N ≈ 1.5e4 — large but finite and harmless numerically. The forward interpolation value is perturbed by at most N * JITTER ≈ 1.5e-3 in scaled source-plane units, well below the (source_grid_size - 3) downstream multiplier's sub-pixel sensitivity, so the mapping matrix itself is essentially unchanged in value.

Only the JAX branch is touched. The NumPy path (forward_interp_np) is untouched — it doesn't suffer from this bug because NumPy has no autodiff.

Verification

After the patch, the same probe produces finite, FD-agreeing gradients at every stage:

       stage          JAX grad      FD grad     ratio
grid_over_transformed -1.93e+00   -2.08e+00    0.93
grid_over_index       -8.52e+02   -6.21e+02    1.37
t_row_t_col           -8.41e+01   -1.75e+02    0.48
weights               -2.80e+02   -5.49e+02    0.51
weights_util          -2.80e+02   -5.49e+02    0.51

The remaining ~0.5–1.4 JAX/FD ratio at downstream stages is expected and not a correctness concern: the sum-of-squares loss used in the probe is piecewise smooth, with jumps at every bin-boundary crossing, so finite differences with h=1e-5 average across crossings that JAX's autodiff correctly reports as zero. This is intrinsic to the bilinear interpolator's discrete bin structure and has nothing to do with the patch.

pixelization_gradients.py (companion script in autolens_workspace_developer) goes from failing at step 4 to passing steps 4–6 with finite, well-conditioned gradients. Step 8+ still fail for unrelated NNLS-conditioning reasons.

API Changes

None — internal changes only. No public signatures changed, no added / removed / renamed symbols, no behavioural change on the NumPy path. The JAX-path numerical output of transform(grid) is perturbed by at most ~1.5e-3 per coordinate, which is below the bilinear bin-assignment resolution.

See full details below.

Test Plan

  • python -m pytest test_autoarray/inversion/ — 155 passed
  • python -m pytest test_autoarray/inversion/pixelization/interpolator/ test_autoarray/inversion/pixelization/mappers/ test_autoarray/inversion/pixelization/mesh/ — 21 passed
  • autolens_workspace_developer/jax_profiling/imaging/mapper_grad_isolate.py — previously O(1e21) ratios, now O(1) at every stage
  • autolens_workspace_developer/jax_profiling/imaging/pixelization_gradients.py — steps 4–6 go from failing to PASS
Full API Changes (for automation & release notes)

Removed

  • (none)

Added

  • (none)

Renamed

  • (none)

Changed Signature

  • (none)

Changed Behaviour

  • autoarray.inversion.mesh.interpolator.rectangular.create_transforms: JAX-path transform / inv_transform now internally apply a 1e-7 monotonic jitter + stop_gradient to the sorted knot array. Forward output is perturbed by at most ~1.5e-3 per coordinate in scaled source-plane units. Backward gradient through the knot array is now zero by design. NumPy path is unchanged.

Migration

  • No migration required — the change is internal to the rectangular interpolator and is transparent to every caller.

🤖 Generated with Claude Code

Jammy2211 and others added 2 commits April 14, 2026 23:24
stop_gradient + monotonic jitter on sort_points stabilises
jnp.interp backward through duplicate knots in ray-traced source grids.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Double-where safe-division alternative to the JITTER approach — not
yet wired in, see admin_jammy/prompt/autoarray/rectangular_interp_custom_vjp.md
for the follow-up task that evaluates and switches to it.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant