From b981d098ef3828e94880f8e10233921a810e45f5 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 28 May 2026 23:47:35 -0400 Subject: [PATCH 1/2] feat(aggregation): Add MoDoWeighting --- CHANGELOG.md | 6 + docs/source/docs/aggregation/index.rst | 1 + docs/source/docs/aggregation/modo.rst | 7 ++ src/torchjd/aggregation/__init__.py | 2 + src/torchjd/aggregation/_modo.py | 143 +++++++++++++++++++++++ tests/unit/aggregation/test_modo.py | 153 +++++++++++++++++++++++++ 6 files changed, 312 insertions(+) create mode 100644 docs/source/docs/aggregation/modo.rst create mode 100644 src/torchjd/aggregation/_modo.py create mode 100644 tests/unit/aggregation/test_modo.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9204433a..7ca370c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a + softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine` + in a two-batch training loop. + ## [0.12.0] - 2026-05-28 ### Added diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 13e405cb..66d74570 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -41,6 +41,7 @@ Abstract base classes krum.rst mean.rst mgda.rst + modo.rst nash_mtl.rst pcgrad.rst random.rst diff --git a/docs/source/docs/aggregation/modo.rst b/docs/source/docs/aggregation/modo.rst new file mode 100644 index 00000000..98b8d515 --- /dev/null +++ b/docs/source/docs/aggregation/modo.rst @@ -0,0 +1,7 @@ +:hide-toc: + +MoDo +==== + +.. autoclass:: torchjd.aggregation.MoDoWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 1814d320..92bbadec 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -53,6 +53,7 @@ from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting from ._mixins import Stateful +from ._modo import MoDoWeighting from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting @@ -87,6 +88,7 @@ "MeanWeighting", "MGDA", "MGDAWeighting", + "MoDoWeighting", "NashMTL", "PCGrad", "PCGradWeighting", diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py new file mode 100644 index 00000000..d24629b3 --- /dev/null +++ b/src/torchjd/aggregation/_modo.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful +from torchjd.linalg import PSDMatrix + +from ._weighting_bases import _GramianWeighting + + +class MoDoWeighting(_GramianWeighting, Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] implementing the + task-weight update from `Three-Way Trade-Off in Multi-Objective Learning: Optimization, + Generalization and Conflict-Avoidance `_ + (JMLR 2024), commonly referred to as MoDo (Multi-Objective gradient with Double sampling). + + At each call, the weights :math:`\lambda` are updated by a projected gradient step on + :math:`\lambda^\top G \lambda + \rho \|\lambda\|^2` where :math:`G = G_1 G_1^\top` is the + Gramian of the first mini-batch's Jacobian: + + .. math:: + + \lambda_{t+1} = \operatorname{softmax}\!\bigl( + \lambda_t - \gamma \cdot (G \lambda_t + \rho \lambda_t) + \bigr) + + The paper specifies hard simplex projection :math:`\Pi_\Delta`; we follow the `official + LibMTL implementation `_ and use + :func:`torch.softmax` as the projection step. + + The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector + :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset + automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use + :meth:`reset` to manually restart the smoothing from uniform weights. + + .. warning:: + MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this + weighting must come from a mini-batch that is independent of the one used for the + subsequent parameter update. See the usage example below. + + :param gamma: Learning rate of the task-weight update. Must be positive. + :param rho: Non-negative :math:`\ell_2` regularisation coefficient. + + .. admonition:: Example + + Train a model using MoDo with two independent mini-batches per step. The first batch + drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter + update via the usual backward pass. + + .. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autogram import Engine + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + engine = Engine(model, batch_dim=0) + + # loader_1 and loader_2 must yield independent draws from the same distribution. + for batch_1, batch_2 in zip(loader_1, loader_2): + input_1, target_1 = batch_1 + input_2, target_2 = batch_2 + + # Step 1: Gramian from batch 1 drives the lambda update. + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + gramian = engine.compute_gramian(losses_1) + weights = weighting(gramian) + + # Step 2: backward on batch 2 with those weights drives the parameter update. + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + losses_2.backward(weights) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: + super().__init__() + self.gamma = gamma + self.rho = rho + self._lambda: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def gamma(self) -> float: + return self._gamma + + @gamma.setter + def gamma(self, value: float) -> None: + if value <= 0.0: + raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.") + self._gamma = value + + @property + def rho(self) -> float: + return self._rho + + @rho.setter + def rho(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.") + self._rho = value + + def reset(self) -> None: + """Clears the stored task weights so the next forward starts from uniform.""" + + self._lambda = None + self._state_key = None + + def forward(self, gramian: PSDMatrix, /) -> Tensor: + m = gramian.shape[0] + if m == 0: + return gramian.new_empty((0,)) + + self._ensure_state(gramian) + lambd = cast(Tensor, self._lambda) + + with torch.no_grad(): + grad = gramian @ lambd + self._rho * lambd + lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + + self._lambda = lambd + return lambd + + def _ensure_state(self, gramian: PSDMatrix) -> None: + key = (gramian.shape[0], gramian.dtype, gramian.device) + if self._state_key == key and self._lambda is not None: + return + self._lambda = gramian.new_full((gramian.shape[0],), 1.0 / gramian.shape[0]) + self._state_key = key + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})" diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py new file mode 100644 index 00000000..af80fbab --- /dev/null +++ b/tests/unit/aggregation/test_modo.py @@ -0,0 +1,153 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator +from torchjd.aggregation._modo import MoDoWeighting + +from ._asserts import assert_expected_structure +from ._inputs import scaled_matrices, typical_matrices + +gramian_pairs = [ + (GramianWeightedAggregator(MoDoWeighting()), m) for m in typical_matrices + scaled_matrices +] + + +def test_representations() -> None: + W = MoDoWeighting(gamma=0.1, rho=0.05) + assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" + + +@mark.parametrize(["aggregator", "matrix"], gramian_pairs) +def test_expected_structure_gramian_weighting( + aggregator: GramianWeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +def test_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_gamma_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.gamma = 0.01 + assert W.gamma == 0.01 + W.gamma = 0.1 + assert W.gamma == 0.1 + W.gamma = 1.0 + assert W.gamma == 1.0 + + +def test_gamma_setter_rejects_non_positive() -> None: + W = MoDoWeighting() + with raises(ValueError, match="gamma"): + W.gamma = 0.0 + with raises(ValueError, match="gamma"): + W.gamma = -0.1 + + +def test_rho_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.rho = 0.0 + assert W.rho == 0.0 + W.rho = 0.1 + assert W.rho == 0.1 + + +def test_rho_setter_rejects_negative() -> None: + W = MoDoWeighting() + with raises(ValueError, match="rho"): + W.rho = -0.1 + + +def test_output_lies_on_simplex() -> None: + """The softmax projection ensures the weights sum to 1 and are non-negative.""" + + J = randn_((4, 10)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1, rho=0.05) + weights = W(G) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_small_gamma_stays_near_uniform() -> None: + """With a tiny gamma, one step barely moves lambda from the uniform initialisation.""" + + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + W = MoDoWeighting(gamma=1e-8) + uniform = tensor_([1.0 / m] * m) + assert_close(W(G), uniform, atol=1e-6, rtol=1e-6) + + +def test_update_recurrence() -> None: + """Verify one step of the softmax-projected gradient update by hand.""" + + gamma = 0.1 + rho = 0.05 + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + + assert_close(W(G), expected) + + +def test_two_consecutive_steps() -> None: + """Verify two consecutive steps of the softmax-projected gradient update.""" + + gamma = 0.1 + rho = 0.0 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + + lambda_0 = tensor_([1.0 / m] * m) + grad_1 = G1 @ lambda_0 + rho * lambda_0 + lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1) + + grad_2 = G2 @ lambda_1 + rho * lambda_1 + lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1) + + assert_close(W(G1), lambda_1) + assert_close(W(G2), lambda_2) + + +def test_changing_m_auto_resets() -> None: + """When the number of objectives changes, the state is re-initialised to uniform.""" + + W = MoDoWeighting(gamma=0.1) + W(randn_((3, 8)) @ randn_((3, 8)).T) + # After a state-resetting call with m=2, the first output should equal the uniform step's output. + fresh = MoDoWeighting(gamma=0.1) + J = randn_((2, 8)) + G = J @ J.T + assert_close(W(G), fresh(G)) + + +def test_zero_rows() -> None: + """A (0, 0) Gramian yields an empty weight vector.""" + + W = MoDoWeighting() + weights = W(tensor_([]).reshape(0, 0)) + assert weights.shape == (0,) From b416fbadffeb90f33761c32faa0598a4309e5731 Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 29 May 2026 10:29:45 -0400 Subject: [PATCH 2/2] refactor(aggregation): Address review feedback on MoDoWeighting --- src/torchjd/aggregation/_modo.py | 31 ++++++++++++----------------- tests/unit/aggregation/test_modo.py | 8 -------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index d24629b3..d219b44e 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -5,23 +5,22 @@ import torch from torch import Tensor -from torchjd.aggregation._mixins import Stateful +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable from torchjd.linalg import PSDMatrix from ._weighting_bases import _GramianWeighting -class MoDoWeighting(_GramianWeighting, Stateful): +class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] implementing the - task-weight update from `Three-Way Trade-Off in Multi-Objective Learning: Optimization, - Generalization and Conflict-Avoidance `_ - (JMLR 2024), commonly referred to as MoDo (Multi-Objective gradient with Double sampling). + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] from `Three-Way + Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance + `_ (JMLR 2024), commonly referred + to as MoDo (Multi-Objective gradient with Double sampling). - At each call, the weights :math:`\lambda` are updated by a projected gradient step on - :math:`\lambda^\top G \lambda + \rho \|\lambda\|^2` where :math:`G = G_1 G_1^\top` is the - Gramian of the first mini-batch's Jacobian: + Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a + softmax-projected gradient step: .. math:: @@ -36,12 +35,13 @@ class MoDoWeighting(_GramianWeighting, Stateful): The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use - :meth:`reset` to manually restart the smoothing from uniform weights. + :meth:`reset` to manually restart from uniform weights. .. warning:: MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this weighting must come from a mini-batch that is independent of the one used for the - subsequent parameter update. See the usage example below. + subsequent parameter update. The Gramian can be computed efficiently from a batch of + losses using the :class:`~torchjd.autogram.Engine`. See the usage example below. :param gamma: Learning rate of the task-weight update. Must be positive. :param rho: Non-negative :math:`\ell_2` regularisation coefficient. @@ -118,16 +118,11 @@ def reset(self) -> None: self._state_key = None def forward(self, gramian: PSDMatrix, /) -> Tensor: - m = gramian.shape[0] - if m == 0: - return gramian.new_empty((0,)) - self._ensure_state(gramian) lambd = cast(Tensor, self._lambda) - with torch.no_grad(): - grad = gramian @ lambd + self._rho * lambd - lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + grad = gramian @ lambd + self._rho * lambd + lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) self._lambda = lambd return lambd diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index af80fbab..9b9193be 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -143,11 +143,3 @@ def test_changing_m_auto_resets() -> None: J = randn_((2, 8)) G = J @ J.T assert_close(W(G), fresh(G)) - - -def test_zero_rows() -> None: - """A (0, 0) Gramian yields an empty weight vector.""" - - W = MoDoWeighting() - weights = W(tensor_([]).reshape(0, 0)) - assert weights.shape == (0,)