diff --git a/src/torchjd/scalarization/_geometric_mean.py b/src/torchjd/scalarization/_geometric_mean.py index f54d9fca..5c71f6d5 100644 --- a/src/torchjd/scalarization/_geometric_mean.py +++ b/src/torchjd/scalarization/_geometric_mean.py @@ -15,9 +15,9 @@ class GeometricMean(Scalarizer): """ def forward(self, values: Tensor, /) -> Tensor: - if (values < 1e-12).any(): + if (values < 0.0).any(): raise ValueError( - "GeometricMean is only defined for strictly positive values. Found a value " - "below 1e-12 in the input." + "GeometricMean is only defined for strictly positive values." + "Found a negative value in the input." ) return torch.exp(torch.log(values).mean()) diff --git a/tests/unit/scalarization/test_geometric_mean.py b/tests/unit/scalarization/test_geometric_mean.py index 1fb99ab9..a3e89089 100644 --- a/tests/unit/scalarization/test_geometric_mean.py +++ b/tests/unit/scalarization/test_geometric_mean.py @@ -36,16 +36,27 @@ def test_permutation_invariant(losses: Tensor) -> None: @mark.parametrize( - "invalid", + "negative", [ - tensor_([1.0, 0.0]), tensor_([1.0, -1.0]), - tensor_([1.0, 1e-13]), + tensor_([-1e-13, 2.0]), + tensor_([-1.0]), ], ) -def test_raises_on_non_positive_input(invalid: Tensor) -> None: +def test_raises_on_negative_input(negative: Tensor) -> None: with raises(ValueError): - GeometricMean()(invalid) + GeometricMean()(negative) + + +def test_returns_zero_when_a_value_is_zero() -> None: + # log(0) = -inf, so the geometric mean collapses to 0. This matches the LibMTL behavior; + # the gradient is nan, which is expected for this method. + assert GeometricMean()(tensor_([1.0, 0.0])) == 0.0 + + +def test_does_not_raise_on_tiny_positive_input() -> None: + # Tiny but strictly positive values are valid and must not be rejected. + assert GeometricMean()(tensor_([1.0, 1e-13])).isfinite() def test_representations() -> None: