Skip to content

feat(scalarization): Add GeometricMean#716

Merged
PierreQuinton merged 8 commits into
SimplexLab:mainfrom
ppraneth:scalarization-2
May 29, 2026
Merged

feat(scalarization): Add GeometricMean#716
PierreQuinton merged 8 commits into
SimplexLab:mainfrom
ppraneth:scalarization-2

Conversation

@ppraneth
Copy link
Copy Markdown
Contributor

@ppraneth ppraneth commented May 29, 2026

New torchjd.scalarization.GLS: the geometric mean of the input tensor of values, from MultiNet++ (Chennupati et al., CVPRW 2019).

Paper formula:

$$L_{\text{GLS}} = \left(\prod_{i=1}^{K} L_i\right)^{1/K}$$

Implementation

Computed in log space for numerical stability:

def forward(self, values: Tensor, /) -> Tensor:
    return torch.exp(torch.log(values).mean())

Mathematically equivalent to the paper's formula:

$$\exp!\left(\tfrac{1}{K}\sum_i \log L_i\right) = \left(\prod_i L_i\right)^{1/K}$$

The direct form values.prod() ** (1/K) would underflow prod to zero in float32 for realistic loss magnitudes (for example, 100 losses near 0.01). Log space stays in a safe range throughout.

Contract

  • Accepts a tensor of any shape; reduces over all elements (same convention as Mean and Sum).
  • Undefined on non-positive inputs by design. nan and -inf propagate with no silent clamp, so misuse surfaces immediately.

Files

File Purpose
src/torchjd/scalarization/_gmean.py New class
src/torchjd/scalarization/__init__.py Export
docs/source/docs/scalarization/gmean.rst Doc page
docs/source/docs/scalarization/index.rst Toctree entry
tests/unit/scalarization/test_gmean.py Unit tests
CHANGELOG.md [Unreleased] entry

Test plan

  • uv run pytest tests/unit/scalarization/test_gmean.py -W error -v
  • uv run pytest tests/unit -W error (full regression)
  • uv run ruff check && uv run ruff format --check
  • uv run ty check
  • uv run pre-commit run --all-files

The new test file covers value correctness, structural shape over all fixture shapes from _inputs.py (0-dim through 3-D), gradient flow, permutation invariance, nan propagation on negative input, and repr/str.

ppraneth added 3 commits May 28, 2026 19:41
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from a team as a code owner May 29, 2026 01:28
@PierreQuinton PierreQuinton added cc: feat Conventional commit type for new features. package: scalarization labels May 29, 2026
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks for the PR!

I think I would prefer the one of the names

  • GMean
  • GeomMean
  • GeometricMean

They certainly didn't invent the geometric mean in the paper, they just used it. Moreover the name I propose are more explicit.

Comment thread src/torchjd/scalarization/_geometric_mean.py
Comment thread src/torchjd/scalarization/_gls.py Outdated
Comment thread CHANGELOG.md Outdated
ppraneth and others added 3 commits May 29, 2026 13:41
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from PierreQuinton May 29, 2026 08:44
@ppraneth ppraneth changed the title feat(scalarization): Add GLS feat(scalarization): Add GMean May 29, 2026
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the last comment is resolved, this is read to merge

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! Looking very good already. A few comments to resolve and then we can merge IMO.

Comment thread docs/source/docs/scalarization/gmean.rst Outdated
Comment thread src/torchjd/scalarization/_gmean.py Outdated
Comment on lines +16 to +17
if (values <= 0.0).any():
return (values * 0.0).sum()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we raise an error when a value is negative? Or maybe to be safe, raise an error when a value is < 1e-12 (or some hard-coded stuff, doesn't have to be a parameter of the scalarizer IMO). (This would also need a change to the tests).

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant < -1e-12, not < 1e-12. I don't think we should raise an error when a loss has a tiny value, I think we should only raise when it's negative. Also, with the new formula, we really need to raise when a value is < 0, not < -1e-12 (because we don't have the fall back to returning zero anymore when it's between -1e-12 and 0).

So I think we should have:

if (values < 0.0).any():
    raise ValueError(
        "GeometricMean is only defined for strictly positive values. 
        "Found a negative value in the input."
    )

return torch.exp(torch.log(values).mean())

This way:

  • if a value is negative => error
  • if a value is zero => return zero, grad will be nan (which is to be expected with this method, and which is equivalent to the LibMTL impl).
  • else, everything works normally.

We should make a new PR to fix this IMO, but before, please tell me if you both agree with this @PierreQuinton @ppraneth.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, sorry about merging to fast.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will make a pr now

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread src/torchjd/scalarization/_geometric_mean.py
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from ValerianRey May 29, 2026 10:35
@PierreQuinton PierreQuinton changed the title feat(scalarization): Add GMean feat(scalarization): Add GeometricMean May 29, 2026
@PierreQuinton
Copy link
Copy Markdown
Contributor

Merging, congrats @ppraneth !!

@PierreQuinton PierreQuinton merged commit 47ad975 into SimplexLab:main May 29, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: scalarization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants