feat(scalarization): Add GeometricMean#716
Conversation
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
PierreQuinton
left a comment
There was a problem hiding this comment.
Once the last comment is resolved, this is read to merge
ValerianRey
left a comment
There was a problem hiding this comment.
Thanks a lot for the PR! Looking very good already. A few comments to resolve and then we can merge IMO.
| if (values <= 0.0).any(): | ||
| return (values * 0.0).sum() |
There was a problem hiding this comment.
Shouldn't we raise an error when a value is negative? Or maybe to be safe, raise an error when a value is < 1e-12 (or some hard-coded stuff, doesn't have to be a parameter of the scalarizer IMO). (This would also need a change to the tests).
There was a problem hiding this comment.
Sorry, I meant < -1e-12, not < 1e-12. I don't think we should raise an error when a loss has a tiny value, I think we should only raise when it's negative. Also, with the new formula, we really need to raise when a value is < 0, not < -1e-12 (because we don't have the fall back to returning zero anymore when it's between -1e-12 and 0).
So I think we should have:
if (values < 0.0).any():
raise ValueError(
"GeometricMean is only defined for strictly positive values.
"Found a negative value in the input."
)
return torch.exp(torch.log(values).mean())This way:
- if a value is negative => error
- if a value is zero => return zero, grad will be nan (which is to be expected with this method, and which is equivalent to the LibMTL impl).
- else, everything works normally.
We should make a new PR to fix this IMO, but before, please tell me if you both agree with this @PierreQuinton @ppraneth.
There was a problem hiding this comment.
Agreed, sorry about merging to fast.
There was a problem hiding this comment.
Will make a pr now
There was a problem hiding this comment.
@PierreQuinton @ValerianRey PR with #718
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
GeometricMean
|
Merging, congrats @ppraneth !! |
New
torchjd.scalarization.GLS: the geometric mean of the input tensor of values, from MultiNet++ (Chennupati et al., CVPRW 2019).Paper formula:
Implementation
Computed in log space for numerical stability:
Mathematically equivalent to the paper's formula:
The direct form
values.prod() ** (1/K)would underflowprodto zero in float32 for realistic loss magnitudes (for example, 100 losses near 0.01). Log space stays in a safe range throughout.Contract
MeanandSum).nanand-infpropagate with no silent clamp, so misuse surfaces immediately.Files
src/torchjd/scalarization/_gmean.pysrc/torchjd/scalarization/__init__.pydocs/source/docs/scalarization/gmean.rstdocs/source/docs/scalarization/index.rsttests/unit/scalarization/test_gmean.pyCHANGELOG.md[Unreleased]entryTest plan
uv run pytest tests/unit/scalarization/test_gmean.py -W error -vuv run pytest tests/unit -W error(full regression)uv run ruff check && uv run ruff format --checkuv run ty checkuv run pre-commit run --all-filesThe new test file covers value correctness, structural shape over all fixture shapes from
_inputs.py(0-dim through 3-D), gradient flow, permutation invariance,nanpropagation on negative input, andrepr/str.