feat(distributions): add Poisson distribution#329
Open
Sumu004 wants to merge 3 commits into
Open
Conversation
In float32, sampling from a Tanh-transformed distribution (e.g. SAC or PPO with a squashed Gaussian policy) can produce values numerically equal to ±1 due to limited precision. Passing these to arctanh gives ±∞, which propagates to NaN in log_prob and silently breaks policy gradient updates. Fix: clip y to the open interval (-1+ε, 1-ε) in inverse_and_log_det, where ε = jnp.finfo(y.dtype).eps — the tightest safe bound for the dtype. This produces the largest finite pre-activation rather than ±∞ and keeps log_prob finite for any sample produced by sample(). The clip is dtype-aware so float64 benefits from the tighter 2.2e-16 bound while float32 uses 1.2e-7. Also updates test_stability to verify finiteness at boundary values (distrax intentionally diverges from TFP's NaN behaviour here) and adds two regression tests: one for direct boundary clipping, one that samples from a wide Tanh-wrapped MultivariateNormalDiag and verifies log_prob remains finite. Fixes: google-deepmind#216
…ensions
When MultivariateNormal* distributions are constructed inside jax.vmap,
the vmapped execution prepends a batch dimension to _loc at run time.
The _batch_shape tuple is computed and captured at trace time (as a static
Python tuple), so it does not include the extra batch dimension. This
makes batch_shape stale and causes the loc property to fail with:
ValueError: Cannot broadcast to shape with fewer dimensions:
arr_shape=(B, D) shape=(D,)
Fix: override batch_shape as a computed property. Instead of returning the
static _batch_shape directly, detect any extra leading dimensions in _loc
that go beyond what _batch_shape + _event_shape predict (i.e. dimensions
added by vmap or similar batching transforms), and prepend them to the
static batch_shape. This preserves the existing scale-broadcasting
semantics (where batch_shape can be wider than loc.shape[:-1]) while also
handling the vmap case.
Also adds VmapBatchShapeTest with four regression cases:
- batch_shape and loc.shape after vmap
- loc property no longer raises after vmap (the reported symptom)
- non-vmapped case unchanged
- scale-batch-broadcasting unchanged
Fixes: google-deepmind#276
Adds distrax.Poisson(rate) — the Poisson distribution with rate parameter λ.
Implements:
- log_prob(k) = k*log(λ) − λ − log Γ(k+1)
- mean() = variance() = λ
- mode() = floor(λ)
- entropy(): computed numerically (no closed form)
- kl_divergence(other): analytic formula when both are Poisson:
KL(Poisson(λ₁)‖Poisson(λ₂)) = λ₁ log(λ₁/λ₂) + λ₂ − λ₁
- sample_n(): delegates to jax.random.poisson
Adds distrax.Poisson to the public __init__.py and 31 tests covering
log_prob accuracy vs scipy, moment correctness, dtype control, KL
divergence, JIT compatibility, and entropy.
Fixes: google-deepmind#164
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Adds
distrax.Poisson(rate)— the Poisson distribution with rate parameter λ > 0, implementing the interface requested in #164.The Poisson distribution models the number of discrete events occurring at a constant rate. Its PMF is:
Implemented methods
log_prob(k)k·log(λ) − λ − log Γ(k+1)mean()variance()mode()floor(λ)entropy()kl_divergence(other)otheris also Poisson:λ₁ log(λ₁/λ₂) + λ₂ − λ₁sample()jax.random.poissonSupports integer and floating-point
dtypefor samples, batchedrate, and JIT compilation.Tests
31 tests in
poisson_test.pycovering:log_probaccuracy vsscipy.stats.poisson.logpmffor k = 0, 1, 3, 10Notes
entropy()is computed via a numerical sum (TFP's JAX substrate does not implementPoisson.entropy()). The window is chosen asmin(λ + 10√λ + 30, 2000), which captures >99.999 % of probability mass for rates up to ~1000.equiv_tfp_cls = tfd.Poissonis set for TFP interoperability.Fixes #164