fix(distrax): Tanh NaN clip + MultivariateNormal vmap batch_shape fix#328
Open
Sumu004 wants to merge 2 commits into
Open
fix(distrax): Tanh NaN clip + MultivariateNormal vmap batch_shape fix#328Sumu004 wants to merge 2 commits into
Sumu004 wants to merge 2 commits into
Conversation
In float32, sampling from a Tanh-transformed distribution (e.g. SAC or PPO with a squashed Gaussian policy) can produce values numerically equal to ±1 due to limited precision. Passing these to arctanh gives ±∞, which propagates to NaN in log_prob and silently breaks policy gradient updates. Fix: clip y to the open interval (-1+ε, 1-ε) in inverse_and_log_det, where ε = jnp.finfo(y.dtype).eps — the tightest safe bound for the dtype. This produces the largest finite pre-activation rather than ±∞ and keeps log_prob finite for any sample produced by sample(). The clip is dtype-aware so float64 benefits from the tighter 2.2e-16 bound while float32 uses 1.2e-7. Also updates test_stability to verify finiteness at boundary values (distrax intentionally diverges from TFP's NaN behaviour here) and adds two regression tests: one for direct boundary clipping, one that samples from a wide Tanh-wrapped MultivariateNormalDiag and verifies log_prob remains finite. Fixes: google-deepmind#216
…ensions
When MultivariateNormal* distributions are constructed inside jax.vmap,
the vmapped execution prepends a batch dimension to _loc at run time.
The _batch_shape tuple is computed and captured at trace time (as a static
Python tuple), so it does not include the extra batch dimension. This
makes batch_shape stale and causes the loc property to fail with:
ValueError: Cannot broadcast to shape with fewer dimensions:
arr_shape=(B, D) shape=(D,)
Fix: override batch_shape as a computed property. Instead of returning the
static _batch_shape directly, detect any extra leading dimensions in _loc
that go beyond what _batch_shape + _event_shape predict (i.e. dimensions
added by vmap or similar batching transforms), and prepend them to the
static batch_shape. This preserves the existing scale-broadcasting
semantics (where batch_shape can be wider than loc.shape[:-1]) while also
handling the vmap case.
Also adds VmapBatchShapeTest with four regression cases:
- batch_shape and loc.shape after vmap
- loc property no longer raises after vmap (the reported symptom)
- non-vmapped case unchanged
- scale-batch-broadcasting unchanged
Fixes: google-deepmind#276
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Two related numerical-correctness fixes in
distrax:Fix 1: Tanh bijector — clip inverse to prevent NaN in log_prob (Fixes #216)
In
float32, sampling from a Tanh-transformed distribution (SAC, PPO with squashed Gaussian) can produce values equal to ±1 due to limited precision.arctanh(±1) = ±∞, propagating to NaN inlog_prob.Fix: clip
yto(-1+ε, 1-ε)whereε = jnp.finfo(y.dtype).epsbefore callingarctanh. Dtype-aware (float64 uses 2.2e-16, float32 uses 1.2e-7).Fix 2: MultivariateNormalFromBijector — batch_shape stale after vmap (Fixes #276)
When distributions are constructed inside
jax.vmap, the vmapped execution prepends a batch dimension to_locat run time. The_batch_shapetuple captured at trace time doesn't include this extra dimension, causingdist.locto raise:Fix: override
batch_shapeas a computed property. Detect extra leading dimensions in_locbeyond what_batch_shape + _event_shapepredict, and prepend them. Preserves scale-broadcasting semantics for the non-vmap case.Tests
Fix 1:
test_stabilityupdated;test_inverse_clips_boundary_values_to_prevent_nan+test_log_prob_finite_at_float32_boundary_samplesadded.Fix 2:
VmapBatchShapeTestwith 4 cases: vmap batch_shape, vmap loc accessible, non-vmapped unchanged, scale-broadcasting unchanged.All 69 Tanh tests and 359 MVN tests pass.