Skip to content

fix(distrax): Tanh NaN clip + MultivariateNormal vmap batch_shape fix#328

Open
Sumu004 wants to merge 2 commits into
google-deepmind:mainfrom
Sumu004:fix/tanh-bijector-nan-clip
Open

fix(distrax): Tanh NaN clip + MultivariateNormal vmap batch_shape fix#328
Sumu004 wants to merge 2 commits into
google-deepmind:mainfrom
Sumu004:fix/tanh-bijector-nan-clip

Conversation

@Sumu004

@Sumu004 Sumu004 commented Jun 4, 2026

Copy link
Copy Markdown

What does this PR do?

Two related numerical-correctness fixes in distrax:


Fix 1: Tanh bijector — clip inverse to prevent NaN in log_prob (Fixes #216)

In float32, sampling from a Tanh-transformed distribution (SAC, PPO with squashed Gaussian) can produce values equal to ±1 due to limited precision. arctanh(±1) = ±∞, propagating to NaN in log_prob.

Fix: clip y to (-1+ε, 1-ε) where ε = jnp.finfo(y.dtype).eps before calling arctanh. Dtype-aware (float64 uses 2.2e-16, float32 uses 1.2e-7).


Fix 2: MultivariateNormalFromBijector — batch_shape stale after vmap (Fixes #276)

When distributions are constructed inside jax.vmap, the vmapped execution prepends a batch dimension to _loc at run time. The _batch_shape tuple captured at trace time doesn't include this extra dimension, causing dist.loc to raise:

ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(B, D) shape=(D,)

Fix: override batch_shape as a computed property. Detect extra leading dimensions in _loc beyond what _batch_shape + _event_shape predict, and prepend them. Preserves scale-broadcasting semantics for the non-vmap case.


Tests

Fix 1: test_stability updated; test_inverse_clips_boundary_values_to_prevent_nan + test_log_prob_finite_at_float32_boundary_samples added.

Fix 2: VmapBatchShapeTest with 4 cases: vmap batch_shape, vmap loc accessible, non-vmapped unchanged, scale-broadcasting unchanged.

All 69 Tanh tests and 359 MVN tests pass.

Sumu004 added 2 commits June 4, 2026 17:45
In float32, sampling from a Tanh-transformed distribution (e.g. SAC or
PPO with a squashed Gaussian policy) can produce values numerically equal
to ±1 due to limited precision.  Passing these to arctanh gives ±∞, which
propagates to NaN in log_prob and silently breaks policy gradient updates.

Fix: clip y to the open interval (-1+ε, 1-ε) in inverse_and_log_det,
where ε = jnp.finfo(y.dtype).eps — the tightest safe bound for the dtype.
This produces the largest finite pre-activation rather than ±∞ and keeps
log_prob finite for any sample produced by sample().

The clip is dtype-aware so float64 benefits from the tighter 2.2e-16
bound while float32 uses 1.2e-7.

Also updates test_stability to verify finiteness at boundary values
(distrax intentionally diverges from TFP's NaN behaviour here) and adds
two regression tests: one for direct boundary clipping, one that samples
from a wide Tanh-wrapped MultivariateNormalDiag and verifies log_prob
remains finite.

Fixes: google-deepmind#216
…ensions

When MultivariateNormal* distributions are constructed inside jax.vmap,
the vmapped execution prepends a batch dimension to _loc at run time.
The _batch_shape tuple is computed and captured at trace time (as a static
Python tuple), so it does not include the extra batch dimension.  This
makes batch_shape stale and causes the loc property to fail with:

  ValueError: Cannot broadcast to shape with fewer dimensions:
    arr_shape=(B, D) shape=(D,)

Fix: override batch_shape as a computed property.  Instead of returning the
static _batch_shape directly, detect any extra leading dimensions in _loc
that go beyond what _batch_shape + _event_shape predict (i.e. dimensions
added by vmap or similar batching transforms), and prepend them to the
static batch_shape.  This preserves the existing scale-broadcasting
semantics (where batch_shape can be wider than loc.shape[:-1]) while also
handling the vmap case.

Also adds VmapBatchShapeTest with four regression cases:
  - batch_shape and loc.shape after vmap
  - loc property no longer raises after vmap (the reported symptom)
  - non-vmapped case unchanged
  - scale-batch-broadcasting unchanged

Fixes: google-deepmind#276
@Sumu004 Sumu004 changed the title fix(tanh): clip inverse to (-1+eps, 1-eps) to prevent NaN in log_prob fix(distrax): Tanh NaN clip + MultivariateNormal vmap batch_shape fix Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MultivariateNormalDiag vmap issue nan in MultivariateNormalDiag log prob

1 participant