diff --git a/distreqx/bijectors/__init__.py b/distreqx/bijectors/__init__.py index 7805142..0377ba3 100644 --- a/distreqx/bijectors/__init__.py +++ b/distreqx/bijectors/__init__.py @@ -7,6 +7,7 @@ from ._block import Block as Block from ._chain import Chain as Chain from ._diag_linear import DiagLinear as DiagLinear +from ._leafwise import Leafwise as Leafwise from ._linear import AbstractLinearBijector as AbstractLinearBijector from ._scalar_affine import ScalarAffine as ScalarAffine from ._shift import Shift as Shift diff --git a/distreqx/bijectors/_leafwise.py b/distreqx/bijectors/_leafwise.py new file mode 100644 index 0000000..38eb466 --- /dev/null +++ b/distreqx/bijectors/_leafwise.py @@ -0,0 +1,120 @@ +"""TreeMap Bijector for applying a pytree of bijectors to a pytree of inputs.""" + +import functools + +import jax +import jax.numpy as jnp +from jaxtyping import PyTree + +from ._bijector import ( + AbstractBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, +) + + +def _is_bijector(node: PyTree) -> bool: + return isinstance(node, AbstractBijector) + + +class Leafwise(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict=True): + """Applies a pytree of bijectors to a pytree of inputs. + + This behaves analogously to TensorFlow Probability's `JointMap`. It allows + applying independent bijectors to a structured input (e.g., a tuple or dict + of arrays) and aggregates the log-determinants across the structure. + """ + + bijectors: PyTree[AbstractBijector] + _is_constant_jacobian: bool + _is_constant_log_det: bool + + def __init__(self, bijectors: PyTree[AbstractBijector]): + """Initializes a TreeMap bijector.""" + leaves = jax.tree_util.tree_leaves(bijectors, is_leaf=_is_bijector) + if not leaves: + raise ValueError("The pytree of bijectors cannot be empty.") + + self.bijectors = bijectors + + is_constant_jacobian = all(b.is_constant_jacobian for b in leaves) + is_constant_log_det = all(b.is_constant_log_det for b in leaves) + + if is_constant_log_det is None: + is_constant_log_det = is_constant_jacobian + if is_constant_jacobian and not is_constant_log_det: + raise ValueError( + "The Jacobian is said to be constant, but its " + "determinant is said not to be, which is impossible." + ) + self._is_constant_jacobian = is_constant_jacobian + self._is_constant_log_det = is_constant_log_det + + def forward(self, x: PyTree) -> PyTree: + """Computes y = f(x).""" + return jax.tree_util.tree_map( + lambda b, v: b.forward(v), self.bijectors, x, is_leaf=_is_bijector + ) + + def inverse(self, y: PyTree) -> PyTree: + """Computes x = f^{-1}(y).""" + return jax.tree_util.tree_map( + lambda b, v: b.inverse(v), self.bijectors, y, is_leaf=_is_bijector + ) + + def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]: + """Computes y = f(x) and sum of log|det J(f)(x)|.""" + ys_and_log_dets = jax.tree_util.tree_map( + lambda b, v: b.forward_and_log_det(v), + self.bijectors, + x, + is_leaf=_is_bijector, + ) + + y = jax.tree_util.tree_map( + lambda b, res: res[0], self.bijectors, ys_and_log_dets, is_leaf=_is_bijector + ) + log_dets = jax.tree_util.tree_map( + lambda b, res: res[1], self.bijectors, ys_and_log_dets, is_leaf=_is_bijector + ) + + log_det_leaves = jax.tree_util.tree_leaves(log_dets) + total_log_det = functools.reduce(jnp.add, log_det_leaves) + return y, total_log_det + + def inverse_and_log_det(self, y: PyTree) -> tuple[PyTree, PyTree]: + """Computes x = f^{-1}(y) and sum of log|det J(f^{-1})(y)|.""" + xs_and_log_dets = jax.tree_util.tree_map( + lambda b, v: b.inverse_and_log_det(v), + self.bijectors, + y, + is_leaf=_is_bijector, + ) + + x = jax.tree_util.tree_map( + lambda b, res: res[0], self.bijectors, xs_and_log_dets, is_leaf=_is_bijector + ) + log_dets = jax.tree_util.tree_map( + lambda b, res: res[1], self.bijectors, xs_and_log_dets, is_leaf=_is_bijector + ) + + log_det_leaves = jax.tree_util.tree_leaves(log_dets) + total_log_det = functools.reduce(jnp.add, log_det_leaves) + return x, total_log_det + + def same_as(self, other: AbstractBijector) -> bool: + """Returns True if this bijector is guaranteed to be the same as `other`.""" + if type(other) is Leafwise: + if jax.tree_util.tree_structure( + self.bijectors, is_leaf=_is_bijector + ) != jax.tree_util.tree_structure(other.bijectors, is_leaf=_is_bijector): + return False + + match_tree = jax.tree_util.tree_map( + lambda b1, b2: b1.same_as(b2), + self.bijectors, + other.bijectors, + is_leaf=_is_bijector, + ) + return all(jax.tree_util.tree_leaves(match_tree)) + return False diff --git a/docs/api/bijectors/leafwise.md b/docs/api/bijectors/leafwise.md new file mode 100644 index 0000000..a3e18f1 --- /dev/null +++ b/docs/api/bijectors/leafwise.md @@ -0,0 +1,7 @@ +# Leafwise Bijector + +::: distreqx.bijectors.Leafwise + options: + members: + - __init__ +--- \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index aa5cf4e..d5acb39 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -124,6 +124,7 @@ nav: - 'api/bijectors/shift.md' - 'api/bijectors/sigmoid.md' - 'api/bijectors/tanh.md' + - 'api/bijectors/tree_map.md' - 'api/bijectors/triangular_linear.md' - 'api/bijectors/_bijector.md' - Utilities: diff --git a/tests/leafwise_test.py b/tests/leafwise_test.py new file mode 100644 index 0000000..31788c7 --- /dev/null +++ b/tests/leafwise_test.py @@ -0,0 +1,191 @@ +"""Tests for `leafwise.py`.""" + +from unittest import TestCase + +import jax +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import Leafwise, Shift, Tanh +from distreqx.bijectors._bijector import AbstractBijector + + +def _is_bijector(node): + return isinstance(node, AbstractBijector) + + +class LeafwiseTest(TestCase): + def test_empty_tree_raises(self): + with self.assertRaisesRegex( + ValueError, "The pytree of bijectors cannot be empty" + ): + Leafwise({}) + with self.assertRaisesRegex( + ValueError, "The pytree of bijectors cannot be empty" + ): + Leafwise([]) + + def test_jacobian_is_constant_property(self): + # All bijectors have constant jacobians + const_bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Shift(jnp.ones((2,)))}) + self.assertTrue(const_bij.is_constant_jacobian) + self.assertTrue(const_bij.is_constant_log_det) + + # Mixed bijectors (Tanh does not have a constant jacobian) + mixed_bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + self.assertFalse(mixed_bij.is_constant_jacobian) + self.assertFalse(mixed_bij.is_constant_log_det) + + @parameterized.expand( + [ + ( + "dict_tree", + {"a": Shift(jnp.array(1.0)), "b": Tanh()}, + {"a": jnp.array(0.0), "b": jnp.array(0.5)}, + ), + ( + "tuple_tree", + (Shift(jnp.array(2.0)), Tanh()), + (jnp.array(0.0), jnp.array(-0.5)), + ), + ( + "nested_tree", + {"a": (Shift(jnp.array(1.0)),), "b": Tanh()}, + {"a": (jnp.array(0.0),), "b": jnp.array(0.5)}, + ), + ] + ) + def test_forward_methods(self, name, bijectors, x): + tree_bij = Leafwise(bijectors) + + y1 = tree_bij.forward(x) + logdet1 = tree_bij.forward_log_det_jacobian(x) + y2, logdet2 = tree_bij.forward_and_log_det(x) + + # Verify tree structures match + self.assertEqual( + jax.tree_util.tree_structure(y1), jax.tree_util.tree_structure(x) + ) + self.assertEqual( + jax.tree_util.tree_structure(y2), jax.tree_util.tree_structure(x) + ) + + # Manually compute expected values via tree.map + expected_y = jax.tree.map( + lambda b, v: b.forward(v), bijectors, x, is_leaf=_is_bijector + ) + expected_logdets_tree = jax.tree.map( + lambda b, v: b.forward_log_det_jacobian(v), + bijectors, + x, + is_leaf=_is_bijector, + ) + expected_logdet = sum(jax.tree_util.tree_leaves(expected_logdets_tree)) + + # Assert shapes and values + jax.tree.map( + lambda res, exp: self.assertEqual(res.shape, exp.shape), y1, expected_y + ) + jax.tree.map( + lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), y1, expected_y + ) + jax.tree.map( + lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), y2, expected_y + ) + + np.testing.assert_allclose(logdet1, expected_logdet, 1e-6) + np.testing.assert_allclose(logdet2, expected_logdet, 1e-6) + + @parameterized.expand( + [ + ( + "dict_tree", + {"a": Shift(jnp.array(1.0)), "b": Tanh()}, + {"a": jnp.array(1.0), "b": jnp.array(0.2)}, + ), + ( + "tuple_tree", + (Shift(jnp.array(2.0)), Tanh()), + (jnp.array(2.0), jnp.array(-0.2)), + ), + ( + "nested_tree", + {"a": (Shift(jnp.array(1.0)),), "b": Tanh()}, + {"a": (jnp.array(1.0),), "b": jnp.array(0.2)}, + ), + ] + ) + def test_inverse_methods(self, name, bijectors, y): + tree_bij = Leafwise(bijectors) + + x1 = tree_bij.inverse(y) + logdet1 = tree_bij.inverse_log_det_jacobian(y) + x2, logdet2 = tree_bij.inverse_and_log_det(y) + + # Verify tree structures match + self.assertEqual( + jax.tree_util.tree_structure(x1), jax.tree_util.tree_structure(y) + ) + self.assertEqual( + jax.tree_util.tree_structure(x2), jax.tree_util.tree_structure(y) + ) + + # Manually compute expected values via tree_map + expected_x = jax.tree.map( + lambda b, v: b.inverse(v), bijectors, y, is_leaf=_is_bijector + ) + expected_logdets_tree = jax.tree.map( + lambda b, v: b.inverse_log_det_jacobian(v), + bijectors, + y, + is_leaf=_is_bijector, + ) + expected_logdet = sum(jax.tree_util.tree_leaves(expected_logdets_tree)) + + # Assert shapes and values + jax.tree.map( + lambda res, exp: self.assertEqual(res.shape, exp.shape), x1, expected_x + ) + jax.tree.map( + lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), x1, expected_x + ) + jax.tree.map( + lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), x2, expected_x + ) + + np.testing.assert_allclose(logdet1, expected_logdet, 1e-6) + np.testing.assert_allclose(logdet2, expected_logdet, 1e-6) + + def test_jittable(self): + @jax.jit + def f(x, b): + return b.forward(x) + + bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + x = {"a": np.zeros((4,)), "b": np.zeros((4,))} + + z = f(x, bij) + + self.assertIsInstance(z, dict) + self.assertIsInstance(z["a"], jnp.ndarray) + self.assertIsInstance(z["b"], jnp.ndarray) + + def test_same_as_itself(self): + bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + self.assertTrue(bij.same_as(bij)) + + def test_not_same_as_others(self): + bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + + # Completely different bijector + other_type = Shift(jnp.zeros((4,))) + self.assertFalse(bij.same_as(other_type)) + + # Same structure, different bijector parameters + different_params = Leafwise({"a": Shift(jnp.zeros((4,))), "b": Tanh()}) + self.assertFalse(bij.same_as(different_params)) + + # Different structure + different_structure = Leafwise({"a": Shift(jnp.ones((4,)))}) + self.assertFalse(bij.same_as(different_structure))