feat(scalarization): Add scalarization package#701
Conversation
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
ValerianRey
left a comment
There was a problem hiding this comment.
Tyvm for this very good PR.
Got a few nitpicks to fix, and then we can merge!
Yes, I think we can leave it as default and let @ppraneth become code owner if they are interested and add some non-trivial scalarizers in future PRs. |
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Thank you so much! I'd love to become a code owner one day and I'm really excited to contribute to this project |
ValerianRey
left a comment
There was a problem hiding this comment.
LGTM.
Still need to rename typical_inputs to non-scalar inputs (per Pierre's comment).
Also I forgot to mention that we need a changelog entry (I can do it).
@PierreQuinton I think we should make a release with this.
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
|
@ValerianRey Done, I have renamed |
|
Thanks! Will do the changelog + merge tomorrow. |
| assert_returns_scalar(Mean(), losses) | ||
|
|
||
|
|
||
| @mark.parametrize("losses", non_scalar_inputs) |
There was a problem hiding this comment.
Can't we use all_inputs essentially everywhere?
There was a problem hiding this comment.
Done, Switched all the parametrizations across test_mean.py, test_sum.py, and test_random.py to use all_inputs. With that done, non_scalar_inputs had no remaining use, so I removed it from _inputs.py
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Closes #666.
Summary
Adds a new
torchjd.scalarizationpackage providing simple baselines against which aggregators can be compared. Includes:Scalarizer: abstract base class, inherits fromnn.Module.Mean: returnslosses.mean().Sum: returnslosses.sum().Constant(weights): returns(weights * losses).sum(). Validatesweights.shape == losses.shapeat call time.Random: combines losses with positive random weights summing to 1 (RLW, Algorithm 2 of arXiv 2111.10603).All scalarizers accept loss tensors of any shape (including 0-dim) and return a 0-dim scalar.
Design decisions (confirmed with maintainers )
Scalarizersuffix.Scalarizerinherits fromnn.Modulefor consistency withAggregatorandWeighting, and to leave room for trainable scalarizers later.Combinergeneralization for now.Constantuses option (a):weights.shapemust equallosses.shape.The
Statefulmixin stays intorchjd.aggregation._mixinsfor now. When the first stateful scalarizer lands, it can be moved totorchjd._mixinsso both packages share it. A comment inscalarization/__init__.pyrecords this.Test plan
uv run pytest tests/unit/scalarization -W error -vpasses.uv run pytest tests/unit/scalarization --cov=src/torchjd/scalarizationshows full coverage.uv run pytest tests/unit -W errorpasses (3019 passed, 66 skipped, 33 xfailed).PYTEST_TORCH_DTYPE=float64 uv run pytest tests/unit/scalarization -W errorpasses.uv run ruff format --checkanduv run ruff checkpass.uv run ty checkpasses.uv run pre-commit run --all-filespasses.