Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Added

- Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature
Aggregation and Geometric Loss Strategy for Multi-Task
Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf),
a `Scalarizer` that returns the geometric mean of the input tensor of values.

## [0.12.0] - 2026-05-28

### Added
Expand Down
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/geometric_mean.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

GeometricMean
=============

.. autoclass:: torchjd.scalarization.GeometricMean
:members: __call__
1 change: 1 addition & 0 deletions docs/source/docs/scalarization/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Abstract base class
:maxdepth: 1

constant.rst
geometric_mean.rst
mean.rst
random.rst
sum.rst
9 changes: 2 additions & 7 deletions src/torchjd/scalarization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,10 @@
"""

from ._constant import Constant
from ._geometric_mean import GeometricMean
from ._mean import Mean
from ._random import Random
from ._scalarizer_base import Scalarizer
from ._sum import Sum

__all__ = [
"Constant",
"Mean",
"Random",
"Scalarizer",
"Sum",
]
__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "Sum"]
23 changes: 23 additions & 0 deletions src/torchjd/scalarization/_geometric_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from torch import Tensor

from ._scalarizer_base import Scalarizer


class GeometricMean(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that returns the geometric mean of the input tensor
of values, as studied in `MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss
Strategy for Multi-Task Learning
<https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf>`_.

This method is also known as GLS (Geometric Loss Strategy).
"""
Comment thread
ppraneth marked this conversation as resolved.

def forward(self, values: Tensor, /) -> Tensor:
if (values < 1e-12).any():
raise ValueError(
"GeometricMean is only defined for strictly positive values. Found a value "
"below 1e-12 in the input."
)
return torch.exp(torch.log(values).mean())
Comment thread
ValerianRey marked this conversation as resolved.
54 changes: 54 additions & 0 deletions tests/unit/scalarization/test_geometric_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from pytest import mark, raises
from torch import Tensor
from utils.tensors import rand_, tensor_

from torchjd.scalarization import GeometricMean

from ._asserts import (
assert_grad_flow,
assert_permutation_invariant,
assert_returns_scalar,
)
from ._inputs import shapes

positive_inputs: list[Tensor] = [rand_(shape) + 1 for shape in shapes]


def test_value() -> None:
losses = tensor_([1.0, 2.0, 4.0])
torch.testing.assert_close(GeometricMean()(losses), tensor_(2.0))


@mark.parametrize("losses", positive_inputs)
def test_expected_structure(losses: Tensor) -> None:
assert_returns_scalar(GeometricMean(), losses)


@mark.parametrize("losses", positive_inputs)
def test_grad_flow(losses: Tensor) -> None:
assert_grad_flow(GeometricMean(), losses)


@mark.parametrize("losses", positive_inputs)
def test_permutation_invariant(losses: Tensor) -> None:
assert_permutation_invariant(GeometricMean(), losses)


@mark.parametrize(
"invalid",
[
tensor_([1.0, 0.0]),
tensor_([1.0, -1.0]),
tensor_([1.0, 1e-13]),
],
)
def test_raises_on_non_positive_input(invalid: Tensor) -> None:
with raises(ValueError):
GeometricMean()(invalid)


def test_representations() -> None:
s = GeometricMean()
assert repr(s) == "GeometricMean()"
assert str(s) == "GeometricMean"
Loading