feat(aggregation): Add MoDoWeighting#717
Conversation
PierreQuinton
left a comment
There was a problem hiding this comment.
Looking good, it still needs few changes but once this is merge, I think this makes #676 easier to merge.
|
|
||
| with torch.no_grad(): | ||
| grad = gramian @ lambd + self._rho * lambd | ||
| lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) |
There was a problem hiding this comment.
So in the end, this is a softmax. @rkhosrowshahi I think this means that moco is essentially just a composition with this weighting, where essentially you give yy_t to it, and then multiply yy_t by the obtained weights. Is that correct? If yes, I think we should change #676 accordingly.
| from ._weighting_bases import _GramianWeighting | ||
|
|
||
|
|
||
| class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): |
There was a problem hiding this comment.
/opencode:Plan is the inheritance order correct here?
There was a problem hiding this comment.
it matched _gradvac.py, and also the warning in docstring for _NonDifferntiable states "Placing this mixin before the primary base will cause it to shadow the primary class's call signature in generated documentation."
So yes, I believe it is
PierreQuinton
left a comment
There was a problem hiding this comment.
For me this is ready, let's wait for @ValerianRey 's review s still.
Adds
MoDoWeightingfrom Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance (JMLR 2024).It's a stateful
Weighting[PSDMatrix]implementing the λ-update from Algorithm 2:Per the discussion with @PierreQuinton and @ValerianRey on Discord, this follows the official LibMTL implementation which uses
softmaxrather than the paper's hard simplex projection.Designed to be composed with
autogram.Enginein a two-batch training loop so that MoDo's double-sampling property is preserved (Gramian comes from batch 1; backward uses batch 2).Test plan
tests/unit/aggregation/test_modo.py(12 functions, 72 cases — structural, reset, parameter validation, softmax boundary cases, recurrence verification)ty checkpasses on_modo.py-W --keep-going -nEOF
)"