diff --git a/CHANGELOG.md b/CHANGELOG.md index 9204433a..ab5a536d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature + Aggregation and Geometric Loss Strategy for Multi-Task + Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf), + a `Scalarizer` that returns the geometric mean of the input tensor of values. + ## [0.12.0] - 2026-05-28 ### Added diff --git a/docs/source/docs/scalarization/geometric_mean.rst b/docs/source/docs/scalarization/geometric_mean.rst new file mode 100644 index 00000000..82622860 --- /dev/null +++ b/docs/source/docs/scalarization/geometric_mean.rst @@ -0,0 +1,7 @@ +:hide-toc: + +GeometricMean +============= + +.. autoclass:: torchjd.scalarization.GeometricMean + :members: __call__ diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst index 11381bfa..8fd87dc8 100644 --- a/docs/source/docs/scalarization/index.rst +++ b/docs/source/docs/scalarization/index.rst @@ -16,6 +16,7 @@ Abstract base class :maxdepth: 1 constant.rst + geometric_mean.rst mean.rst random.rst sum.rst diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index bf82aa11..337d38ca 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -20,15 +20,10 @@ """ from ._constant import Constant +from ._geometric_mean import GeometricMean from ._mean import Mean from ._random import Random from ._scalarizer_base import Scalarizer from ._sum import Sum -__all__ = [ - "Constant", - "Mean", - "Random", - "Scalarizer", - "Sum", -] +__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "Sum"] diff --git a/src/torchjd/scalarization/_geometric_mean.py b/src/torchjd/scalarization/_geometric_mean.py new file mode 100644 index 00000000..f54d9fca --- /dev/null +++ b/src/torchjd/scalarization/_geometric_mean.py @@ -0,0 +1,23 @@ +import torch +from torch import Tensor + +from ._scalarizer_base import Scalarizer + + +class GeometricMean(Scalarizer): + """ + :class:`~torchjd.scalarization.Scalarizer` that returns the geometric mean of the input tensor + of values, as studied in `MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss + Strategy for Multi-Task Learning + `_. + + This method is also known as GLS (Geometric Loss Strategy). + """ + + def forward(self, values: Tensor, /) -> Tensor: + if (values < 1e-12).any(): + raise ValueError( + "GeometricMean is only defined for strictly positive values. Found a value " + "below 1e-12 in the input." + ) + return torch.exp(torch.log(values).mean()) diff --git a/tests/unit/scalarization/test_geometric_mean.py b/tests/unit/scalarization/test_geometric_mean.py new file mode 100644 index 00000000..1fb99ab9 --- /dev/null +++ b/tests/unit/scalarization/test_geometric_mean.py @@ -0,0 +1,54 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from utils.tensors import rand_, tensor_ + +from torchjd.scalarization import GeometricMean + +from ._asserts import ( + assert_grad_flow, + assert_permutation_invariant, + assert_returns_scalar, +) +from ._inputs import shapes + +positive_inputs: list[Tensor] = [rand_(shape) + 1 for shape in shapes] + + +def test_value() -> None: + losses = tensor_([1.0, 2.0, 4.0]) + torch.testing.assert_close(GeometricMean()(losses), tensor_(2.0)) + + +@mark.parametrize("losses", positive_inputs) +def test_expected_structure(losses: Tensor) -> None: + assert_returns_scalar(GeometricMean(), losses) + + +@mark.parametrize("losses", positive_inputs) +def test_grad_flow(losses: Tensor) -> None: + assert_grad_flow(GeometricMean(), losses) + + +@mark.parametrize("losses", positive_inputs) +def test_permutation_invariant(losses: Tensor) -> None: + assert_permutation_invariant(GeometricMean(), losses) + + +@mark.parametrize( + "invalid", + [ + tensor_([1.0, 0.0]), + tensor_([1.0, -1.0]), + tensor_([1.0, 1e-13]), + ], +) +def test_raises_on_non_positive_input(invalid: Tensor) -> None: + with raises(ValueError): + GeometricMean()(invalid) + + +def test_representations() -> None: + s = GeometricMean() + assert repr(s) == "GeometricMean()" + assert str(s) == "GeometricMean"