From 31df351cf2d53d29ff3eae99a498095d5d785aea Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Thu, 30 Apr 2026 15:16:03 +0200 Subject: [PATCH] Update inverse_and_log_det typehint --- distreqx/bijectors/__init__.py | 1 + distreqx/bijectors/_bijector.py | 4 +- distreqx/bijectors/_split.py | 70 ++++++++++++++++++++++ docs/api/bijectors/split.md | 7 +++ mkdocs.yml | 1 + tests/split_test.py | 101 ++++++++++++++++++++++++++++++++ 6 files changed, 182 insertions(+), 2 deletions(-) create mode 100644 distreqx/bijectors/_split.py create mode 100644 docs/api/bijectors/split.md create mode 100644 tests/split_test.py diff --git a/distreqx/bijectors/__init__.py b/distreqx/bijectors/__init__.py index 7805142..3d93974 100644 --- a/distreqx/bijectors/__init__.py +++ b/distreqx/bijectors/__init__.py @@ -11,6 +11,7 @@ from ._scalar_affine import ScalarAffine as ScalarAffine from ._shift import Shift as Shift from ._sigmoid import Sigmoid as Sigmoid +from ._split import Split as Split from ._tanh import Tanh as Tanh from ._triangular_linear import TriangularLinear as TriangularLinear from ._unconstrained_affine import UnconstrainedAffine as UnconstrainedAffine diff --git a/distreqx/bijectors/_bijector.py b/distreqx/bijectors/_bijector.py index 40680b6..e8185b7 100644 --- a/distreqx/bijectors/_bijector.py +++ b/distreqx/bijectors/_bijector.py @@ -1,7 +1,7 @@ from abc import abstractmethod import equinox as eqx -from jaxtyping import Array, PyTree +from jaxtyping import PyTree from .._custom_meta import AbstractStrictModule @@ -59,7 +59,7 @@ def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]: ) @abstractmethod - def inverse_and_log_det(self, y: Array) -> tuple[PyTree, PyTree]: + def inverse_and_log_det(self, y: PyTree) -> tuple[PyTree, PyTree]: r"""Computes $x = f^{-1}(y)$ and $\log|\det J(f^{-1})(y)|$.""" raise NotImplementedError( f"Bijector {self.name} does not implement `inverse_and_log_det`." diff --git a/distreqx/bijectors/_split.py b/distreqx/bijectors/_split.py new file mode 100644 index 0000000..84593fe --- /dev/null +++ b/distreqx/bijectors/_split.py @@ -0,0 +1,70 @@ +from typing import Union + +import equinox as eqx +import jax.numpy as jnp +from jaxtyping import Array + +from ._bijector import ( + AbstractBijector, + AbstractForwardInverseBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, +) + + +class Split( + AbstractForwardInverseBijector, + AbstractInvLogDetJacBijector, + AbstractFwdLogDetJacBijector, + strict=True, +): + """A bijector that splits a single array into a tuple of arrays along an axis. + + This operates as a wrapper around `jax.numpy.split`. + """ + + indices_or_sections: Union[int, tuple[int, ...]] = eqx.field(static=True) + axis: int = eqx.field(static=True) + + _is_constant_jacobian: bool = True + _is_constant_log_det: bool = True + + def __init__( + self, + indices_or_sections: Union[int, tuple[int, ...], list[int]], + axis: int = -1, + ): + """Initializes a Split bijector. + + **Arguments:** + + - `indices_or_sections`: If an integer `N`, the array will be divided into + `N` equal arrays along axis. If a tuple/list of sorted integers, the entries + indicate where along axis the array is split. + - `axis`: The axis along which to split. Defaults to -1 (last axis). + """ + # Ensure lists are converted to tuples so they remain hashable for JAX JIT + if isinstance(indices_or_sections, list): + indices_or_sections = tuple(indices_or_sections) + + self.indices_or_sections = indices_or_sections + self.axis = axis + + def forward_and_log_det(self, x: Array) -> tuple[tuple[Array, ...], Array]: + """Computes y = tuple(split(x)) and log|det J(f)(x)| = 0.""" + y = tuple(jnp.split(x, self.indices_or_sections, axis=self.axis)) + return y, jnp.zeros((), dtype=x.dtype) + + def inverse_and_log_det(self, y: tuple[Array, ...]) -> tuple[Array, Array]: + """Computes x = concatenate(y) and log|det J(f^{-1})(y)| = 0.""" + x = jnp.concatenate(y, axis=self.axis) + dtype = y[0].dtype if y else jnp.float32 + return x, jnp.zeros((), dtype=dtype) + + def same_as(self, other: AbstractBijector) -> bool: + """Returns True if this bijector is guaranteed to be the same as `other`.""" + return ( + type(other) is Split + and self.indices_or_sections == other.indices_or_sections + and self.axis == other.axis + ) diff --git a/docs/api/bijectors/split.md b/docs/api/bijectors/split.md new file mode 100644 index 0000000..969a6a0 --- /dev/null +++ b/docs/api/bijectors/split.md @@ -0,0 +1,7 @@ +# Split Bijector + +::: distreqx.bijectors.Split + options: + members: + - __init__ +--- \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index aa5cf4e..e5190e1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -123,6 +123,7 @@ nav: - 'api/bijectors/scalar_affine.md' - 'api/bijectors/shift.md' - 'api/bijectors/sigmoid.md' + - 'api/bijectors/split.md' - 'api/bijectors/tanh.md' - 'api/bijectors/triangular_linear.md' - 'api/bijectors/_bijector.md' diff --git a/tests/split_test.py b/tests/split_test.py new file mode 100644 index 0000000..fb44e94 --- /dev/null +++ b/tests/split_test.py @@ -0,0 +1,101 @@ +from unittest import TestCase + +import jax + +jax.config.update("jax_enable_x64", True) + +import equinox as eqx +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import Split + + +class SplitTest(TestCase): + def setUp(self): + # Bijector 1: Split into 3 equal sections + self.bij_equal = Split(indices_or_sections=3, axis=-1) + + # Bijector 2: Split at specific indices + # (elements up to index 2, from 2 to 5, and from 5 onward) + self.bij_indices = Split(indices_or_sections=(2, 5), axis=-1) + + def assertion_fn(self, rtol=1e-5): + return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol) + + @parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)]) + def test_forward_and_log_det_equal_sections(self, name, dtype): + x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype) + y, log_det = self.bij_equal.forward_and_log_det(x) + + self.assertIsInstance(y, tuple) + self.assertEqual(len(y), 3) + self.assertion_fn()(y[0], jnp.array([1.0, 2.0], dtype=dtype)) + self.assertion_fn()(y[1], jnp.array([3.0, 4.0], dtype=dtype)) + self.assertion_fn()(y[2], jnp.array([5.0, 6.0], dtype=dtype)) + + self.assertEqual(log_det.shape, ()) + self.assertEqual(log_det, 0.0) + self.assertEqual(log_det.dtype, dtype) + + @parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)]) + def test_forward_and_log_det_indices(self, name, dtype): + x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=dtype) + y, log_det = self.bij_indices.forward_and_log_det(x) + + self.assertIsInstance(y, tuple) + self.assertEqual(len(y), 3) + self.assertion_fn()(y[0], jnp.array([1.0, 2.0], dtype=dtype)) # x[:2] + self.assertion_fn()(y[1], jnp.array([3.0, 4.0, 5.0], dtype=dtype)) # x[2:5] + self.assertion_fn()(y[2], jnp.array([6.0, 7.0], dtype=dtype)) # x[5:] + + @parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)]) + def test_inverse_and_log_det(self, name, dtype): + y = ( + jnp.array([1.0, 2.0], dtype=dtype), + jnp.array([3.0, 4.0], dtype=dtype), + jnp.array([5.0, 6.0], dtype=dtype), + ) + x, log_det = self.bij_equal.inverse_and_log_det(y) + + self.assertion_fn()(x, jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)) + + self.assertEqual(log_det.shape, ()) + self.assertEqual(log_det, 0.0) + self.assertEqual(x.dtype, dtype) + self.assertEqual(log_det.dtype, dtype) + + def test_jittable(self): + @eqx.filter_jit + def f_forward(bij, x): + return bij.forward_and_log_det(x) + + @eqx.filter_jit + def f_inverse(bij, y): + return bij.inverse_and_log_det(y) + + x = jnp.ones((6,)) + y, log_det_fwd = f_forward(self.bij_equal, x) + + self.assertIsInstance(y, tuple) + self.assertIsInstance(log_det_fwd, jax.Array) + + x_reconstructed, log_det_inv = f_inverse(self.bij_equal, y) + self.assertIsInstance(x_reconstructed, jax.Array) + self.assertIsInstance(log_det_inv, jax.Array) + + def test_same_as(self): + same_bij = Split(indices_or_sections=3, axis=-1) + diff_bij_1 = Split(indices_or_sections=4, axis=-1) + diff_bij_2 = Split(indices_or_sections=3, axis=0) + + self.assertTrue(self.bij_equal.same_as(same_bij)) + self.assertFalse(self.bij_equal.same_as(diff_bij_1)) + self.assertFalse(self.bij_equal.same_as(diff_bij_2)) + + def test_list_to_tuple_conversion(self): + # A list should be converted to a tuple upon initialization for hashability + bij_list = Split(indices_or_sections=[2, 5]) + self.assertIsInstance(bij_list.indices_or_sections, tuple) + self.assertTrue(self.bij_indices.same_as(bij_list))