Skip to content

feat(scalarization): Add scalarization package#701

Open
ppraneth wants to merge 16 commits into
SimplexLab:mainfrom
ppraneth:scalarization
Open

feat(scalarization): Add scalarization package#701
ppraneth wants to merge 16 commits into
SimplexLab:mainfrom
ppraneth:scalarization

Conversation

@ppraneth
Copy link
Copy Markdown

Closes #666.

Summary

Adds a new torchjd.scalarization package providing simple baselines against which aggregators can be compared. Includes:

  • Scalarizer: abstract base class, inherits from nn.Module.
  • Mean: returns losses.mean().
  • Sum: returns losses.sum().
  • Constant(weights): returns (weights * losses).sum(). Validates weights.shape == losses.shape at call time.
  • Random: combines losses with positive random weights summing to 1 (RLW, Algorithm 2 of arXiv 2111.10603).

All scalarizers accept loss tensors of any shape (including 0-dim) and return a 0-dim scalar.

Design decisions (confirmed with maintainers )

  1. Short names mirroring the aggregation package, no Scalarizer suffix.
  2. Scalarizer inherits from nn.Module for consistency with Aggregator and Weighting, and to leave room for trainable scalarizers later.
  3. No Combiner generalization for now.
  4. Arbitrary-shape input, 0-dim output.
  5. Constant uses option (a): weights.shape must equal losses.shape.

The Stateful mixin stays in torchjd.aggregation._mixins for now. When the first stateful scalarizer lands, it can be moved to torchjd._mixins so both packages share it. A comment in scalarization/__init__.py records this.

Test plan

  • uv run pytest tests/unit/scalarization -W error -v passes.
  • uv run pytest tests/unit/scalarization --cov=src/torchjd/scalarization shows full coverage.
  • uv run pytest tests/unit -W error passes (3019 passed, 66 skipped, 33 xfailed).
  • PYTEST_TORCH_DTYPE=float64 uv run pytest tests/unit/scalarization -W error passes.
  • uv run ruff format --check and uv run ruff check pass.
  • uv run ty check passes.
  • uv run pre-commit run --all-files passes.

ppraneth added 3 commits May 26, 2026 19:03
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from a team as a code owner May 27, 2026 02:30
@PierreQuinton PierreQuinton added cc: feat Conventional commit type for new features. package: scalarization labels May 27, 2026
@github-actions github-actions Bot changed the title feat: Add scalarization package feat(scalarization): Add scalarization package May 27, 2026
@github-actions github-actions Bot changed the title feat: Add scalarization package feat(scalarization): Add scalarization package May 27, 2026
PierreQuinton

This comment was marked as resolved.

Comment thread src/torchjd/scalarization/__init__.py Outdated
ppraneth and others added 7 commits May 27, 2026 09:57
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tyvm for this very good PR.
Got a few nitpicks to fix, and then we can merge!

Comment thread docs/source/docs/scalarization/index.rst Outdated
Comment thread tests/unit/scalarization/test_constant.py Outdated
Comment thread tests/unit/scalarization/_inputs.py Outdated
Comment thread src/torchjd/scalarization/_scalarizer_base.py
Comment thread src/torchjd/scalarization/__init__.py Outdated
@ValerianRey
Copy link
Copy Markdown
Contributor

Also @ValerianRey should we define a code owner for the package or leave it as the default (maintainers) and then after maybe a few PR from @ppraneth we can make him code owner?

Yes, I think we can leave it as default and let @ppraneth become code owner if they are interested and add some non-trivial scalarizers in future PRs.

ppraneth added 3 commits May 27, 2026 23:55
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from ValerianRey May 27, 2026 19:01
@ppraneth
Copy link
Copy Markdown
Author

Yes, I think we can leave it as default and let @ppraneth become code owner if they are interested and add some non-trivial scalarizers in future PRs.

Thank you so much! I'd love to become a code owner one day and I'm really excited to contribute to this project

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
Still need to rename typical_inputs to non-scalar inputs (per Pierre's comment).
Also I forgot to mention that we need a changelog entry (I can do it).

@PierreQuinton I think we should make a release with this.

Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth
Copy link
Copy Markdown
Author

@ValerianRey Done, I have renamed typical_inputs to non_scalar_inputs.

@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks! Will do the changelog + merge tomorrow.

Comment thread tests/unit/scalarization/test_mean.py Outdated
assert_returns_scalar(Mean(), losses)


@mark.parametrize("losses", non_scalar_inputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use all_inputs essentially everywhere?

Copy link
Copy Markdown
Author

@ppraneth ppraneth May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, Switched all the parametrizations across test_mean.py, test_sum.py, and test_random.py to use all_inputs. With that done, non_scalar_inputs had no remaining use, so I removed it from _inputs.py

Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from PierreQuinton May 28, 2026 05:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: scalarization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add scalarization package

3 participants