Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions distreqx/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ._scalar_affine import ScalarAffine as ScalarAffine
from ._shift import Shift as Shift
from ._sigmoid import Sigmoid as Sigmoid
from ._split import Split as Split
from ._tanh import Tanh as Tanh
from ._triangular_linear import TriangularLinear as TriangularLinear
from ._unconstrained_affine import UnconstrainedAffine as UnconstrainedAffine
4 changes: 2 additions & 2 deletions distreqx/bijectors/_bijector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod

import equinox as eqx
from jaxtyping import Array, PyTree
from jaxtyping import PyTree

from .._custom_meta import AbstractStrictModule

Expand Down Expand Up @@ -59,7 +59,7 @@ def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]:
)

@abstractmethod
def inverse_and_log_det(self, y: Array) -> tuple[PyTree, PyTree]:
def inverse_and_log_det(self, y: PyTree) -> tuple[PyTree, PyTree]:
r"""Computes $x = f^{-1}(y)$ and $\log|\det J(f^{-1})(y)|$."""
raise NotImplementedError(
f"Bijector {self.name} does not implement `inverse_and_log_det`."
Expand Down
70 changes: 70 additions & 0 deletions distreqx/bijectors/_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Union

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array

from ._bijector import (
AbstractBijector,
AbstractForwardInverseBijector,
AbstractFwdLogDetJacBijector,
AbstractInvLogDetJacBijector,
)


class Split(
AbstractForwardInverseBijector,
AbstractInvLogDetJacBijector,
AbstractFwdLogDetJacBijector,
strict=True,
):
"""A bijector that splits a single array into a tuple of arrays along an axis.

This operates as a wrapper around `jax.numpy.split`.
"""

indices_or_sections: Union[int, tuple[int, ...]] = eqx.field(static=True)
axis: int = eqx.field(static=True)

_is_constant_jacobian: bool = True
_is_constant_log_det: bool = True

def __init__(
self,
indices_or_sections: Union[int, tuple[int, ...], list[int]],
axis: int = -1,
):
"""Initializes a Split bijector.

**Arguments:**

- `indices_or_sections`: If an integer `N`, the array will be divided into
`N` equal arrays along axis. If a tuple/list of sorted integers, the entries
indicate where along axis the array is split.
- `axis`: The axis along which to split. Defaults to -1 (last axis).
"""
# Ensure lists are converted to tuples so they remain hashable for JAX JIT
if isinstance(indices_or_sections, list):
indices_or_sections = tuple(indices_or_sections)

self.indices_or_sections = indices_or_sections
self.axis = axis

def forward_and_log_det(self, x: Array) -> tuple[tuple[Array, ...], Array]:
"""Computes y = tuple(split(x)) and log|det J(f)(x)| = 0."""
y = tuple(jnp.split(x, self.indices_or_sections, axis=self.axis))
return y, jnp.zeros((), dtype=x.dtype)

def inverse_and_log_det(self, y: tuple[Array, ...]) -> tuple[Array, Array]:
"""Computes x = concatenate(y) and log|det J(f^{-1})(y)| = 0."""
x = jnp.concatenate(y, axis=self.axis)
dtype = y[0].dtype if y else jnp.float32
return x, jnp.zeros((), dtype=dtype)

def same_as(self, other: AbstractBijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
return (
type(other) is Split
and self.indices_or_sections == other.indices_or_sections
and self.axis == other.axis
)
7 changes: 7 additions & 0 deletions docs/api/bijectors/split.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Split Bijector

::: distreqx.bijectors.Split
options:
members:
- __init__
---
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ nav:
- 'api/bijectors/scalar_affine.md'
- 'api/bijectors/shift.md'
- 'api/bijectors/sigmoid.md'
- 'api/bijectors/split.md'
- 'api/bijectors/tanh.md'
- 'api/bijectors/triangular_linear.md'
- 'api/bijectors/_bijector.md'
Expand Down
101 changes: 101 additions & 0 deletions tests/split_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from unittest import TestCase

import jax

jax.config.update("jax_enable_x64", True)

import equinox as eqx
import jax.numpy as jnp
import numpy as np
from parameterized import parameterized # type: ignore

from distreqx.bijectors import Split


class SplitTest(TestCase):
def setUp(self):
# Bijector 1: Split into 3 equal sections
self.bij_equal = Split(indices_or_sections=3, axis=-1)

# Bijector 2: Split at specific indices
# (elements up to index 2, from 2 to 5, and from 5 onward)
self.bij_indices = Split(indices_or_sections=(2, 5), axis=-1)

def assertion_fn(self, rtol=1e-5):
return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol)

@parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)])
def test_forward_and_log_det_equal_sections(self, name, dtype):
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)
y, log_det = self.bij_equal.forward_and_log_det(x)

self.assertIsInstance(y, tuple)
self.assertEqual(len(y), 3)
self.assertion_fn()(y[0], jnp.array([1.0, 2.0], dtype=dtype))
self.assertion_fn()(y[1], jnp.array([3.0, 4.0], dtype=dtype))
self.assertion_fn()(y[2], jnp.array([5.0, 6.0], dtype=dtype))

self.assertEqual(log_det.shape, ())
self.assertEqual(log_det, 0.0)
self.assertEqual(log_det.dtype, dtype)

@parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)])
def test_forward_and_log_det_indices(self, name, dtype):
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=dtype)
y, log_det = self.bij_indices.forward_and_log_det(x)

self.assertIsInstance(y, tuple)
self.assertEqual(len(y), 3)
self.assertion_fn()(y[0], jnp.array([1.0, 2.0], dtype=dtype)) # x[:2]
self.assertion_fn()(y[1], jnp.array([3.0, 4.0, 5.0], dtype=dtype)) # x[2:5]
self.assertion_fn()(y[2], jnp.array([6.0, 7.0], dtype=dtype)) # x[5:]

@parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)])
def test_inverse_and_log_det(self, name, dtype):
y = (
jnp.array([1.0, 2.0], dtype=dtype),
jnp.array([3.0, 4.0], dtype=dtype),
jnp.array([5.0, 6.0], dtype=dtype),
)
x, log_det = self.bij_equal.inverse_and_log_det(y)

self.assertion_fn()(x, jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype))

self.assertEqual(log_det.shape, ())
self.assertEqual(log_det, 0.0)
self.assertEqual(x.dtype, dtype)
self.assertEqual(log_det.dtype, dtype)

def test_jittable(self):
@eqx.filter_jit
def f_forward(bij, x):
return bij.forward_and_log_det(x)

@eqx.filter_jit
def f_inverse(bij, y):
return bij.inverse_and_log_det(y)

x = jnp.ones((6,))
y, log_det_fwd = f_forward(self.bij_equal, x)

self.assertIsInstance(y, tuple)
self.assertIsInstance(log_det_fwd, jax.Array)

x_reconstructed, log_det_inv = f_inverse(self.bij_equal, y)
self.assertIsInstance(x_reconstructed, jax.Array)
self.assertIsInstance(log_det_inv, jax.Array)

def test_same_as(self):
same_bij = Split(indices_or_sections=3, axis=-1)
diff_bij_1 = Split(indices_or_sections=4, axis=-1)
diff_bij_2 = Split(indices_or_sections=3, axis=0)

self.assertTrue(self.bij_equal.same_as(same_bij))
self.assertFalse(self.bij_equal.same_as(diff_bij_1))
self.assertFalse(self.bij_equal.same_as(diff_bij_2))

def test_list_to_tuple_conversion(self):
# A list should be converted to a tuple upon initialization for hashability
bij_list = Split(indices_or_sections=[2, 5])
self.assertIsInstance(bij_list.indices_or_sections, tuple)
self.assertTrue(self.bij_indices.same_as(bij_list))
Loading