diff --git a/CHANGELOG.md b/CHANGELOG.md index ab5a536d..225e65d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a + softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine` + in a two-batch training loop. - Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf), diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 13e405cb..66d74570 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -41,6 +41,7 @@ Abstract base classes krum.rst mean.rst mgda.rst + modo.rst nash_mtl.rst pcgrad.rst random.rst diff --git a/docs/source/docs/aggregation/modo.rst b/docs/source/docs/aggregation/modo.rst new file mode 100644 index 00000000..98b8d515 --- /dev/null +++ b/docs/source/docs/aggregation/modo.rst @@ -0,0 +1,7 @@ +:hide-toc: + +MoDo +==== + +.. autoclass:: torchjd.aggregation.MoDoWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 1814d320..92bbadec 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -53,6 +53,7 @@ from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting from ._mixins import Stateful +from ._modo import MoDoWeighting from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting @@ -87,6 +88,7 @@ "MeanWeighting", "MGDA", "MGDAWeighting", + "MoDoWeighting", "NashMTL", "PCGrad", "PCGradWeighting", diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py new file mode 100644 index 00000000..d219b44e --- /dev/null +++ b/src/torchjd/aggregation/_modo.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable +from torchjd.linalg import PSDMatrix + +from ._weighting_bases import _GramianWeighting + + +class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] from `Three-Way + Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance + `_ (JMLR 2024), commonly referred + to as MoDo (Multi-Objective gradient with Double sampling). + + Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a + softmax-projected gradient step: + + .. math:: + + \lambda_{t+1} = \operatorname{softmax}\!\bigl( + \lambda_t - \gamma \cdot (G \lambda_t + \rho \lambda_t) + \bigr) + + The paper specifies hard simplex projection :math:`\Pi_\Delta`; we follow the `official + LibMTL implementation `_ and use + :func:`torch.softmax` as the projection step. + + The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector + :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset + automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use + :meth:`reset` to manually restart from uniform weights. + + .. warning:: + MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this + weighting must come from a mini-batch that is independent of the one used for the + subsequent parameter update. The Gramian can be computed efficiently from a batch of + losses using the :class:`~torchjd.autogram.Engine`. See the usage example below. + + :param gamma: Learning rate of the task-weight update. Must be positive. + :param rho: Non-negative :math:`\ell_2` regularisation coefficient. + + .. admonition:: Example + + Train a model using MoDo with two independent mini-batches per step. The first batch + drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter + update via the usual backward pass. + + .. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autogram import Engine + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + engine = Engine(model, batch_dim=0) + + # loader_1 and loader_2 must yield independent draws from the same distribution. + for batch_1, batch_2 in zip(loader_1, loader_2): + input_1, target_1 = batch_1 + input_2, target_2 = batch_2 + + # Step 1: Gramian from batch 1 drives the lambda update. + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + gramian = engine.compute_gramian(losses_1) + weights = weighting(gramian) + + # Step 2: backward on batch 2 with those weights drives the parameter update. + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + losses_2.backward(weights) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: + super().__init__() + self.gamma = gamma + self.rho = rho + self._lambda: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def gamma(self) -> float: + return self._gamma + + @gamma.setter + def gamma(self, value: float) -> None: + if value <= 0.0: + raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.") + self._gamma = value + + @property + def rho(self) -> float: + return self._rho + + @rho.setter + def rho(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.") + self._rho = value + + def reset(self) -> None: + """Clears the stored task weights so the next forward starts from uniform.""" + + self._lambda = None + self._state_key = None + + def forward(self, gramian: PSDMatrix, /) -> Tensor: + self._ensure_state(gramian) + lambd = cast(Tensor, self._lambda) + + grad = gramian @ lambd + self._rho * lambd + lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + + self._lambda = lambd + return lambd + + def _ensure_state(self, gramian: PSDMatrix) -> None: + key = (gramian.shape[0], gramian.dtype, gramian.device) + if self._state_key == key and self._lambda is not None: + return + self._lambda = gramian.new_full((gramian.shape[0],), 1.0 / gramian.shape[0]) + self._state_key = key + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})" diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py new file mode 100644 index 00000000..9b9193be --- /dev/null +++ b/tests/unit/aggregation/test_modo.py @@ -0,0 +1,145 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator +from torchjd.aggregation._modo import MoDoWeighting + +from ._asserts import assert_expected_structure +from ._inputs import scaled_matrices, typical_matrices + +gramian_pairs = [ + (GramianWeightedAggregator(MoDoWeighting()), m) for m in typical_matrices + scaled_matrices +] + + +def test_representations() -> None: + W = MoDoWeighting(gamma=0.1, rho=0.05) + assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" + + +@mark.parametrize(["aggregator", "matrix"], gramian_pairs) +def test_expected_structure_gramian_weighting( + aggregator: GramianWeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +def test_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_gamma_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.gamma = 0.01 + assert W.gamma == 0.01 + W.gamma = 0.1 + assert W.gamma == 0.1 + W.gamma = 1.0 + assert W.gamma == 1.0 + + +def test_gamma_setter_rejects_non_positive() -> None: + W = MoDoWeighting() + with raises(ValueError, match="gamma"): + W.gamma = 0.0 + with raises(ValueError, match="gamma"): + W.gamma = -0.1 + + +def test_rho_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.rho = 0.0 + assert W.rho == 0.0 + W.rho = 0.1 + assert W.rho == 0.1 + + +def test_rho_setter_rejects_negative() -> None: + W = MoDoWeighting() + with raises(ValueError, match="rho"): + W.rho = -0.1 + + +def test_output_lies_on_simplex() -> None: + """The softmax projection ensures the weights sum to 1 and are non-negative.""" + + J = randn_((4, 10)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1, rho=0.05) + weights = W(G) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_small_gamma_stays_near_uniform() -> None: + """With a tiny gamma, one step barely moves lambda from the uniform initialisation.""" + + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + W = MoDoWeighting(gamma=1e-8) + uniform = tensor_([1.0 / m] * m) + assert_close(W(G), uniform, atol=1e-6, rtol=1e-6) + + +def test_update_recurrence() -> None: + """Verify one step of the softmax-projected gradient update by hand.""" + + gamma = 0.1 + rho = 0.05 + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + + assert_close(W(G), expected) + + +def test_two_consecutive_steps() -> None: + """Verify two consecutive steps of the softmax-projected gradient update.""" + + gamma = 0.1 + rho = 0.0 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + + lambda_0 = tensor_([1.0 / m] * m) + grad_1 = G1 @ lambda_0 + rho * lambda_0 + lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1) + + grad_2 = G2 @ lambda_1 + rho * lambda_1 + lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1) + + assert_close(W(G1), lambda_1) + assert_close(W(G2), lambda_2) + + +def test_changing_m_auto_resets() -> None: + """When the number of objectives changes, the state is re-initialised to uniform.""" + + W = MoDoWeighting(gamma=0.1) + W(randn_((3, 8)) @ randn_((3, 8)).T) + # After a state-resetting call with m=2, the first output should equal the uniform step's output. + fresh = MoDoWeighting(gamma=0.1) + J = randn_((2, 8)) + G = J @ J.T + assert_close(W(G), fresh(G))