Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 1 addition & 52 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,58 +215,7 @@ losses with Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggr
```

You can even go one step further by considering the multiple tasks and each element of the batch
independently. We call that Instance-Wise Multitask Learning (IWMTL).

```python
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import Flattening, UPGradWeighting
from torchjd.autogram import Engine

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]

optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
engine = Engine(shared_module, batch_dim=0)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input) # shape: [16, 3]
out1 = task1_module(features).squeeze(1) # shape: [16]
out2 = task2_module(features).squeeze(1) # shape: [16]

# Compute the matrix of losses: one loss per element of the batch and per task
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]

# Compute the gramian (inner products between pairs of gradients of the losses)
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]

# Obtain the weights that lead to no conflict between reweighted gradients
weights = weighting(gramian) # shape: [16, 2]

# Do the standard backward pass, but weighted using the obtained weights
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()
```

> [!NOTE]
> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized
> Gramian* and we extract weights from it using a
> [GeneralizedWeighting](https://torchjd.org/stable/docs/aggregation/#torchjd.aggregation.GeneralizedWeighting).
independently (Instance-Wise Multitask Learning). See [this example](https://torchjd.org/stable/examples/iwmtl/) for more details.

More usage examples can be found [here](https://torchjd.org/stable/examples/).

Expand Down
Loading