From ed51cede9f54c31e84aaaa31ab7bf7236edbfd08 Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Thu, 30 Apr 2026 18:15:12 +0200 Subject: [PATCH 1/5] Add TreeMap --- distreqx/bijectors/__init__.py | 1 + distreqx/bijectors/_treemap.py | 120 ++++++++++++++++++++ tests/treemap_test.py | 195 +++++++++++++++++++++++++++++++++ 3 files changed, 316 insertions(+) create mode 100644 distreqx/bijectors/_treemap.py create mode 100644 tests/treemap_test.py diff --git a/distreqx/bijectors/__init__.py b/distreqx/bijectors/__init__.py index 7805142..4b333da 100644 --- a/distreqx/bijectors/__init__.py +++ b/distreqx/bijectors/__init__.py @@ -12,5 +12,6 @@ from ._shift import Shift as Shift from ._sigmoid import Sigmoid as Sigmoid from ._tanh import Tanh as Tanh +from ._treemap import TreeMap as TreeMap from ._triangular_linear import TriangularLinear as TriangularLinear from ._unconstrained_affine import UnconstrainedAffine as UnconstrainedAffine diff --git a/distreqx/bijectors/_treemap.py b/distreqx/bijectors/_treemap.py new file mode 100644 index 0000000..0e8dab3 --- /dev/null +++ b/distreqx/bijectors/_treemap.py @@ -0,0 +1,120 @@ +"""TreeMap Bijector for applying a pytree of bijectors to a pytree of inputs.""" + +import functools + +import jax +import jax.numpy as jnp +from jaxtyping import PyTree + +from ._bijector import ( + AbstractBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, +) + + +def _is_bijector(node: PyTree) -> bool: + return isinstance(node, AbstractBijector) + + +class TreeMap(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict=True): + """Applies a pytree of bijectors to a pytree of inputs. + + This behaves analogously to TensorFlow Probability's `JointMap`. It allows + applying independent bijectors to a structured input (e.g., a tuple or dict + of arrays) and aggregates the log-determinants across the structure. + """ + + bijectors: PyTree + _is_constant_jacobian: bool + _is_constant_log_det: bool + + def __init__(self, bijectors: PyTree): + """Initializes a TreeMap bijector.""" + leaves = jax.tree_util.tree_leaves(bijectors, is_leaf=_is_bijector) + if not leaves: + raise ValueError("The pytree of bijectors cannot be empty.") + + self.bijectors = bijectors + + is_constant_jacobian = all(b.is_constant_jacobian for b in leaves) + is_constant_log_det = all(b.is_constant_log_det for b in leaves) + + if is_constant_log_det is None: + is_constant_log_det = is_constant_jacobian + if is_constant_jacobian and not is_constant_log_det: + raise ValueError( + "The Jacobian is said to be constant, but its " + "determinant is said not to be, which is impossible." + ) + self._is_constant_jacobian = is_constant_jacobian + self._is_constant_log_det = is_constant_log_det + + def forward(self, x: PyTree) -> PyTree: + """Computes y = f(x).""" + return jax.tree_util.tree_map( + lambda b, v: b.forward(v), self.bijectors, x, is_leaf=_is_bijector + ) + + def inverse(self, y: PyTree) -> PyTree: + """Computes x = f^{-1}(y).""" + return jax.tree_util.tree_map( + lambda b, v: b.inverse(v), self.bijectors, y, is_leaf=_is_bijector + ) + + def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]: + """Computes y = f(x) and sum of log|det J(f)(x)|.""" + ys_and_log_dets = jax.tree_util.tree_map( + lambda b, v: b.forward_and_log_det(v), + self.bijectors, + x, + is_leaf=_is_bijector, + ) + + y = jax.tree_util.tree_map( + lambda b, res: res[0], self.bijectors, ys_and_log_dets, is_leaf=_is_bijector + ) + log_dets = jax.tree_util.tree_map( + lambda b, res: res[1], self.bijectors, ys_and_log_dets, is_leaf=_is_bijector + ) + + log_det_leaves = jax.tree_util.tree_leaves(log_dets) + total_log_det = functools.reduce(jnp.add, log_det_leaves) + return y, total_log_det + + def inverse_and_log_det(self, y: PyTree) -> tuple[PyTree, PyTree]: + """Computes x = f^{-1}(y) and sum of log|det J(f^{-1})(y)|.""" + xs_and_log_dets = jax.tree_util.tree_map( + lambda b, v: b.inverse_and_log_det(v), + self.bijectors, + y, + is_leaf=_is_bijector, + ) + + x = jax.tree_util.tree_map( + lambda b, res: res[0], self.bijectors, xs_and_log_dets, is_leaf=_is_bijector + ) + log_dets = jax.tree_util.tree_map( + lambda b, res: res[1], self.bijectors, xs_and_log_dets, is_leaf=_is_bijector + ) + + log_det_leaves = jax.tree_util.tree_leaves(log_dets) + total_log_det = functools.reduce(jnp.add, log_det_leaves) + return x, total_log_det + + def same_as(self, other: AbstractBijector) -> bool: + """Returns True if this bijector is guaranteed to be the same as `other`.""" + if type(other) is TreeMap: + if jax.tree_util.tree_structure( + self.bijectors, is_leaf=_is_bijector + ) != jax.tree_util.tree_structure(other.bijectors, is_leaf=_is_bijector): + return False + + match_tree = jax.tree_util.tree_map( + lambda b1, b2: b1.same_as(b2), + self.bijectors, + other.bijectors, + is_leaf=_is_bijector, + ) + return all(jax.tree_util.tree_leaves(match_tree)) + return False diff --git a/tests/treemap_test.py b/tests/treemap_test.py new file mode 100644 index 0000000..3b180b5 --- /dev/null +++ b/tests/treemap_test.py @@ -0,0 +1,195 @@ +"""Tests for `tree_map.py`.""" + +from unittest import TestCase + +import jax +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import Shift, Tanh, TreeMap +from distreqx.bijectors._bijector import AbstractBijector + + +def _is_bijector(node): + return isinstance(node, AbstractBijector) + + +class TreeMapTest(TestCase): + def test_empty_tree_raises(self): + with self.assertRaisesRegex( + ValueError, "The pytree of bijectors cannot be empty" + ): + TreeMap({}) + with self.assertRaisesRegex( + ValueError, "The pytree of bijectors cannot be empty" + ): + TreeMap([]) + + def test_jacobian_is_constant_property(self): + # All bijectors have constant jacobians + const_bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Shift(jnp.ones((2,)))}) + self.assertTrue(const_bij.is_constant_jacobian) + self.assertTrue(const_bij.is_constant_log_det) + + # Mixed bijectors (Tanh does not have a constant jacobian) + mixed_bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + self.assertFalse(mixed_bij.is_constant_jacobian) + self.assertFalse(mixed_bij.is_constant_log_det) + + @parameterized.expand( + [ + ( + "dict_tree", + {"a": Shift(jnp.array(1.0)), "b": Tanh()}, + {"a": jnp.array(0.0), "b": jnp.array(0.5)}, + ), + ( + "tuple_tree", + (Shift(jnp.array(2.0)), Tanh()), + (jnp.array(0.0), jnp.array(-0.5)), + ), + ( + "nested_tree", + {"a": (Shift(jnp.array(1.0)),), "b": Tanh()}, + {"a": (jnp.array(0.0),), "b": jnp.array(0.5)}, + ), + ] + ) + def test_forward_methods(self, name, bijectors, x): + tree_bij = TreeMap(bijectors) + + y1 = tree_bij.forward(x) + logdet1 = tree_bij.forward_log_det_jacobian(x) + y2, logdet2 = tree_bij.forward_and_log_det(x) + + # Verify tree structures match + self.assertEqual( + jax.tree_util.tree_structure(y1), jax.tree_util.tree_structure(x) + ) + self.assertEqual( + jax.tree_util.tree_structure(y2), jax.tree_util.tree_structure(x) + ) + + # Manually compute expected values via tree_map + # using is_leaf to stop at bijectors + expected_y = jax.tree_util.tree_map( + lambda b, v: b.forward(v), bijectors, x, is_leaf=_is_bijector + ) + expected_logdets_tree = jax.tree_util.tree_map( + lambda b, v: b.forward_log_det_jacobian(v), + bijectors, + x, + is_leaf=_is_bijector, + ) + expected_logdet = sum(jax.tree_util.tree_leaves(expected_logdets_tree)) + + # Assert shapes and values + jax.tree_util.tree_map( + lambda res, exp: self.assertEqual(res.shape, exp.shape), y1, expected_y + ) + jax.tree_util.tree_map( + lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), y1, expected_y + ) + jax.tree_util.tree_map( + lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), y2, expected_y + ) + + np.testing.assert_allclose(logdet1, expected_logdet, 1e-6) + np.testing.assert_allclose(logdet2, expected_logdet, 1e-6) + + @parameterized.expand( + [ + ( + "dict_tree", + {"a": Shift(jnp.array(1.0)), "b": Tanh()}, + {"a": jnp.array(1.0), "b": jnp.array(0.2)}, + ), + ( + "tuple_tree", + (Shift(jnp.array(2.0)), Tanh()), + (jnp.array(2.0), jnp.array(-0.2)), + ), + ( + "nested_tree", + {"a": (Shift(jnp.array(1.0)),), "b": Tanh()}, + {"a": (jnp.array(1.0),), "b": jnp.array(0.2)}, + ), + ] + ) + def test_inverse_methods(self, name, bijectors, y): + tree_bij = TreeMap(bijectors) + + x1 = tree_bij.inverse(y) + logdet1 = tree_bij.inverse_log_det_jacobian(y) + x2, logdet2 = tree_bij.inverse_and_log_det(y) + + # Verify tree structures match + self.assertEqual( + jax.tree_util.tree_structure(x1), jax.tree_util.tree_structure(y) + ) + self.assertEqual( + jax.tree_util.tree_structure(x2), jax.tree_util.tree_structure(y) + ) + + # Manually compute expected values via tree_map + # using is_leaf to stop at bijectors + expected_x = jax.tree_util.tree_map( + lambda b, v: b.inverse(v), bijectors, y, is_leaf=_is_bijector + ) + expected_logdets_tree = jax.tree_util.tree_map( + lambda b, v: b.inverse_log_det_jacobian(v), + bijectors, + y, + is_leaf=_is_bijector, + ) + expected_logdet = sum(jax.tree_util.tree_leaves(expected_logdets_tree)) + + # Assert shapes and values + jax.tree_util.tree_map( + lambda res, exp: self.assertEqual(res.shape, exp.shape), x1, expected_x + ) + jax.tree_util.tree_map( + lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), x1, expected_x + ) + jax.tree_util.tree_map( + lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), x2, expected_x + ) + + np.testing.assert_allclose(logdet1, expected_logdet, 1e-6) + np.testing.assert_allclose(logdet2, expected_logdet, 1e-6) + + def test_jittable(self): + @jax.jit + def f(x, b): + return b.forward(x) + + bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + x = {"a": np.zeros((4,)), "b": np.zeros((4,))} + + z = f(x, bij) + + self.assertIsInstance(z, dict) + self.assertIsInstance(z["a"], jnp.ndarray) + self.assertIsInstance(z["b"], jnp.ndarray) + + def test_same_as_itself(self): + bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + # Distreqx bijectors often evaluate False for different object + # instances, so checking the exact same instance is the correct test. + self.assertTrue(bij.same_as(bij)) + + def test_not_same_as_others(self): + bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + + # Completely different bijector + other_type = Shift(jnp.zeros((4,))) + self.assertFalse(bij.same_as(other_type)) + + # Same structure, different bijector parameters + different_params = TreeMap({"a": Shift(jnp.zeros((4,))), "b": Tanh()}) + self.assertFalse(bij.same_as(different_params)) + + # Different structure + different_structure = TreeMap({"a": Shift(jnp.ones((4,)))}) + self.assertFalse(bij.same_as(different_structure)) From 927b4be65a10859fdaf706bd58fc35484150e3d2 Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Thu, 30 Apr 2026 18:27:09 +0200 Subject: [PATCH 2/5] Updated docs --- distreqx/bijectors/_treemap.py | 4 ++-- docs/api/bijectors/tree_map.md | 7 +++++++ mkdocs.yml | 1 + 3 files changed, 10 insertions(+), 2 deletions(-) create mode 100644 docs/api/bijectors/tree_map.md diff --git a/distreqx/bijectors/_treemap.py b/distreqx/bijectors/_treemap.py index 0e8dab3..f3c4244 100644 --- a/distreqx/bijectors/_treemap.py +++ b/distreqx/bijectors/_treemap.py @@ -25,11 +25,11 @@ class TreeMap(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict of arrays) and aggregates the log-determinants across the structure. """ - bijectors: PyTree + bijectors: PyTree[AbstractBijector] _is_constant_jacobian: bool _is_constant_log_det: bool - def __init__(self, bijectors: PyTree): + def __init__(self, bijectors: PyTree[AbstractBijector]): """Initializes a TreeMap bijector.""" leaves = jax.tree_util.tree_leaves(bijectors, is_leaf=_is_bijector) if not leaves: diff --git a/docs/api/bijectors/tree_map.md b/docs/api/bijectors/tree_map.md new file mode 100644 index 0000000..8cc0e9f --- /dev/null +++ b/docs/api/bijectors/tree_map.md @@ -0,0 +1,7 @@ +# Tree Map Bijector + +::: distreqx.bijectors.TreeMap + options: + members: + - __init__ +--- \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index aa5cf4e..d5acb39 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -124,6 +124,7 @@ nav: - 'api/bijectors/shift.md' - 'api/bijectors/sigmoid.md' - 'api/bijectors/tanh.md' + - 'api/bijectors/tree_map.md' - 'api/bijectors/triangular_linear.md' - 'api/bijectors/_bijector.md' - Utilities: From 0b8614f4feacadd53741f4eb0b3e7f1e5bdc4a2f Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Thu, 30 Apr 2026 22:13:55 +0200 Subject: [PATCH 3/5] Updated TreeMap to safely pass None to align with Equinox partitioning --- distreqx/bijectors/_treemap.py | 38 +++++++++++++++++++++++++++------- tests/treemap_test.py | 9 +++----- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/distreqx/bijectors/_treemap.py b/distreqx/bijectors/_treemap.py index f3c4244..54ddfce 100644 --- a/distreqx/bijectors/_treemap.py +++ b/distreqx/bijectors/_treemap.py @@ -23,6 +23,7 @@ class TreeMap(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict This behaves analogously to TensorFlow Probability's `JointMap`. It allows applying independent bijectors to a structured input (e.g., a tuple or dict of arrays) and aggregates the log-determinants across the structure. + `None` values in the bijector pytree act as identity transformations. """ bijectors: PyTree[AbstractBijector] @@ -31,9 +32,15 @@ class TreeMap(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict def __init__(self, bijectors: PyTree[AbstractBijector]): """Initializes a TreeMap bijector.""" - leaves = jax.tree_util.tree_leaves(bijectors, is_leaf=_is_bijector) + leaves = [ + b + for b in jax.tree_util.tree_leaves(bijectors, is_leaf=_is_bijector) + if b is not None + ] if not leaves: - raise ValueError("The pytree of bijectors cannot be empty.") + raise ValueError( + "The pytree of bijectors must contain at least one valid bijector." + ) self.bijectors = bijectors @@ -53,19 +60,27 @@ def __init__(self, bijectors: PyTree[AbstractBijector]): def forward(self, x: PyTree) -> PyTree: """Computes y = f(x).""" return jax.tree_util.tree_map( - lambda b, v: b.forward(v), self.bijectors, x, is_leaf=_is_bijector + lambda b, v: b.forward(v) if b is not None else v, + self.bijectors, + x, + is_leaf=_is_bijector, ) def inverse(self, y: PyTree) -> PyTree: """Computes x = f^{-1}(y).""" return jax.tree_util.tree_map( - lambda b, v: b.inverse(v), self.bijectors, y, is_leaf=_is_bijector + lambda b, v: b.inverse(v) if b is not None else v, + self.bijectors, + y, + is_leaf=_is_bijector, ) def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]: """Computes y = f(x) and sum of log|det J(f)(x)|.""" ys_and_log_dets = jax.tree_util.tree_map( - lambda b, v: b.forward_and_log_det(v), + lambda b, v: ( + b.forward_and_log_det(v) if b is not None else (v, jnp.array(0.0)) + ), self.bijectors, x, is_leaf=_is_bijector, @@ -85,7 +100,9 @@ def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]: def inverse_and_log_det(self, y: PyTree) -> tuple[PyTree, PyTree]: """Computes x = f^{-1}(y) and sum of log|det J(f^{-1})(y)|.""" xs_and_log_dets = jax.tree_util.tree_map( - lambda b, v: b.inverse_and_log_det(v), + lambda b, v: ( + b.inverse_and_log_det(v) if b is not None else (v, jnp.array(0.0)) + ), self.bijectors, y, is_leaf=_is_bijector, @@ -110,8 +127,15 @@ def same_as(self, other: AbstractBijector) -> bool: ) != jax.tree_util.tree_structure(other.bijectors, is_leaf=_is_bijector): return False + def _check_same(b1, b2): + if b1 is None and b2 is None: + return True + if b1 is None or b2 is None: + return False + return b1.same_as(b2) + match_tree = jax.tree_util.tree_map( - lambda b1, b2: b1.same_as(b2), + _check_same, self.bijectors, other.bijectors, is_leaf=_is_bijector, diff --git a/tests/treemap_test.py b/tests/treemap_test.py index 3b180b5..fa38e59 100644 --- a/tests/treemap_test.py +++ b/tests/treemap_test.py @@ -17,13 +17,10 @@ def _is_bijector(node): class TreeMapTest(TestCase): def test_empty_tree_raises(self): - with self.assertRaisesRegex( - ValueError, "The pytree of bijectors cannot be empty" - ): + msg = "must contain at least one valid bijector" + with self.assertRaisesRegex(ValueError, msg): TreeMap({}) - with self.assertRaisesRegex( - ValueError, "The pytree of bijectors cannot be empty" - ): + with self.assertRaisesRegex(ValueError, msg): TreeMap([]) def test_jacobian_is_constant_property(self): From 358e93e5157695883021db231a8996ce72281c51 Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Thu, 30 Apr 2026 22:41:17 +0200 Subject: [PATCH 4/5] Revert "Updated TreeMap to safely pass None to align with Equinox partitioning" This reverts commit 0b8614f4feacadd53741f4eb0b3e7f1e5bdc4a2f. --- distreqx/bijectors/_treemap.py | 38 +++++++--------------------------- tests/treemap_test.py | 9 +++++--- 2 files changed, 13 insertions(+), 34 deletions(-) diff --git a/distreqx/bijectors/_treemap.py b/distreqx/bijectors/_treemap.py index 54ddfce..f3c4244 100644 --- a/distreqx/bijectors/_treemap.py +++ b/distreqx/bijectors/_treemap.py @@ -23,7 +23,6 @@ class TreeMap(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict This behaves analogously to TensorFlow Probability's `JointMap`. It allows applying independent bijectors to a structured input (e.g., a tuple or dict of arrays) and aggregates the log-determinants across the structure. - `None` values in the bijector pytree act as identity transformations. """ bijectors: PyTree[AbstractBijector] @@ -32,15 +31,9 @@ class TreeMap(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict def __init__(self, bijectors: PyTree[AbstractBijector]): """Initializes a TreeMap bijector.""" - leaves = [ - b - for b in jax.tree_util.tree_leaves(bijectors, is_leaf=_is_bijector) - if b is not None - ] + leaves = jax.tree_util.tree_leaves(bijectors, is_leaf=_is_bijector) if not leaves: - raise ValueError( - "The pytree of bijectors must contain at least one valid bijector." - ) + raise ValueError("The pytree of bijectors cannot be empty.") self.bijectors = bijectors @@ -60,27 +53,19 @@ def __init__(self, bijectors: PyTree[AbstractBijector]): def forward(self, x: PyTree) -> PyTree: """Computes y = f(x).""" return jax.tree_util.tree_map( - lambda b, v: b.forward(v) if b is not None else v, - self.bijectors, - x, - is_leaf=_is_bijector, + lambda b, v: b.forward(v), self.bijectors, x, is_leaf=_is_bijector ) def inverse(self, y: PyTree) -> PyTree: """Computes x = f^{-1}(y).""" return jax.tree_util.tree_map( - lambda b, v: b.inverse(v) if b is not None else v, - self.bijectors, - y, - is_leaf=_is_bijector, + lambda b, v: b.inverse(v), self.bijectors, y, is_leaf=_is_bijector ) def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]: """Computes y = f(x) and sum of log|det J(f)(x)|.""" ys_and_log_dets = jax.tree_util.tree_map( - lambda b, v: ( - b.forward_and_log_det(v) if b is not None else (v, jnp.array(0.0)) - ), + lambda b, v: b.forward_and_log_det(v), self.bijectors, x, is_leaf=_is_bijector, @@ -100,9 +85,7 @@ def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]: def inverse_and_log_det(self, y: PyTree) -> tuple[PyTree, PyTree]: """Computes x = f^{-1}(y) and sum of log|det J(f^{-1})(y)|.""" xs_and_log_dets = jax.tree_util.tree_map( - lambda b, v: ( - b.inverse_and_log_det(v) if b is not None else (v, jnp.array(0.0)) - ), + lambda b, v: b.inverse_and_log_det(v), self.bijectors, y, is_leaf=_is_bijector, @@ -127,15 +110,8 @@ def same_as(self, other: AbstractBijector) -> bool: ) != jax.tree_util.tree_structure(other.bijectors, is_leaf=_is_bijector): return False - def _check_same(b1, b2): - if b1 is None and b2 is None: - return True - if b1 is None or b2 is None: - return False - return b1.same_as(b2) - match_tree = jax.tree_util.tree_map( - _check_same, + lambda b1, b2: b1.same_as(b2), self.bijectors, other.bijectors, is_leaf=_is_bijector, diff --git a/tests/treemap_test.py b/tests/treemap_test.py index fa38e59..3b180b5 100644 --- a/tests/treemap_test.py +++ b/tests/treemap_test.py @@ -17,10 +17,13 @@ def _is_bijector(node): class TreeMapTest(TestCase): def test_empty_tree_raises(self): - msg = "must contain at least one valid bijector" - with self.assertRaisesRegex(ValueError, msg): + with self.assertRaisesRegex( + ValueError, "The pytree of bijectors cannot be empty" + ): TreeMap({}) - with self.assertRaisesRegex(ValueError, msg): + with self.assertRaisesRegex( + ValueError, "The pytree of bijectors cannot be empty" + ): TreeMap([]) def test_jacobian_is_constant_property(self): From 67913de74393948fe761f48234739c3ffc917a63 Mon Sep 17 00:00:00 2001 From: Gary Allen Date: Fri, 8 May 2026 11:33:54 +0200 Subject: [PATCH 5/5] Renamed TreeMap to Leafwise --- distreqx/bijectors/__init__.py | 2 +- .../bijectors/{_treemap.py => _leafwise.py} | 4 +- .../bijectors/{tree_map.md => leafwise.md} | 4 +- tests/{treemap_test.py => leafwise_test.py} | 54 +++++++++---------- 4 files changed, 30 insertions(+), 34 deletions(-) rename distreqx/bijectors/{_treemap.py => _leafwise.py} (97%) rename docs/api/bijectors/{tree_map.md => leafwise.md} (52%) rename tests/{treemap_test.py => leafwise_test.py} (79%) diff --git a/distreqx/bijectors/__init__.py b/distreqx/bijectors/__init__.py index 4b333da..0377ba3 100644 --- a/distreqx/bijectors/__init__.py +++ b/distreqx/bijectors/__init__.py @@ -7,11 +7,11 @@ from ._block import Block as Block from ._chain import Chain as Chain from ._diag_linear import DiagLinear as DiagLinear +from ._leafwise import Leafwise as Leafwise from ._linear import AbstractLinearBijector as AbstractLinearBijector from ._scalar_affine import ScalarAffine as ScalarAffine from ._shift import Shift as Shift from ._sigmoid import Sigmoid as Sigmoid from ._tanh import Tanh as Tanh -from ._treemap import TreeMap as TreeMap from ._triangular_linear import TriangularLinear as TriangularLinear from ._unconstrained_affine import UnconstrainedAffine as UnconstrainedAffine diff --git a/distreqx/bijectors/_treemap.py b/distreqx/bijectors/_leafwise.py similarity index 97% rename from distreqx/bijectors/_treemap.py rename to distreqx/bijectors/_leafwise.py index f3c4244..38eb466 100644 --- a/distreqx/bijectors/_treemap.py +++ b/distreqx/bijectors/_leafwise.py @@ -17,7 +17,7 @@ def _is_bijector(node: PyTree) -> bool: return isinstance(node, AbstractBijector) -class TreeMap(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict=True): +class Leafwise(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict=True): """Applies a pytree of bijectors to a pytree of inputs. This behaves analogously to TensorFlow Probability's `JointMap`. It allows @@ -104,7 +104,7 @@ def inverse_and_log_det(self, y: PyTree) -> tuple[PyTree, PyTree]: def same_as(self, other: AbstractBijector) -> bool: """Returns True if this bijector is guaranteed to be the same as `other`.""" - if type(other) is TreeMap: + if type(other) is Leafwise: if jax.tree_util.tree_structure( self.bijectors, is_leaf=_is_bijector ) != jax.tree_util.tree_structure(other.bijectors, is_leaf=_is_bijector): diff --git a/docs/api/bijectors/tree_map.md b/docs/api/bijectors/leafwise.md similarity index 52% rename from docs/api/bijectors/tree_map.md rename to docs/api/bijectors/leafwise.md index 8cc0e9f..a3e18f1 100644 --- a/docs/api/bijectors/tree_map.md +++ b/docs/api/bijectors/leafwise.md @@ -1,6 +1,6 @@ -# Tree Map Bijector +# Leafwise Bijector -::: distreqx.bijectors.TreeMap +::: distreqx.bijectors.Leafwise options: members: - __init__ diff --git a/tests/treemap_test.py b/tests/leafwise_test.py similarity index 79% rename from tests/treemap_test.py rename to tests/leafwise_test.py index 3b180b5..31788c7 100644 --- a/tests/treemap_test.py +++ b/tests/leafwise_test.py @@ -1,4 +1,4 @@ -"""Tests for `tree_map.py`.""" +"""Tests for `leafwise.py`.""" from unittest import TestCase @@ -7,7 +7,7 @@ import numpy as np from parameterized import parameterized # type: ignore -from distreqx.bijectors import Shift, Tanh, TreeMap +from distreqx.bijectors import Leafwise, Shift, Tanh from distreqx.bijectors._bijector import AbstractBijector @@ -15,25 +15,25 @@ def _is_bijector(node): return isinstance(node, AbstractBijector) -class TreeMapTest(TestCase): +class LeafwiseTest(TestCase): def test_empty_tree_raises(self): with self.assertRaisesRegex( ValueError, "The pytree of bijectors cannot be empty" ): - TreeMap({}) + Leafwise({}) with self.assertRaisesRegex( ValueError, "The pytree of bijectors cannot be empty" ): - TreeMap([]) + Leafwise([]) def test_jacobian_is_constant_property(self): # All bijectors have constant jacobians - const_bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Shift(jnp.ones((2,)))}) + const_bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Shift(jnp.ones((2,)))}) self.assertTrue(const_bij.is_constant_jacobian) self.assertTrue(const_bij.is_constant_log_det) # Mixed bijectors (Tanh does not have a constant jacobian) - mixed_bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + mixed_bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()}) self.assertFalse(mixed_bij.is_constant_jacobian) self.assertFalse(mixed_bij.is_constant_log_det) @@ -57,7 +57,7 @@ def test_jacobian_is_constant_property(self): ] ) def test_forward_methods(self, name, bijectors, x): - tree_bij = TreeMap(bijectors) + tree_bij = Leafwise(bijectors) y1 = tree_bij.forward(x) logdet1 = tree_bij.forward_log_det_jacobian(x) @@ -71,12 +71,11 @@ def test_forward_methods(self, name, bijectors, x): jax.tree_util.tree_structure(y2), jax.tree_util.tree_structure(x) ) - # Manually compute expected values via tree_map - # using is_leaf to stop at bijectors - expected_y = jax.tree_util.tree_map( + # Manually compute expected values via tree.map + expected_y = jax.tree.map( lambda b, v: b.forward(v), bijectors, x, is_leaf=_is_bijector ) - expected_logdets_tree = jax.tree_util.tree_map( + expected_logdets_tree = jax.tree.map( lambda b, v: b.forward_log_det_jacobian(v), bijectors, x, @@ -85,13 +84,13 @@ def test_forward_methods(self, name, bijectors, x): expected_logdet = sum(jax.tree_util.tree_leaves(expected_logdets_tree)) # Assert shapes and values - jax.tree_util.tree_map( + jax.tree.map( lambda res, exp: self.assertEqual(res.shape, exp.shape), y1, expected_y ) - jax.tree_util.tree_map( + jax.tree.map( lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), y1, expected_y ) - jax.tree_util.tree_map( + jax.tree.map( lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), y2, expected_y ) @@ -118,7 +117,7 @@ def test_forward_methods(self, name, bijectors, x): ] ) def test_inverse_methods(self, name, bijectors, y): - tree_bij = TreeMap(bijectors) + tree_bij = Leafwise(bijectors) x1 = tree_bij.inverse(y) logdet1 = tree_bij.inverse_log_det_jacobian(y) @@ -133,11 +132,10 @@ def test_inverse_methods(self, name, bijectors, y): ) # Manually compute expected values via tree_map - # using is_leaf to stop at bijectors - expected_x = jax.tree_util.tree_map( + expected_x = jax.tree.map( lambda b, v: b.inverse(v), bijectors, y, is_leaf=_is_bijector ) - expected_logdets_tree = jax.tree_util.tree_map( + expected_logdets_tree = jax.tree.map( lambda b, v: b.inverse_log_det_jacobian(v), bijectors, y, @@ -146,13 +144,13 @@ def test_inverse_methods(self, name, bijectors, y): expected_logdet = sum(jax.tree_util.tree_leaves(expected_logdets_tree)) # Assert shapes and values - jax.tree_util.tree_map( + jax.tree.map( lambda res, exp: self.assertEqual(res.shape, exp.shape), x1, expected_x ) - jax.tree_util.tree_map( + jax.tree.map( lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), x1, expected_x ) - jax.tree_util.tree_map( + jax.tree.map( lambda res, exp: np.testing.assert_allclose(res, exp, 1e-6), x2, expected_x ) @@ -164,7 +162,7 @@ def test_jittable(self): def f(x, b): return b.forward(x) - bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()}) x = {"a": np.zeros((4,)), "b": np.zeros((4,))} z = f(x, bij) @@ -174,22 +172,20 @@ def f(x, b): self.assertIsInstance(z["b"], jnp.ndarray) def test_same_as_itself(self): - bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Tanh()}) - # Distreqx bijectors often evaluate False for different object - # instances, so checking the exact same instance is the correct test. + bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()}) self.assertTrue(bij.same_as(bij)) def test_not_same_as_others(self): - bij = TreeMap({"a": Shift(jnp.ones((4,))), "b": Tanh()}) + bij = Leafwise({"a": Shift(jnp.ones((4,))), "b": Tanh()}) # Completely different bijector other_type = Shift(jnp.zeros((4,))) self.assertFalse(bij.same_as(other_type)) # Same structure, different bijector parameters - different_params = TreeMap({"a": Shift(jnp.zeros((4,))), "b": Tanh()}) + different_params = Leafwise({"a": Shift(jnp.zeros((4,))), "b": Tanh()}) self.assertFalse(bij.same_as(different_params)) # Different structure - different_structure = TreeMap({"a": Shift(jnp.ones((4,)))}) + different_structure = Leafwise({"a": Shift(jnp.ones((4,)))}) self.assertFalse(bij.same_as(different_structure))