Fix exploding JAX gradient in rectangular interpolator (duplicate sort-knots)#281
Open
Fix exploding JAX gradient in rectangular interpolator (duplicate sort-knots)#281
Conversation
stop_gradient + monotonic jitter on sort_points stabilises jnp.interp backward through duplicate knots in ray-traced source grids. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Double-where safe-division alternative to the JITTER approach — not yet wired in, see admin_jammy/prompt/autoarray/rectangular_interp_custom_vjp.md for the follow-up task that evaluates and switches to it. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes an ~O(1e24) gradient blow-up inside
autoarray.inversion.mesh.interpolator.rectangular.create_transformsthat silently poisoned the entire JAX pixelization likelihood gradient. Downstream of this function,FitImagingwith any rectangular pixelization mesh (RectangularUniform,RectangularAdaptDensity) produced effectively-zero or NaN gradients through the mapping matrix, making the pixelization path unusable for gradient-based optimisation / HMC.After the patch, the mapping matrix, data vector, and curvature matrix all return finite gradients that agree with finite differences up to expected bin-boundary O(1) noise.
What was going wrong — detailed narrative for the reviewer
The rectangular interpolator converts each ray-traced source-plane point into a rank-space coordinate via
and then uses the output to index into a regular
source_grid_size × source_grid_sizemesh with bilinear weights (floor/ceil+ sub-bint_row/t_col).The bug. Ray-traced source-plane grids from realistic lens models contain massive numbers of duplicate coordinates. For an Isothermal lens over a circular mask of radius 3.5 (the standard HST test setup), the probe in
autolens_workspace_developer/jax_profiling/imaging/mapper_grad_isolate.pymeasured:Half of all adjacent pairs in the sorted source grid are at exactly zero gap (floating-point zero — these are genuinely coincident rays, not rounding noise). This breaks
jnp.interp's vjp in two distinct ways, which is the crucial subtlety:Knot-gradient term. The vjp of
jnp.interpw.r.t. its knot arrayxpcontains a division byxp[i+1] - xp[i]. With 7680 exact-zero gaps this is literally0/0, which JAX propagates asO(1e24)cotangents back throughsort_points → source_plane_data_grid → deflections → lens model parameters.Query-gradient term. The vjp w.r.t. the query
xis the local slope,(yp[i+1] - yp[i]) / (xp[i+1] - xp[i]). Even if the knots themselves were frozen, any query landing in a zero-gap bin seesslope = Δt / 0. An isolated one-bin duplicate doesn't trigger this in practice (JAX's indexing dodges it), but long runs of consecutive duplicates, which is exactly what caustics produce, do.Stage-by-stage gradient vs finite-difference before the fix (from
mapper_grad_isolate.py):The blow-up is injected exactly at the
jnp.interpcall and then propagated by linearity through every downstream operation, including the finalinversion.operated_mapping_matrix→D,F, NNLS,log_evidence.The fix
Inside the
xp.__name__.startswith("jax")branch ofcreate_transforms:Two patches, each addressing one of the two failure modes above:
stop_gradient(sort_points)kills the knot-gradient term (failure mode 1). This is semantically correct, not a hack: the only consumer of this interpolator's output isadaptive_rectangular_mappings_weights_via_interpolation_from, which doesfloor(...)/ceil(...)on the transformed coordinate to pick the 4 corner pixels. That bin-assignment already has zero gradient. The knot-gradient term we drop is precisely the derivative of which-bin-is-selected with respect to the knot positions — a derivative that has no downstream consumer. The smooth sub-bint_row/t_colcontribution, which is the only thing that really feeds into the bilinear weights, flows through the query-point argument ofjnp.interpentirely unchanged.Monotonic jitter
arange(N) * 1e-7prevents the query-gradient term (failure mode 2) from seeing a zero knot gap. With the jitter, the minimum gap is guaranteed to be1e-7. The worst-case slope through any bin becomes bounded by(Δt_max) / JITTER = (1/(N+1)) / 1e-7 ≈ 650forN ≈ 1.5e4— large but finite and harmless numerically. The forward interpolation value is perturbed by at mostN * JITTER ≈ 1.5e-3in scaled source-plane units, well below the(source_grid_size - 3)downstream multiplier's sub-pixel sensitivity, so the mapping matrix itself is essentially unchanged in value.Only the JAX branch is touched. The NumPy path (
forward_interp_np) is untouched — it doesn't suffer from this bug because NumPy has no autodiff.Verification
After the patch, the same probe produces finite, FD-agreeing gradients at every stage:
The remaining ~0.5–1.4 JAX/FD ratio at downstream stages is expected and not a correctness concern: the sum-of-squares loss used in the probe is piecewise smooth, with jumps at every bin-boundary crossing, so finite differences with
h=1e-5average across crossings that JAX's autodiff correctly reports as zero. This is intrinsic to the bilinear interpolator's discrete bin structure and has nothing to do with the patch.pixelization_gradients.py(companion script inautolens_workspace_developer) goes from failing at step 4 to passing steps 4–6 with finite, well-conditioned gradients. Step 8+ still fail for unrelated NNLS-conditioning reasons.API Changes
None — internal changes only. No public signatures changed, no added / removed / renamed symbols, no behavioural change on the NumPy path. The JAX-path numerical output of
transform(grid)is perturbed by at most~1.5e-3per coordinate, which is below the bilinear bin-assignment resolution.See full details below.
Test Plan
python -m pytest test_autoarray/inversion/— 155 passedpython -m pytest test_autoarray/inversion/pixelization/interpolator/ test_autoarray/inversion/pixelization/mappers/ test_autoarray/inversion/pixelization/mesh/— 21 passedautolens_workspace_developer/jax_profiling/imaging/mapper_grad_isolate.py— previously O(1e21) ratios, now O(1) at every stageautolens_workspace_developer/jax_profiling/imaging/pixelization_gradients.py— steps 4–6 go from failing to PASSFull API Changes (for automation & release notes)
Removed
Added
Renamed
Changed Signature
Changed Behaviour
autoarray.inversion.mesh.interpolator.rectangular.create_transforms: JAX-pathtransform/inv_transformnow internally apply a1e-7monotonic jitter +stop_gradientto the sorted knot array. Forward output is perturbed by at most ~1.5e-3 per coordinate in scaled source-plane units. Backward gradient through the knot array is now zero by design. NumPy path is unchanged.Migration
🤖 Generated with Claude Code