Skip to content

feat(distributions): add Poisson distribution#329

Open
Sumu004 wants to merge 3 commits into
google-deepmind:mainfrom
Sumu004:feat/add-poisson-distribution
Open

feat(distributions): add Poisson distribution#329
Sumu004 wants to merge 3 commits into
google-deepmind:mainfrom
Sumu004:feat/add-poisson-distribution

Conversation

@Sumu004

@Sumu004 Sumu004 commented Jun 4, 2026

Copy link
Copy Markdown

What does this PR do?

Adds distrax.Poisson(rate) — the Poisson distribution with rate parameter λ > 0, implementing the interface requested in #164.

The Poisson distribution models the number of discrete events occurring at a constant rate. Its PMF is:

P(X = k) = λ^k * exp(-λ) / k!,   k = 0, 1, 2, …

Implemented methods

Method Notes
log_prob(k) k·log(λ) − λ − log Γ(k+1)
mean() λ
variance() λ
mode() floor(λ)
entropy() Numerical sum (no closed form)
kl_divergence(other) Analytic when other is also Poisson: λ₁ log(λ₁/λ₂) + λ₂ − λ₁
sample() Delegates to jax.random.poisson

Supports integer and floating-point dtype for samples, batched rate, and JIT compilation.

Tests

31 tests in poisson_test.py covering:

  • log_prob accuracy vs scipy.stats.poisson.logpmf for k = 0, 1, 3, 10
  • Mean, variance, mode correctness
  • Entropy vs scipy
  • KL divergence (analytic formula and self-divergence = 0)
  • Sample shape, non-negativity, moments (large-sample)
  • Integer and float dtype samples
  • JIT compatibility

Notes

  • entropy() is computed via a numerical sum (TFP's JAX substrate does not implement Poisson.entropy()). The window is chosen as min(λ + 10√λ + 30, 2000), which captures >99.999 % of probability mass for rates up to ~1000.
  • equiv_tfp_cls = tfd.Poisson is set for TFP interoperability.

Fixes #164

Sumu004 added 3 commits June 4, 2026 17:45
In float32, sampling from a Tanh-transformed distribution (e.g. SAC or
PPO with a squashed Gaussian policy) can produce values numerically equal
to ±1 due to limited precision.  Passing these to arctanh gives ±∞, which
propagates to NaN in log_prob and silently breaks policy gradient updates.

Fix: clip y to the open interval (-1+ε, 1-ε) in inverse_and_log_det,
where ε = jnp.finfo(y.dtype).eps — the tightest safe bound for the dtype.
This produces the largest finite pre-activation rather than ±∞ and keeps
log_prob finite for any sample produced by sample().

The clip is dtype-aware so float64 benefits from the tighter 2.2e-16
bound while float32 uses 1.2e-7.

Also updates test_stability to verify finiteness at boundary values
(distrax intentionally diverges from TFP's NaN behaviour here) and adds
two regression tests: one for direct boundary clipping, one that samples
from a wide Tanh-wrapped MultivariateNormalDiag and verifies log_prob
remains finite.

Fixes: google-deepmind#216
…ensions

When MultivariateNormal* distributions are constructed inside jax.vmap,
the vmapped execution prepends a batch dimension to _loc at run time.
The _batch_shape tuple is computed and captured at trace time (as a static
Python tuple), so it does not include the extra batch dimension.  This
makes batch_shape stale and causes the loc property to fail with:

  ValueError: Cannot broadcast to shape with fewer dimensions:
    arr_shape=(B, D) shape=(D,)

Fix: override batch_shape as a computed property.  Instead of returning the
static _batch_shape directly, detect any extra leading dimensions in _loc
that go beyond what _batch_shape + _event_shape predict (i.e. dimensions
added by vmap or similar batching transforms), and prepend them to the
static batch_shape.  This preserves the existing scale-broadcasting
semantics (where batch_shape can be wider than loc.shape[:-1]) while also
handling the vmap case.

Also adds VmapBatchShapeTest with four regression cases:
  - batch_shape and loc.shape after vmap
  - loc property no longer raises after vmap (the reported symptom)
  - non-vmapped case unchanged
  - scale-batch-broadcasting unchanged

Fixes: google-deepmind#276
Adds distrax.Poisson(rate) — the Poisson distribution with rate parameter λ.

Implements:
  - log_prob(k) = k*log(λ) − λ − log Γ(k+1)
  - mean() = variance() = λ
  - mode() = floor(λ)
  - entropy(): computed numerically (no closed form)
  - kl_divergence(other): analytic formula when both are Poisson:
      KL(Poisson(λ₁)‖Poisson(λ₂)) = λ₁ log(λ₁/λ₂) + λ₂ − λ₁
  - sample_n(): delegates to jax.random.poisson

Adds distrax.Poisson to the public __init__.py and 31 tests covering
log_prob accuracy vs scipy, moment correctness, dtype control, KL
divergence, JIT compatibility, and entropy.

Fixes: google-deepmind#164
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feature request: Poisson distribution

1 participant