diff --git a/README.md b/README.md index 22a4bd25..8eec62bc 100644 --- a/README.md +++ b/README.md @@ -215,58 +215,7 @@ losses with Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggr ``` You can even go one step further by considering the multiple tasks and each element of the batch -independently. We call that Instance-Wise Multitask Learning (IWMTL). - -```python -import torch -from torch.nn import Linear, MSELoss, ReLU, Sequential -from torch.optim import SGD - -from torchjd.aggregation import Flattening, UPGradWeighting -from torchjd.autogram import Engine - -shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) -task1_module = Linear(3, 1) -task2_module = Linear(3, 1) -params = [ - *shared_module.parameters(), - *task1_module.parameters(), - *task2_module.parameters(), -] - -optimizer = SGD(params, lr=0.1) -mse = MSELoss(reduction="none") -weighting = Flattening(UPGradWeighting()) -engine = Engine(shared_module, batch_dim=0) - -inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 -task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task -task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task - -for input, target1, target2 in zip(inputs, task1_targets, task2_targets): - features = shared_module(input) # shape: [16, 3] - out1 = task1_module(features).squeeze(1) # shape: [16] - out2 = task2_module(features).squeeze(1) # shape: [16] - - # Compute the matrix of losses: one loss per element of the batch and per task - losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] - - # Compute the gramian (inner products between pairs of gradients of the losses) - gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] - - # Obtain the weights that lead to no conflict between reweighted gradients - weights = weighting(gramian) # shape: [16, 2] - - # Do the standard backward pass, but weighted using the obtained weights - losses.backward(weights) - optimizer.step() - optimizer.zero_grad() -``` - -> [!NOTE] -> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized -> Gramian* and we extract weights from it using a -> [GeneralizedWeighting](https://torchjd.org/stable/docs/aggregation/#torchjd.aggregation.GeneralizedWeighting). +independently (Instance-Wise Multitask Learning). See [this example](https://torchjd.org/stable/examples/iwmtl/) for more details. More usage examples can be found [here](https://torchjd.org/stable/examples/).