Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions distreqx/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AbstractSurvivalDistribution as AbstractSurvivalDistribution,
)
from ._gamma import Gamma as Gamma
from ._improper_uniform import ImproperUniform as ImproperUniform
from ._independent import Independent as Independent
from ._logistic import Logistic as Logistic
from ._mixture_same_family import MixtureSameFamily as MixtureSameFamily
Expand Down
82 changes: 82 additions & 0 deletions distreqx/distributions/_improper_uniform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import jax.numpy as jnp
from jaxtyping import Array, Key

from ._distribution import AbstractDistribution


class ImproperUniform(
AbstractDistribution,
strict=True,
):
"""Improper Uniform distribution over the entire real line.

This distribution has an unnormalized probability density of 1 everywhere,
meaning its `log_prob` evaluates to 0. As an improper distribution, it
does not integrate to 1 and cannot be sampled.
"""

shape: tuple[int, ...] = ()

@property
def event_shape(self) -> tuple[int, ...]:
return self.shape

def sample(self, key: Key[Array, ""]) -> Array:
"""Sampling is not defined for improper distributions."""
raise NotImplementedError("Cannot sample from an ImproperUniform distribution.")

def sample_and_log_prob(self, key: Key[Array, ""]) -> tuple[Array, Array]:
raise NotImplementedError("Cannot sample from an ImproperUniform distribution.")

def log_prob(self, value: Array) -> Array:
"""Returns the unnormalized log probability (constant 0.0)."""
return jnp.zeros_like(value)

def prob(self, value: Array) -> Array:
"""Returns the unnormalized probability (constant 1.0)."""
return jnp.ones_like(value)

def entropy(self) -> Array:
"""Entropy of an improper uniform over the reals is infinite."""
return jnp.full(self.shape, jnp.inf)

def icdf(self, value: Array) -> Array:
raise NotImplementedError("icdf is undefined for an improper distribution.")

def log_cdf(self, value: Array) -> Array:
raise NotImplementedError("log_cdf is undefined for an improper distribution.")

def cdf(self, value: Array) -> Array:
raise NotImplementedError("cdf is undefined for an improper distribution.")

def survival_function(self, value: Array) -> Array:
raise NotImplementedError(
"survival_function is undefined for an improper distribution."
)

def log_survival_function(self, value: Array) -> Array:
raise NotImplementedError(
"log_survival_function is undefined for an improper distribution."
)

def mean(self) -> Array:
raise NotImplementedError("Mean is undefined for an improper distribution.")

def median(self) -> Array:
raise NotImplementedError("Median is undefined for an improper distribution.")

def variance(self) -> Array:
raise NotImplementedError("Variance is undefined for an improper distribution.")

def stddev(self) -> Array:
raise NotImplementedError(
"Standard deviation is undefined for an improper distribution."
)

def mode(self) -> Array:
raise NotImplementedError("Mode is undefined for an improper distribution.")

def kl_divergence(self, other_dist, **kwargs) -> Array:
raise NotImplementedError(
"KL divergence is undefined for an improper distribution."
)
4 changes: 4 additions & 0 deletions docs/api/distributions/improper_uniform.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Improper Uniform

::: distreqx.distributions.ImproperUniform
---
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ nav:
- 'api/distributions/categorical.md'
- 'api/distributions/gamma.md'
- 'api/distributions/logistic.md'
- 'api/distributions/improper_uniform.md'
- 'api/distributions/independent.md'
- 'api/distributions/mixture_same_family.md'
- 'api/distributions/uniform.md'
Expand Down
Loading