Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a
softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine`
in a two-batch training loop.
- Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature
Aggregation and Geometric Loss Strategy for Multi-Task
Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf),
Expand Down
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Abstract base classes
krum.rst
mean.rst
mgda.rst
modo.rst
nash_mtl.rst
pcgrad.rst
random.rst
Expand Down
7 changes: 7 additions & 0 deletions docs/source/docs/aggregation/modo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

MoDo
====

.. autoclass:: torchjd.aggregation.MoDoWeighting
:members: __call__, reset
2 changes: 2 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._modo import MoDoWeighting
from ._nash_mtl import NashMTL
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
Expand Down Expand Up @@ -87,6 +88,7 @@
"MeanWeighting",
"MGDA",
"MGDAWeighting",
"MoDoWeighting",
"NashMTL",
"PCGrad",
"PCGradWeighting",
Expand Down
138 changes: 138 additions & 0 deletions src/torchjd/aggregation/_modo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

from typing import cast

import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from torchjd.linalg import PSDMatrix

from ._weighting_bases import _GramianWeighting


class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/opencode:Plan is the inheritance order correct here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it matched _gradvac.py, and also the warning in docstring for _NonDifferntiable states "Placing this mixin before the primary base will cause it to shadow the primary class's call signature in generated documentation."

So yes, I believe it is

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] from `Three-Way
Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance
<https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf>`_ (JMLR 2024), commonly referred
to as MoDo (Multi-Objective gradient with Double sampling).

Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a
softmax-projected gradient step:

.. math::

\lambda_{t+1} = \operatorname{softmax}\!\bigl(
\lambda_t - \gamma \cdot (G \lambda_t + \rho \lambda_t)
\bigr)

The paper specifies hard simplex projection :math:`\Pi_\Delta`; we follow the `official
LibMTL implementation <https://github.com/median-research-group/LibMTL>`_ and use
:func:`torch.softmax` as the projection step.

The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector
:math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset
automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use
:meth:`reset` to manually restart from uniform weights.

.. warning::
MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this
weighting must come from a mini-batch that is independent of the one used for the
subsequent parameter update. The Gramian can be computed efficiently from a batch of
losses using the :class:`~torchjd.autogram.Engine`. See the usage example below.

:param gamma: Learning rate of the task-weight update. Must be positive.
:param rho: Non-negative :math:`\ell_2` regularisation coefficient.

.. admonition:: Example

Train a model using MoDo with two independent mini-batches per step. The first batch
drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter
update via the usual backward pass.

Comment thread
PierreQuinton marked this conversation as resolved.
.. code-block:: python

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import MoDoWeighting
from torchjd.autogram import Engine

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = MoDoWeighting(gamma=0.1, rho=0.0)
engine = Engine(model, batch_dim=0)

# loader_1 and loader_2 must yield independent draws from the same distribution.
for batch_1, batch_2 in zip(loader_1, loader_2):
input_1, target_1 = batch_1
input_2, target_2 = batch_2

# Step 1: Gramian from batch 1 drives the lambda update.
losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
gramian = engine.compute_gramian(losses_1)
weights = weighting(gramian)

# Step 2: backward on batch 2 with those weights drives the parameter update.
losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
losses_2.backward(weights)
optimizer.step()
optimizer.zero_grad()
"""

def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None:
super().__init__()
self.gamma = gamma
self.rho = rho
self._lambda: Tensor | None = None
self._state_key: tuple[int, torch.dtype, torch.device] | None = None

@property
def gamma(self) -> float:
return self._gamma

@gamma.setter
def gamma(self, value: float) -> None:
if value <= 0.0:
raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.")
self._gamma = value

@property
def rho(self) -> float:
return self._rho

@rho.setter
def rho(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.")
self._rho = value

def reset(self) -> None:
"""Clears the stored task weights so the next forward starts from uniform."""

self._lambda = None
self._state_key = None

def forward(self, gramian: PSDMatrix, /) -> Tensor:
self._ensure_state(gramian)
lambd = cast(Tensor, self._lambda)

grad = gramian @ lambd + self._rho * lambd
lambd = torch.softmax(lambd - self._gamma * grad, dim=-1)

self._lambda = lambd
return lambd

def _ensure_state(self, gramian: PSDMatrix) -> None:
key = (gramian.shape[0], gramian.dtype, gramian.device)
if self._state_key == key and self._lambda is not None:
return
self._lambda = gramian.new_full((gramian.shape[0],), 1.0 / gramian.shape[0])
self._state_key = key

def __repr__(self) -> str:
return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})"
145 changes: 145 additions & 0 deletions tests/unit/aggregation/test_modo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
from pytest import mark, raises
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import randn_, tensor_

from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator
from torchjd.aggregation._modo import MoDoWeighting

from ._asserts import assert_expected_structure
from ._inputs import scaled_matrices, typical_matrices

gramian_pairs = [
(GramianWeightedAggregator(MoDoWeighting()), m) for m in typical_matrices + scaled_matrices
]


def test_representations() -> None:
W = MoDoWeighting(gamma=0.1, rho=0.05)
assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)"


@mark.parametrize(["aggregator", "matrix"], gramian_pairs)
def test_expected_structure_gramian_weighting(
aggregator: GramianWeightedAggregator, matrix: Tensor
) -> None:
assert_expected_structure(aggregator, matrix)


def test_reset_restores_first_step_behavior() -> None:
J = randn_((3, 8))
G = J @ J.T
W = MoDoWeighting(gamma=0.1)
first = W(G)
W(G)
W.reset()
assert_close(first, W(G))


def test_gamma_setter_accepts_valid() -> None:
W = MoDoWeighting()
W.gamma = 0.01
assert W.gamma == 0.01
W.gamma = 0.1
assert W.gamma == 0.1
W.gamma = 1.0
assert W.gamma == 1.0


def test_gamma_setter_rejects_non_positive() -> None:
W = MoDoWeighting()
with raises(ValueError, match="gamma"):
W.gamma = 0.0
with raises(ValueError, match="gamma"):
W.gamma = -0.1


def test_rho_setter_accepts_valid() -> None:
W = MoDoWeighting()
W.rho = 0.0
assert W.rho == 0.0
W.rho = 0.1
assert W.rho == 0.1


def test_rho_setter_rejects_negative() -> None:
W = MoDoWeighting()
with raises(ValueError, match="rho"):
W.rho = -0.1


def test_output_lies_on_simplex() -> None:
"""The softmax projection ensures the weights sum to 1 and are non-negative."""

J = randn_((4, 10))
G = J @ J.T
W = MoDoWeighting(gamma=0.1, rho=0.05)
weights = W(G)
assert weights.shape == (4,)
assert (weights >= 0).all()
assert_close(weights.sum(), tensor_(1.0))


def test_small_gamma_stays_near_uniform() -> None:
"""With a tiny gamma, one step barely moves lambda from the uniform initialisation."""

J = randn_((3, 8))
G = J @ J.T
m = J.shape[0]
W = MoDoWeighting(gamma=1e-8)
uniform = tensor_([1.0 / m] * m)
assert_close(W(G), uniform, atol=1e-6, rtol=1e-6)


def test_update_recurrence() -> None:
"""Verify one step of the softmax-projected gradient update by hand."""

gamma = 0.1
rho = 0.05
J = randn_((3, 8))
G = J @ J.T
m = J.shape[0]

W = MoDoWeighting(gamma=gamma, rho=rho)
lambda_0 = tensor_([1.0 / m] * m)
grad = G @ lambda_0 + rho * lambda_0
expected = torch.softmax(lambda_0 - gamma * grad, dim=-1)

assert_close(W(G), expected)


def test_two_consecutive_steps() -> None:
"""Verify two consecutive steps of the softmax-projected gradient update."""

gamma = 0.1
rho = 0.0
J1 = randn_((3, 8))
J2 = randn_((3, 8))
G1 = J1 @ J1.T
G2 = J2 @ J2.T
m = J1.shape[0]

W = MoDoWeighting(gamma=gamma, rho=rho)

lambda_0 = tensor_([1.0 / m] * m)
grad_1 = G1 @ lambda_0 + rho * lambda_0
lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1)

grad_2 = G2 @ lambda_1 + rho * lambda_1
lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1)

assert_close(W(G1), lambda_1)
assert_close(W(G2), lambda_2)


def test_changing_m_auto_resets() -> None:
"""When the number of objectives changes, the state is re-initialised to uniform."""

W = MoDoWeighting(gamma=0.1)
W(randn_((3, 8)) @ randn_((3, 8)).T)
# After a state-resetting call with m=2, the first output should equal the uniform step's output.
fresh = MoDoWeighting(gamma=0.1)
J = randn_((2, 8))
G = J @ J.T
assert_close(W(G), fresh(G))
Loading