Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/constant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

Constant
========

.. autoclass:: torchjd.scalarization.Constant
:members: __call__
21 changes: 21 additions & 0 deletions docs/source/docs/scalarization/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
scalarization
=============

.. automodule:: torchjd.scalarization
:no-members:

Abstract base class
-------------------

.. autoclass:: torchjd.scalarization.Scalarizer
:members: __call__


.. toctree::
:hidden:
:maxdepth: 1

constant.rst
mean.rst
random.rst
sum.rst
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/mean.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

Mean
====

.. autoclass:: torchjd.scalarization.Mean
:members: __call__
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/random.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

Random
======

.. autoclass:: torchjd.scalarization.Random
:members: __call__
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/sum.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

Sum
===

.. autoclass:: torchjd.scalarization.Sum
:members: __call__
5 changes: 5 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Jacobian descent is the aggregator, which maps the Jacobian to an optimization s
:doc:`Aggregation <docs/aggregation/index>`, we provide an overview of the various aggregators
available in TorchJD, and their corresponding weightings.

For comparison against simple baselines, the :doc:`Scalarization <docs/scalarization/index>`
package provides scalarizers that combine a tensor of losses into a single scalar loss, allowing
standard gradient descent to be used.

A straightforward application of Jacobian descent is multi-task learning, in which the vector of
per-task losses has to be minimized. To start using TorchJD for multi-task learning, follow our
:doc:`MTL example <examples/mtl>`.
Expand Down Expand Up @@ -70,4 +74,5 @@ TorchJD is open-source, under MIT License. The source code is available on
docs/autogram/index.rst
docs/autojac/index.rst
docs/aggregation/index.rst
docs/scalarization/index.rst
docs/linalg/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ def vector_to_str(vector: Tensor) -> str:

weights_str = ", ".join([f"{value:.2f}".rstrip("0") for value in vector])
return weights_str


def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
"""Returns a suffix string containing the representation of the optional preference vector."""

if pref_vector is None:
return ""
return f"([{vector_to_str(pref_vector)}])"
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import torch
from torch import Tensor

from torchjd._vector_str import pref_vector_to_str_suffix
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._utils.pref_vector import pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting

SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import torch
from torch import Tensor

from torchjd._vector_str import pref_vector_to_str_suffix
from torchjd.linalg import Matrix

from ._aggregator_bases import Aggregator
from ._mixins import _NonDifferentiable
from ._sum import SumWeighting
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._utils.pref_vector import pref_vector_to_weighting


# Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context.
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch import Tensor

from torchjd._vector_str import vector_to_str

from ._aggregator_bases import WeightedAggregator
from ._utils.str import vector_to_str
from ._weighting_bases import _MatrixWeighting


Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from torch import Tensor

from torchjd._linalg import DualConeProjector, projector_or_default
from torchjd._vector_str import pref_vector_to_str_suffix
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._mixins import _NonDifferentiable
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._utils.pref_vector import pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting


Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from torch import Tensor

from torchjd._linalg import DualConeProjector, projector_or_default
from torchjd._vector_str import pref_vector_to_str_suffix
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._mixins import _NonDifferentiable
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._utils.pref_vector import pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting


Expand Down
10 changes: 0 additions & 10 deletions src/torchjd/aggregation/_utils/pref_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from torchjd.aggregation._weighting_bases import Weighting
from torchjd.linalg import Matrix

from .str import vector_to_str


def pref_vector_to_weighting(
pref_vector: Tensor | None,
Expand All @@ -24,11 +22,3 @@ def pref_vector_to_weighting(
f"{pref_vector.ndim}`.",
)
return ConstantWeighting(pref_vector)


def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
"""Returns a suffix string containing the representation of the optional preference vector."""

if pref_vector is None:
return ""
return f"([{vector_to_str(pref_vector)}])"
34 changes: 34 additions & 0 deletions src/torchjd/scalarization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
A :class:`~torchjd.scalarization.Scalarizer` reduces a tensor of values of any shape into a single
scalar value. This is the simple baseline
against which :class:`Aggregators <torchjd.aggregation.Aggregator>` are compared: instead of
combining the per-loss gradients via the Jacobian or its Gramian, a
:class:`~torchjd.scalarization.Scalarizer` combines the losses directly, and a standard call to
:meth:`~torch.Tensor.backward` produces the gradient.

The following example shows how to use :class:`~torchjd.scalarization.Mean` to combine a vector of
losses into a single scalar loss.

>>> from torch import tensor
>>> from torchjd.scalarization import Mean
>>>
>>> scalarizer = Mean()
>>> losses = tensor([1.0, 2.0, 3.0])
>>> loss = scalarizer(losses)
>>> loss
tensor(2.)
"""

from ._constant import Constant
from ._mean import Mean
from ._random import Random
from ._scalarizer_base import Scalarizer
from ._sum import Sum

__all__ = [
"Constant",
"Mean",
"Random",
"Scalarizer",
"Sum",
]
35 changes: 35 additions & 0 deletions src/torchjd/scalarization/_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from torch import Tensor

from torchjd._vector_str import pref_vector_to_str_suffix

from ._scalarizer_base import Scalarizer


class Constant(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values with
constant, pre-determined weights.

:param weights: The weights to apply to the values. Must have the same shape as the values
passed at call time.
"""

def __init__(self, weights: Tensor) -> None:
super().__init__()
self.weights = weights

def forward(self, values: Tensor, /) -> Tensor:
if values.shape != self.weights.shape:
raise ValueError(
f"Parameter `values` should have shape {tuple(self.weights.shape)} (matching the "
f"shape of the weights). Found `values.shape = {tuple(values.shape)}`.",
)
return (self.weights * values).sum()

def __repr__(self) -> str:
return f"{self.__class__.__name__}(weights={repr(self.weights)})"

def __str__(self) -> str:
if self.weights.ndim == 1:
return f"{self.__class__.__name__}{pref_vector_to_str_suffix(self.weights)}"
return f"{self.__class__.__name__}(weights of shape {tuple(self.weights.shape)})"
12 changes: 12 additions & 0 deletions src/torchjd/scalarization/_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch import Tensor

from ._scalarizer_base import Scalarizer


class Mean(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of values.
"""

def forward(self, values: Tensor, /) -> Tensor:
return values.mean()
19 changes: 19 additions & 0 deletions src/torchjd/scalarization/_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from torch import Tensor
from torch.nn.functional import softmax

from ._scalarizer_base import Scalarizer


class Random(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values with
positive random weights summing to 1, as defined in Algorithm 2 of `Reasonable Effectiveness of
Random Weighting: A Litmus Test for Multi-Task Learning
<https://arxiv.org/pdf/2111.10603.pdf>`_.
"""

def forward(self, values: Tensor, /) -> Tensor:
flat = torch.randn(values.numel(), device=values.device, dtype=values.dtype)
weights = softmax(flat, dim=-1).reshape(values.shape)
return (weights * values).sum()
31 changes: 31 additions & 0 deletions src/torchjd/scalarization/_scalarizer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod

from torch import Tensor, nn


class Scalarizer(nn.Module, ABC):
"""
Abstract base class for all scalarizers. Reduces a tensor of values of any shape into a single
scalar value.
"""
Comment thread
ppraneth marked this conversation as resolved.

def __init__(self) -> None:
super().__init__()

@abstractmethod
def forward(self, values: Tensor, /) -> Tensor:
"""Computes the scalarization from input tensor."""

def __call__(self, values: Tensor, /) -> Tensor:
"""
Computes the scalar value from the input tensor of values and applies all registered hooks.

:param values: The tensor of values to scalarize. May be of any shape.
"""
return super().__call__(values)

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

def __str__(self) -> str:
return f"{self.__class__.__name__}"
12 changes: 12 additions & 0 deletions src/torchjd/scalarization/_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch import Tensor

from ._scalarizer_base import Scalarizer


class Sum(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that returns the sum of the input tensor of values.
"""

def forward(self, values: Tensor, /) -> Tensor:
return values.sum()
Empty file.
27 changes: 27 additions & 0 deletions tests/unit/scalarization/_asserts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torch import Tensor
from utils.tensors import randperm_

from torchjd.scalarization import Scalarizer


def assert_returns_scalar(scalarizer: Scalarizer, losses: Tensor) -> None:
out = scalarizer(losses)
assert out.dim() == 0
assert out.isfinite()


def assert_grad_flow(scalarizer: Scalarizer, losses: Tensor) -> None:
leaf = losses.detach().requires_grad_()
out = scalarizer(leaf)
out.backward()
assert leaf.grad is not None
assert leaf.grad.isfinite().all()


def assert_permutation_invariant(scalarizer: Scalarizer, losses: Tensor) -> None:
out = scalarizer(losses)
flat = losses.flatten()
permuted = flat[randperm_(flat.numel())].reshape(losses.shape)
out_permuted = scalarizer(permuted)
torch.testing.assert_close(out, out_permuted)
5 changes: 5 additions & 0 deletions tests/unit/scalarization/_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from torch import Tensor
from utils.tensors import randn_

shapes: list[list[int]] = [[], [5], [3, 4], [2, 3, 4]]
all_inputs: list[Tensor] = [randn_(shape) for shape in shapes]
Loading
Loading