-
Notifications
You must be signed in to change notification settings - Fork 17
feat(scalarization): Add scalarization package #701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ppraneth
wants to merge
17
commits into
SimplexLab:main
Choose a base branch
from
ppraneth:scalarization
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ac7efc5
Scaffold scalarization
ppraneth d273b51
add test cases scalarization
ppraneth 8e3068d
minor edit fixes
ppraneth 632c795
Update src/torchjd/scalarization/_random.py
ppraneth cb4f4e5
Update src/torchjd/scalarization/_random.py
ppraneth 6921068
feedback changes
ppraneth ec4b137
minor fix
ppraneth 586840c
docs add
ppraneth d274c93
docs add
ppraneth 80257e8
Merge branch 'main' into scalarization
ppraneth 54adee3
Merge branch 'main' into scalarization
ppraneth 0eeabae
fix part 1
ppraneth 4d496fa
fix part 2
ppraneth 67426b8
rename typical_inputs to non_scalar_inputs
ppraneth e0bc683
Merge branch 'main' into scalarization
ValerianRey 1f10059
rename non_scalar_inputs with all_inputs
ppraneth 1693fb4
fix _inputs.py
ppraneth File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| Constant | ||
| ======== | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Constant | ||
| :members: __call__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| scalarization | ||
| ============= | ||
|
|
||
| .. automodule:: torchjd.scalarization | ||
| :no-members: | ||
|
|
||
| Abstract base class | ||
| ------------------- | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Scalarizer | ||
| :members: __call__ | ||
|
|
||
|
|
||
| .. toctree:: | ||
| :hidden: | ||
| :maxdepth: 1 | ||
|
|
||
| constant.rst | ||
| mean.rst | ||
| random.rst | ||
| sum.rst |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| Mean | ||
| ==== | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Mean | ||
| :members: __call__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| Random | ||
| ====== | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Random | ||
| :members: __call__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| Sum | ||
| === | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Sum | ||
| :members: __call__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| """ | ||
| A :class:`~torchjd.scalarization.Scalarizer` reduces a tensor of values of any shape into a single | ||
| scalar value. This is the simple baseline | ||
| against which :class:`Aggregators <torchjd.aggregation.Aggregator>` are compared: instead of | ||
| combining the per-loss gradients via the Jacobian or its Gramian, a | ||
| :class:`~torchjd.scalarization.Scalarizer` combines the losses directly, and a standard call to | ||
| :meth:`~torch.Tensor.backward` produces the gradient. | ||
|
|
||
| The following example shows how to use :class:`~torchjd.scalarization.Mean` to combine a vector of | ||
| losses into a single scalar loss. | ||
|
|
||
| >>> from torch import tensor | ||
| >>> from torchjd.scalarization import Mean | ||
| >>> | ||
| >>> scalarizer = Mean() | ||
| >>> losses = tensor([1.0, 2.0, 3.0]) | ||
| >>> loss = scalarizer(losses) | ||
| >>> loss | ||
| tensor(2.) | ||
| """ | ||
|
|
||
| from ._constant import Constant | ||
| from ._mean import Mean | ||
| from ._random import Random | ||
| from ._scalarizer_base import Scalarizer | ||
| from ._sum import Sum | ||
|
|
||
| __all__ = [ | ||
| "Constant", | ||
| "Mean", | ||
| "Random", | ||
| "Scalarizer", | ||
| "Sum", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from torch import Tensor | ||
|
|
||
| from torchjd._vector_str import pref_vector_to_str_suffix | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class Constant(Scalarizer): | ||
| """ | ||
| :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values with | ||
| constant, pre-determined weights. | ||
|
|
||
| :param weights: The weights to apply to the values. Must have the same shape as the values | ||
| passed at call time. | ||
| """ | ||
|
|
||
| def __init__(self, weights: Tensor) -> None: | ||
| super().__init__() | ||
| self.weights = weights | ||
|
|
||
| def forward(self, values: Tensor, /) -> Tensor: | ||
| if values.shape != self.weights.shape: | ||
| raise ValueError( | ||
| f"Parameter `values` should have shape {tuple(self.weights.shape)} (matching the " | ||
| f"shape of the weights). Found `values.shape = {tuple(values.shape)}`.", | ||
| ) | ||
| return (self.weights * values).sum() | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(weights={repr(self.weights)})" | ||
|
|
||
| def __str__(self) -> str: | ||
| if self.weights.ndim == 1: | ||
| return f"{self.__class__.__name__}{pref_vector_to_str_suffix(self.weights)}" | ||
| return f"{self.__class__.__name__}(weights of shape {tuple(self.weights.shape)})" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| from torch import Tensor | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class Mean(Scalarizer): | ||
| """ | ||
| :class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of values. | ||
| """ | ||
|
|
||
| def forward(self, values: Tensor, /) -> Tensor: | ||
| return values.mean() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| import torch | ||
| from torch import Tensor | ||
| from torch.nn.functional import softmax | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class Random(Scalarizer): | ||
| """ | ||
| :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values with | ||
| positive random weights summing to 1, as defined in Algorithm 2 of `Reasonable Effectiveness of | ||
| Random Weighting: A Litmus Test for Multi-Task Learning | ||
| <https://arxiv.org/pdf/2111.10603.pdf>`_. | ||
| """ | ||
|
|
||
| def forward(self, values: Tensor, /) -> Tensor: | ||
| flat = torch.randn(values.numel(), device=values.device, dtype=values.dtype) | ||
| weights = softmax(flat, dim=-1).reshape(values.shape) | ||
| return (weights * values).sum() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| from abc import ABC, abstractmethod | ||
|
|
||
| from torch import Tensor, nn | ||
|
|
||
|
|
||
| class Scalarizer(nn.Module, ABC): | ||
| """ | ||
| Abstract base class for all scalarizers. Reduces a tensor of values of any shape into a single | ||
| scalar value. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
|
|
||
| @abstractmethod | ||
| def forward(self, values: Tensor, /) -> Tensor: | ||
| """Computes the scalarization from input tensor.""" | ||
|
|
||
| def __call__(self, values: Tensor, /) -> Tensor: | ||
| """ | ||
| Computes the scalar value from the input tensor of values and applies all registered hooks. | ||
|
|
||
| :param values: The tensor of values to scalarize. May be of any shape. | ||
| """ | ||
| return super().__call__(values) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}()" | ||
|
|
||
| def __str__(self) -> str: | ||
| return f"{self.__class__.__name__}" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| from torch import Tensor | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class Sum(Scalarizer): | ||
| """ | ||
| :class:`~torchjd.scalarization.Scalarizer` that returns the sum of the input tensor of values. | ||
| """ | ||
|
|
||
| def forward(self, values: Tensor, /) -> Tensor: | ||
| return values.sum() |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| import torch | ||
| from torch import Tensor | ||
| from utils.tensors import randperm_ | ||
|
|
||
| from torchjd.scalarization import Scalarizer | ||
|
|
||
|
|
||
| def assert_returns_scalar(scalarizer: Scalarizer, losses: Tensor) -> None: | ||
| out = scalarizer(losses) | ||
| assert out.dim() == 0 | ||
| assert out.isfinite() | ||
|
|
||
|
|
||
| def assert_grad_flow(scalarizer: Scalarizer, losses: Tensor) -> None: | ||
| leaf = losses.detach().requires_grad_() | ||
| out = scalarizer(leaf) | ||
| out.backward() | ||
| assert leaf.grad is not None | ||
| assert leaf.grad.isfinite().all() | ||
|
|
||
|
|
||
| def assert_permutation_invariant(scalarizer: Scalarizer, losses: Tensor) -> None: | ||
| out = scalarizer(losses) | ||
| flat = losses.flatten() | ||
| permuted = flat[randperm_(flat.numel())].reshape(losses.shape) | ||
| out_permuted = scalarizer(permuted) | ||
| torch.testing.assert_close(out, out_permuted) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from torch import Tensor | ||
| from utils.tensors import randn_ | ||
|
|
||
| shapes: list[list[int]] = [[], [5], [3, 4], [2, 3, 4]] | ||
| all_inputs: list[Tensor] = [randn_(shape) for shape in shapes] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.