From 7d05c650731e03c12c7de17819553182c5a5f045 Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Thu, 30 Apr 2026 21:31:30 +0200 Subject: [PATCH 1/3] Updated docs --- distreqx/distributions/__init__.py | 1 + distreqx/distributions/_improper_uniform.py | 87 +++++++++++++++++++++ docs/api/distributions/improper_uniform.md | 4 + mkdocs.yml | 1 + 4 files changed, 93 insertions(+) create mode 100644 distreqx/distributions/_improper_uniform.py create mode 100644 docs/api/distributions/improper_uniform.md diff --git a/distreqx/distributions/__init__.py b/distreqx/distributions/__init__.py index 06b8aa3..7cdb660 100644 --- a/distreqx/distributions/__init__.py +++ b/distreqx/distributions/__init__.py @@ -10,6 +10,7 @@ AbstractSurvivalDistribution as AbstractSurvivalDistribution, ) from ._gamma import Gamma as Gamma +from ._improper_uniform import ImproperUniform as ImproperUniform from ._independent import Independent as Independent from ._logistic import Logistic as Logistic from ._mixture_same_family import MixtureSameFamily as MixtureSameFamily diff --git a/distreqx/distributions/_improper_uniform.py b/distreqx/distributions/_improper_uniform.py new file mode 100644 index 0000000..e2226c9 --- /dev/null +++ b/distreqx/distributions/_improper_uniform.py @@ -0,0 +1,87 @@ +import jax.numpy as jnp +from jaxtyping import Array, Key + +from ._distribution import AbstractDistribution + + +class ImproperUniform( + AbstractDistribution, + strict=True, +): + """Improper Uniform distribution over the entire real line. + + This distribution has an unnormalized probability density of 1 everywhere, + meaning its `log_prob` evaluates to 0. As an improper distribution, it + does not integrate to 1 and cannot be sampled from. + """ + + shape: tuple[int, ...] = () + + @property + def event_shape(self) -> tuple[int, ...]: + return self.shape + + def sample(self, key: Key[Array, ""]) -> Array: + """Sampling is not defined for improper distributions.""" + raise NotImplementedError("Cannot sample from an ImproperUniform distribution.") + + def sample_and_log_prob(self, key: Key[Array, ""]) -> tuple[Array, Array]: + raise NotImplementedError("Cannot sample from an ImproperUniform distribution.") + + def log_prob(self, value: Array) -> Array: + """Returns the unnormalized log probability (constant 0.0).""" + return jnp.zeros_like(value) + + def prob(self, value: Array) -> Array: + """Returns the unnormalized probability (constant 1.0).""" + return jnp.ones_like(value) + + def entropy(self) -> Array: + """Entropy of an improper uniform over the reals is infinite.""" + return jnp.full(self.shape, jnp.inf) + + def icdf(self, value: Array) -> Array: + raise NotImplementedError("icdf is undefined for an improper distribution.") + + def log_cdf(self, value: Array) -> Array: + raise NotImplementedError("log_cdf is undefined for an improper distribution.") + + def cdf(self, value: Array) -> Array: + raise NotImplementedError("cdf is undefined for an improper distribution.") + + def survival_function(self, value: Array) -> Array: + raise NotImplementedError( + "survival_function is undefined for an improper distribution." + ) + + def log_survival_function(self, value: Array) -> Array: + raise NotImplementedError( + "log_survival_function is undefined for an improper distribution." + ) + + def mean(self) -> Array: + raise NotImplementedError("Mean is undefined for an improper distribution.") + + def median(self) -> Array: + raise NotImplementedError("Median is undefined for an improper distribution.") + + def variance(self) -> Array: + raise NotImplementedError("Variance is undefined for an improper distribution.") + + def stddev(self) -> Array: + raise NotImplementedError( + "Standard deviation is undefined for an improper distribution." + ) + + def mode(self) -> Array: + raise NotImplementedError("Mode is undefined for an improper distribution.") + + def kl_divergence(self, other_dist, **kwargs) -> Array: + raise NotImplementedError( + "KL divergence is undefined for an improper distribution." + ) + + def cross_entropy(self, other_dist, **kwargs) -> Array: + raise NotImplementedError( + "Cross entropy is undefined for an improper distribution." + ) diff --git a/docs/api/distributions/improper_uniform.md b/docs/api/distributions/improper_uniform.md new file mode 100644 index 0000000..b096c2b --- /dev/null +++ b/docs/api/distributions/improper_uniform.md @@ -0,0 +1,4 @@ +# Improper Uniform + +::: distreqx.distributions.ImproperUniform +--- \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index aa5cf4e..d63bc07 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -111,6 +111,7 @@ nav: - 'api/distributions/categorical.md' - 'api/distributions/gamma.md' - 'api/distributions/logistic.md' + - 'api/distributions/improper_uniform.md' - 'api/distributions/independent.md' - 'api/distributions/mixture_same_family.md' - 'api/distributions/uniform.md' From aaaf26e5d95576a928bccee2b1325b59bf6ee036 Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Fri, 1 May 2026 08:36:26 +0200 Subject: [PATCH 2/3] Removed cross entropy override --- distreqx/distributions/_improper_uniform.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/distreqx/distributions/_improper_uniform.py b/distreqx/distributions/_improper_uniform.py index e2226c9..9a70f7f 100644 --- a/distreqx/distributions/_improper_uniform.py +++ b/distreqx/distributions/_improper_uniform.py @@ -80,8 +80,3 @@ def kl_divergence(self, other_dist, **kwargs) -> Array: raise NotImplementedError( "KL divergence is undefined for an improper distribution." ) - - def cross_entropy(self, other_dist, **kwargs) -> Array: - raise NotImplementedError( - "Cross entropy is undefined for an improper distribution." - ) From 9dcfa111a0c488820e0421b9a9b2b33e92d572fb Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Thu, 7 May 2026 08:17:54 +0200 Subject: [PATCH 3/3] Add ImproperUniform --- distreqx/distributions/_improper_uniform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distreqx/distributions/_improper_uniform.py b/distreqx/distributions/_improper_uniform.py index 9a70f7f..eb73470 100644 --- a/distreqx/distributions/_improper_uniform.py +++ b/distreqx/distributions/_improper_uniform.py @@ -12,7 +12,7 @@ class ImproperUniform( This distribution has an unnormalized probability density of 1 everywhere, meaning its `log_prob` evaluates to 0. As an improper distribution, it - does not integrate to 1 and cannot be sampled from. + does not integrate to 1 and cannot be sampled. """ shape: tuple[int, ...] = ()