-
Notifications
You must be signed in to change notification settings - Fork 20
feat(scalarization): Add DWA #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
adf8304
add DWA
ppraneth e568dbd
add notes
ppraneth f317edd
Merge branch 'main' into scalarization-6
ppraneth b791e6d
Merge branch 'main' into scalarization-6
ppraneth 2e7e69f
Update src/torchjd/scalarization/_dwa.py
ppraneth c0527dd
Update tests/unit/scalarization/test_dwa.py
ppraneth b9ae04b
Update tests/unit/scalarization/test_dwa.py
ppraneth ad4df9f
fix doc
ppraneth 1fd7038
Minor improvement of the docs
ValerianRey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| DWA | ||
| === | ||
|
|
||
| .. autoclass:: torchjd.scalarization.DWA | ||
| :members: __call__, step, reset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ Abstract base class | |
| :maxdepth: 1 | ||
|
|
||
| constant.rst | ||
| dwa.rst | ||
| geometric_mean.rst | ||
| imtl_l.rst | ||
| mean.rst | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| import torch | ||
| from torch import Tensor | ||
| from torch.nn.functional import softmax | ||
|
|
||
| from torchjd._mixins import Stateful | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class DWA(Scalarizer, Stateful): | ||
| r""" | ||
| :class:`~torchjd.Stateful` | ||
| :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using Dynamic | ||
| Weight Average (DWA), proposed in `End-to-End Multi-Task Learning with Attention | ||
| <https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf>`_. | ||
|
|
||
| DWA weights each value by how quickly its loss has been decreasing relative to the others. At | ||
| epoch :math:`t`, the current batch's values are combined as | ||
|
|
||
| .. math:: | ||
| \sum_k \lambda_k(t)\, \ell_k, \qquad | ||
| \lambda_k(t) = \frac{K \exp(w_k(t-1) / T)}{\sum_i \exp(w_i(t-1) / T)}, \qquad | ||
| w_k(t-1) = \frac{L_k(t-1)}{L_k(t-2)} | ||
|
|
||
| where: | ||
|
|
||
| - :math:`\ell_k` is the :math:`k`-th value being scalarized (typically the current batch's loss | ||
| for task k); | ||
| - :math:`L_k(t)` is the :math:`k`-th value averaged over epoch :math:`t` (used only for the | ||
| weights); | ||
| - :math:`w_k(t-1)` is the relative descending rate: the ratio of average losses over the two | ||
| previous epochs; | ||
| - :math:`T` is the temperature; a larger :math:`T` makes the weights more uniform; | ||
| - :math:`K` is the number of values (e.g. the number of tasks); the factor :math:`K` keeps | ||
| :math:`\sum_k \lambda_k = K`. | ||
|
|
||
| The weights use only the two previous epochs' average losses, so they need no gradient. At each | ||
| call, the scalarization is returned and the current batch's losses are summed to the current | ||
| epoch's loss sums. :meth:`step` must then be called once at the end of each epoch to finalize | ||
| that epoch's average loss and roll the history forward. During the first two epochs, before two | ||
| averages are available, the weights are uniform. | ||
|
|
||
| :param temperature: The temperature :math:`T`. Must be strictly positive. Larger values make the | ||
| weights more uniform. The paper uses ``2.0``. | ||
|
|
||
| The following example shows how to train a model with DWA. The scalarizer is called on every | ||
| batch, and :meth:`step` is called once at the end of each epoch. | ||
|
|
||
| >>> import torch | ||
| >>> from torch.nn import Linear | ||
| >>> | ||
| >>> from torchjd.scalarization import DWA | ||
| >>> | ||
| >>> model = Linear(3, 2) | ||
| >>> scalarizer = DWA() | ||
| >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
| >>> | ||
| >>> for epoch in range(3): | ||
| ... for _ in range(4): # Iterate over the batches of the epoch. | ||
| ... features = torch.randn(8, 3) | ||
| ... losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. | ||
| ... loss = scalarizer(losses) | ||
| ... optimizer.zero_grad() | ||
| ... loss.backward() | ||
| ... optimizer.step() | ||
| ... scalarizer.step() # Roll the epoch history once, at the end of the epoch. | ||
|
|
||
| .. note:: | ||
| DWA weights each value by the ratio of its losses over consecutive epochs, which the paper | ||
| defines as a descending rate in the range :math:`(0, +\infty)`. The losses are therefore | ||
| expected to keep a consistent, nonzero sign across epochs (they need not be positive). | ||
| """ | ||
|
|
||
| def __init__(self, temperature: float = 2.0) -> None: | ||
| if temperature <= 0.0: | ||
| raise ValueError( | ||
| f"Parameter `temperature` should be strictly positive. Found `temperature = " | ||
| f"{temperature}`." | ||
| ) | ||
|
|
||
| super().__init__() | ||
| self.temperature = temperature | ||
| self._loss_sum: Tensor | None = None | ||
| self._n_batches: int = 0 | ||
| self._previous_averages: list[Tensor] = [] | ||
|
|
||
| def forward(self, values: Tensor, /) -> Tensor: | ||
| weights = self._compute_weights(values) | ||
|
|
||
| detached = values.detach() | ||
| if self._loss_sum is None: | ||
| self._loss_sum = detached.clone() | ||
| elif self._loss_sum.shape != detached.shape: | ||
| raise ValueError( | ||
| f"The shape of `values` changed from {tuple(self._loss_sum.shape)} to " | ||
| f"{tuple(detached.shape)} within an epoch. Call `reset()` before changing it." | ||
| ) | ||
| else: | ||
| self._loss_sum = self._loss_sum + detached | ||
| self._n_batches += 1 | ||
|
|
||
| return (weights * values).sum() | ||
|
|
||
| def step(self) -> None: | ||
| """ | ||
| Finalizes the current epoch's average loss and rolls the history forward, discarding the | ||
| average from two epochs ago. Should be called once at the end of each epoch. | ||
| """ | ||
|
|
||
| if self._loss_sum is None: | ||
| return | ||
| average = self._loss_sum / self._n_batches | ||
| self._previous_averages = [*self._previous_averages, average][-2:] | ||
| self._loss_sum = None | ||
| self._n_batches = 0 | ||
|
|
||
| def reset(self) -> None: | ||
| self._loss_sum = None | ||
| self._n_batches = 0 | ||
| self._previous_averages = [] | ||
|
|
||
| def _compute_weights(self, values: Tensor) -> Tensor: | ||
| if len(self._previous_averages) < 2: | ||
| return torch.ones_like(values) | ||
| older = self._previous_averages[0] | ||
| newer = self._previous_averages[1] | ||
| if older.shape != values.shape: | ||
| raise ValueError( | ||
| f"The shape of `values` changed from {tuple(older.shape)} to " | ||
| f"{tuple(values.shape)}. Call `reset()` before changing it." | ||
| ) | ||
| rates = (newer / older).flatten() | ||
| weights = softmax(rates / self.temperature, dim=0) | ||
| return values.numel() * weights.reshape(values.shape) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(temperature={self.temperature})" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| import torch | ||
| from pytest import mark, raises | ||
| from torch import Tensor | ||
| from utils.tensors import ones_, tensor_ | ||
|
|
||
| from torchjd.scalarization import DWA | ||
|
|
||
| from ._asserts import assert_grad_flow, assert_returns_scalar | ||
| from ._inputs import all_inputs | ||
|
|
||
|
|
||
| def test_uniform_weights_for_first_two_epochs() -> None: | ||
| dwa = DWA(temperature=2.0) | ||
| # Epoch 1: no completed epoch yet, so weights are uniform (sum). | ||
| torch.testing.assert_close(dwa(tensor_([1.0, 3.0])), tensor_(4.0)) | ||
| dwa.step() | ||
| # Epoch 2: only one completed epoch, so weights are still uniform (sum). | ||
| torch.testing.assert_close(dwa(tensor_([2.0, 5.0])), tensor_(7.0)) | ||
| dwa.step() | ||
|
|
||
|
|
||
| def test_weights_from_previous_two_epochs() -> None: | ||
| dwa = DWA(temperature=2.0) | ||
| dwa(tensor_([1.0, 1.0])) | ||
| dwa.step() # Epoch 1 average = [1, 1]. | ||
| dwa(tensor_([1.0, 4.0])) | ||
| dwa.step() # Epoch 2 average = [1, 4]. | ||
| # Epoch 3: rates = [1, 4] / [1, 1] = [1, 4]. | ||
| losses = tensor_([3.0, 5.0]) | ||
| result = dwa(losses) | ||
| expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) | ||
| torch.testing.assert_close(result, (expected_weights * losses).sum()) | ||
|
|
||
|
|
||
| def test_uses_per_epoch_average() -> None: | ||
| # The weights use the average loss over each epoch's batches, not just the last batch. | ||
| dwa = DWA(temperature=2.0) | ||
| dwa(tensor_([2.0, 2.0])) | ||
| dwa(tensor_([0.0, 0.0])) | ||
| dwa.step() # Epoch 1 average = [1, 1]. | ||
| dwa(tensor_([2.0, 6.0])) | ||
| dwa(tensor_([0.0, 2.0])) | ||
| dwa.step() # Epoch 2 average = [1, 4]. | ||
| losses = tensor_([3.0, 5.0]) | ||
| result = dwa(losses) | ||
| expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) | ||
| torch.testing.assert_close(result, (expected_weights * losses).sum()) | ||
|
|
||
|
|
||
| def test_step_discards_oldest_epoch() -> None: | ||
| dwa = DWA(temperature=2.0) | ||
| dwa(tensor_([9.0, 9.0])) | ||
| dwa.step() # Epoch 1 average = [9, 9]; should be discarded after epoch 3. | ||
| dwa(tensor_([1.0, 1.0])) | ||
| dwa.step() # Epoch 2 average = [1, 1]. | ||
| dwa(tensor_([1.0, 4.0])) | ||
| dwa.step() # Epoch 3 average = [1, 4]. | ||
| # Epoch 4 uses only epochs 2 and 3: rates = [1, 4] / [1, 1] = [1, 4]. | ||
| losses = tensor_([3.0, 5.0]) | ||
| result = dwa(losses) | ||
| expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) | ||
| torch.testing.assert_close(result, (expected_weights * losses).sum()) | ||
|
|
||
|
|
||
| def test_weights_sum_to_numel() -> None: | ||
| dwa = DWA() | ||
| dwa(tensor_([1.0, 2.0])) | ||
| dwa.step() | ||
| dwa(tensor_([2.0, 1.0])) | ||
| dwa.step() | ||
| # The weights sum to the number of elements, so weighting a vector of ones gives that count. | ||
| torch.testing.assert_close(dwa(ones_((2,))), tensor_(2.0)) | ||
|
|
||
|
|
||
| @mark.parametrize("values", all_inputs) | ||
| def test_expected_structure(values: Tensor) -> None: | ||
| assert_returns_scalar(DWA(), values) | ||
|
|
||
|
|
||
| @mark.parametrize("values", all_inputs) | ||
| def test_grad_flow(values: Tensor) -> None: | ||
| assert_grad_flow(DWA(), values) | ||
|
|
||
|
|
||
| def test_grad_flows_with_computed_weights() -> None: | ||
| # After two epochs the weights are computed from the (detached) loss history; gradients must | ||
| # still flow to the current values. | ||
| dwa = DWA(temperature=2.0) | ||
| dwa(tensor_([1.0, 1.0])) | ||
| dwa.step() | ||
| dwa(tensor_([1.0, 4.0])) | ||
| dwa.step() | ||
| assert_grad_flow(dwa, tensor_([3.0, 5.0])) | ||
|
|
||
|
|
||
| def test_reset() -> None: | ||
| dwa = DWA() | ||
| dwa(tensor_([1.0, 2.0])) | ||
| dwa.step() | ||
| dwa(tensor_([3.0, 4.0])) | ||
| dwa.reset() | ||
| assert dwa._previous_averages == [] | ||
| assert dwa._loss_sum is None | ||
| assert dwa._n_batches == 0 | ||
|
|
||
|
|
||
| def test_step_without_forward_is_noop() -> None: | ||
| dwa = DWA() | ||
| dwa.step() # No losses accumulated yet. | ||
| assert dwa._previous_averages == [] | ||
|
|
||
|
|
||
| def test_supports_consistently_negative_losses() -> None: | ||
| # DWA works on negative losses too, as long as each value keeps a consistent sign: the ratio of | ||
| # same-sign losses is positive, so the weights match those of the equivalent positive case. | ||
| dwa = DWA(temperature=2.0) | ||
| dwa(tensor_([-2.0, -2.0])) | ||
| dwa.step() # Epoch 1 average = [-2, -2]. | ||
| dwa(tensor_([-2.0, -8.0])) | ||
| dwa.step() # Epoch 2 average = [-2, -8]; rates = [-2, -8] / [-2, -2] = [1, 4]. | ||
| losses = tensor_([3.0, 5.0]) | ||
| result = dwa(losses) | ||
| expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) | ||
| torch.testing.assert_close(result, (expected_weights * losses).sum()) | ||
|
|
||
|
|
||
| def test_raises_on_shape_change_within_epoch() -> None: | ||
| dwa = DWA() | ||
| dwa(tensor_([1.0, 2.0])) | ||
| with raises(ValueError): | ||
| dwa(tensor_([1.0, 2.0, 3.0])) | ||
|
|
||
|
|
||
| def test_raises_on_shape_change_between_epochs() -> None: | ||
| dwa = DWA() | ||
| dwa(tensor_([1.0, 2.0])) | ||
| dwa.step() | ||
| dwa(tensor_([2.0, 1.0])) | ||
| dwa.step() | ||
| with raises(ValueError): | ||
| dwa(tensor_([1.0, 2.0, 3.0])) | ||
|
|
||
|
|
||
| @mark.parametrize("temperature", [0.0, -1.0]) | ||
| def test_raises_on_non_positive_temperature(temperature: float) -> None: | ||
| with raises(ValueError): | ||
| DWA(temperature=temperature) | ||
|
|
||
|
|
||
| def test_representations() -> None: | ||
| assert repr(DWA()) == "DWA(temperature=2.0)" | ||
| assert repr(DWA(temperature=1.5)) == "DWA(temperature=1.5)" | ||
| assert str(DWA()) == "DWA" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.