Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions distreqx/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._block import Block as Block
from ._chain import Chain as Chain
from ._diag_linear import DiagLinear as DiagLinear
from ._leafwise import Leafwise as Leafwise
from ._linear import AbstractLinearBijector as AbstractLinearBijector
from ._scalar_affine import ScalarAffine as ScalarAffine
from ._shift import Shift as Shift
Expand Down
120 changes: 120 additions & 0 deletions distreqx/bijectors/_leafwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""TreeMap Bijector for applying a pytree of bijectors to a pytree of inputs."""

import functools

import jax
import jax.numpy as jnp
from jaxtyping import PyTree

from ._bijector import (
AbstractBijector,
AbstractFwdLogDetJacBijector,
AbstractInvLogDetJacBijector,
)


def _is_bijector(node: PyTree) -> bool:
return isinstance(node, AbstractBijector)


class Leafwise(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict=True):
"""Applies a pytree of bijectors to a pytree of inputs.

This behaves analogously to TensorFlow Probability's `JointMap`. It allows
applying independent bijectors to a structured input (e.g., a tuple or dict
of arrays) and aggregates the log-determinants across the structure.
"""

bijectors: PyTree[AbstractBijector]
_is_constant_jacobian: bool
_is_constant_log_det: bool

def __init__(self, bijectors: PyTree[AbstractBijector]):
"""Initializes a TreeMap bijector."""
leaves = jax.tree_util.tree_leaves(bijectors, is_leaf=_is_bijector)
if not leaves:
raise ValueError("The pytree of bijectors cannot be empty.")

self.bijectors = bijectors

is_constant_jacobian = all(b.is_constant_jacobian for b in leaves)
is_constant_log_det = all(b.is_constant_log_det for b in leaves)

if is_constant_log_det is None:
is_constant_log_det = is_constant_jacobian
if is_constant_jacobian and not is_constant_log_det:
raise ValueError(
"The Jacobian is said to be constant, but its "
"determinant is said not to be, which is impossible."
)
self._is_constant_jacobian = is_constant_jacobian
self._is_constant_log_det = is_constant_log_det

def forward(self, x: PyTree) -> PyTree:
"""Computes y = f(x)."""
return jax.tree_util.tree_map(
lambda b, v: b.forward(v), self.bijectors, x, is_leaf=_is_bijector
)

def inverse(self, y: PyTree) -> PyTree:
"""Computes x = f^{-1}(y)."""
return jax.tree_util.tree_map(
lambda b, v: b.inverse(v), self.bijectors, y, is_leaf=_is_bijector
)

def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]:
"""Computes y = f(x) and sum of log|det J(f)(x)|."""
ys_and_log_dets = jax.tree_util.tree_map(
lambda b, v: b.forward_and_log_det(v),
self.bijectors,
x,
is_leaf=_is_bijector,
)

y = jax.tree_util.tree_map(
lambda b, res: res[0], self.bijectors, ys_and_log_dets, is_leaf=_is_bijector
)
log_dets = jax.tree_util.tree_map(
lambda b, res: res[1], self.bijectors, ys_and_log_dets, is_leaf=_is_bijector
)

log_det_leaves = jax.tree_util.tree_leaves(log_dets)
total_log_det = functools.reduce(jnp.add, log_det_leaves)
return y, total_log_det

def inverse_and_log_det(self, y: PyTree) -> tuple[PyTree, PyTree]:
"""Computes x = f^{-1}(y) and sum of log|det J(f^{-1})(y)|."""
xs_and_log_dets = jax.tree_util.tree_map(
lambda b, v: b.inverse_and_log_det(v),
self.bijectors,
y,
is_leaf=_is_bijector,
)

x = jax.tree_util.tree_map(
lambda b, res: res[0], self.bijectors, xs_and_log_dets, is_leaf=_is_bijector
)
log_dets = jax.tree_util.tree_map(
lambda b, res: res[1], self.bijectors, xs_and_log_dets, is_leaf=_is_bijector
)

log_det_leaves = jax.tree_util.tree_leaves(log_dets)
total_log_det = functools.reduce(jnp.add, log_det_leaves)
return x, total_log_det

def same_as(self, other: AbstractBijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is Leafwise:
if jax.tree_util.tree_structure(
self.bijectors, is_leaf=_is_bijector
) != jax.tree_util.tree_structure(other.bijectors, is_leaf=_is_bijector):
return False

match_tree = jax.tree_util.tree_map(
lambda b1, b2: b1.same_as(b2),
self.bijectors,
other.bijectors,
is_leaf=_is_bijector,
)
return all(jax.tree_util.tree_leaves(match_tree))
return False
7 changes: 7 additions & 0 deletions docs/api/bijectors/leafwise.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Leafwise Bijector

::: distreqx.bijectors.Leafwise
options:
members:
- __init__
---
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ nav:
- 'api/bijectors/shift.md'
- 'api/bijectors/sigmoid.md'
- 'api/bijectors/tanh.md'
- 'api/bijectors/tree_map.md'
- 'api/bijectors/triangular_linear.md'
- 'api/bijectors/_bijector.md'
- Utilities:
Expand Down
191 changes: 191 additions & 0 deletions tests/leafwise_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Tests for `leafwise.py`."""

from unittest import TestCase

import jax
import jax.numpy as jnp
import numpy as np
from parameterized import parameterized # type: ignore

from distreqx.bijectors import Leafwise, Shift, Tanh
from distreqx.bijectors._bijector import AbstractBijector


def _is_bijector(node):
return isinstance(node, AbstractBijector)


class LeafwiseTest(TestCase):
def test_empty_tree_raises(self):
with self.assertRaisesRegex(
ValueError, "The pytree of bijectors cannot be empty"
):
Leafwise({})
with self.assertRaisesRegex(
ValueError, "The pytree of bijectors cannot be empty"
):
Leafwise([])

def test_jacobian_is_constant_property(self):
# All bijectors have constant jacobians
const_bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Shift(jnp.ones((2,)))})
self.assertTrue(const_bij.is_constant_jacobian)
self.assertTrue(const_bij.is_constant_log_det)

# Mixed bijectors (Tanh does not have a constant jacobian)
mixed_bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()})
self.assertFalse(mixed_bij.is_constant_jacobian)
self.assertFalse(mixed_bij.is_constant_log_det)

@parameterized.expand(
[
(
"dict_tree",
{"a": Shift(jnp.array(1.0)), "b": Tanh()},
{"a": jnp.array(0.0), "b": jnp.array(0.5)},
),
(
"tuple_tree",
(Shift(jnp.array(2.0)), Tanh()),
(jnp.array(0.0), jnp.array(-0.5)),
),
(
"nested_tree",
{"a": (Shift(jnp.array(1.0)),), "b": Tanh()},
{"a": (jnp.array(0.0),), "b": jnp.array(0.5)},
),
]
)
def test_forward_methods(self, name, bijectors, x):
tree_bij = Leafwise(bijectors)

y1 = tree_bij.forward(x)
logdet1 = tree_bij.forward_log_det_jacobian(x)
y2, logdet2 = tree_bij.forward_and_log_det(x)

# Verify tree structures match
self.assertEqual(
jax.tree_util.tree_structure(y1), jax.tree_util.tree_structure(x)
)
self.assertEqual(
jax.tree_util.tree_structure(y2), jax.tree_util.tree_structure(x)
)

# Manually compute expected values via tree.map
expected_y = jax.tree.map(
lambda b, v: b.forward(v), bijectors, x, is_leaf=_is_bijector
)
expected_logdets_tree = jax.tree.map(
lambda b, v: b.forward_log_det_jacobian(v),
bijectors,
x,
is_leaf=_is_bijector,
)
expected_logdet = sum(jax.tree_util.tree_leaves(expected_logdets_tree))

# Assert shapes and values
jax.tree.map(
lambda res, exp: self.assertEqual(res.shape, exp.shape), y1, expected_y
)
jax.tree.map(
lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), y1, expected_y
)
jax.tree.map(
lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), y2, expected_y
)

np.testing.assert_allclose(logdet1, expected_logdet, 1e-6)
np.testing.assert_allclose(logdet2, expected_logdet, 1e-6)

@parameterized.expand(
[
(
"dict_tree",
{"a": Shift(jnp.array(1.0)), "b": Tanh()},
{"a": jnp.array(1.0), "b": jnp.array(0.2)},
),
(
"tuple_tree",
(Shift(jnp.array(2.0)), Tanh()),
(jnp.array(2.0), jnp.array(-0.2)),
),
(
"nested_tree",
{"a": (Shift(jnp.array(1.0)),), "b": Tanh()},
{"a": (jnp.array(1.0),), "b": jnp.array(0.2)},
),
]
)
def test_inverse_methods(self, name, bijectors, y):
tree_bij = Leafwise(bijectors)

x1 = tree_bij.inverse(y)
logdet1 = tree_bij.inverse_log_det_jacobian(y)
x2, logdet2 = tree_bij.inverse_and_log_det(y)

# Verify tree structures match
self.assertEqual(
jax.tree_util.tree_structure(x1), jax.tree_util.tree_structure(y)
)
self.assertEqual(
jax.tree_util.tree_structure(x2), jax.tree_util.tree_structure(y)
)

# Manually compute expected values via tree_map
expected_x = jax.tree.map(
lambda b, v: b.inverse(v), bijectors, y, is_leaf=_is_bijector
)
expected_logdets_tree = jax.tree.map(
lambda b, v: b.inverse_log_det_jacobian(v),
bijectors,
y,
is_leaf=_is_bijector,
)
expected_logdet = sum(jax.tree_util.tree_leaves(expected_logdets_tree))

# Assert shapes and values
jax.tree.map(
lambda res, exp: self.assertEqual(res.shape, exp.shape), x1, expected_x
)
jax.tree.map(
lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), x1, expected_x
)
jax.tree.map(
lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), x2, expected_x
)

np.testing.assert_allclose(logdet1, expected_logdet, 1e-6)
np.testing.assert_allclose(logdet2, expected_logdet, 1e-6)

def test_jittable(self):
@jax.jit
def f(x, b):
return b.forward(x)

bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()})
x = {"a": np.zeros((4,)), "b": np.zeros((4,))}

z = f(x, bij)

self.assertIsInstance(z, dict)
self.assertIsInstance(z["a"], jnp.ndarray)
self.assertIsInstance(z["b"], jnp.ndarray)

def test_same_as_itself(self):
bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()})
self.assertTrue(bij.same_as(bij))

def test_not_same_as_others(self):
bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()})

# Completely different bijector
other_type = Shift(jnp.zeros((4,)))
self.assertFalse(bij.same_as(other_type))

# Same structure, different bijector parameters
different_params = Leafwise({"a": Shift(jnp.zeros((4,))), "b": Tanh()})
self.assertFalse(bij.same_as(different_params))

# Different structure
different_structure = Leafwise({"a": Shift(jnp.ones((4,)))})
self.assertFalse(bij.same_as(different_structure))
Loading