diff --git a/CHANGELOG.md b/CHANGELOG.md index f2daafd0..1a8b4cba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `DWA` (Dynamic Weight Average) from [End-to-End Multi-Task Learning with + Attention](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf) + (CVPR 2019), a stateful `Scalarizer` that weights each value by the relative rate at which its + loss decreased over the two previous epochs. It has no learnable parameters; call its `step()` + method once per epoch to roll the loss history. - Added `SDMGradWeighting` from [Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms](https://arxiv.org/pdf/2305.18409) (NeurIPS 2023). It is a stateful `Weighting` that solves for task weights via a simplex-projected inner loop on a cross-batch matrix `A = J_1 @ J_2.T` (computed from two independent mini-batches using `autojac.jac`), with a direction-oriented regularizer pulling the descent direction toward a preference direction. - Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful diff --git a/docs/source/docs/scalarization/dwa.rst b/docs/source/docs/scalarization/dwa.rst new file mode 100644 index 00000000..49965ac6 --- /dev/null +++ b/docs/source/docs/scalarization/dwa.rst @@ -0,0 +1,7 @@ +:hide-toc: + +DWA +=== + +.. autoclass:: torchjd.scalarization.DWA + :members: __call__, step, reset diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst index f2d42be7..fff5d797 100644 --- a/docs/source/docs/scalarization/index.rst +++ b/docs/source/docs/scalarization/index.rst @@ -15,6 +15,7 @@ Abstract base class :maxdepth: 1 constant.rst + dwa.rst geometric_mean.rst imtl_l.rst mean.rst diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index 6b37d281..6f596edd 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -20,6 +20,7 @@ """ from ._constant import Constant +from ._dwa import DWA from ._geometric_mean import GeometricMean from ._imtl_l import IMTLL from ._mean import Mean @@ -31,6 +32,7 @@ __all__ = [ "Constant", + "DWA", "GeometricMean", "IMTLL", "Mean", diff --git a/src/torchjd/scalarization/_dwa.py b/src/torchjd/scalarization/_dwa.py new file mode 100644 index 00000000..d6d3559f --- /dev/null +++ b/src/torchjd/scalarization/_dwa.py @@ -0,0 +1,137 @@ +import torch +from torch import Tensor +from torch.nn.functional import softmax + +from torchjd._mixins import Stateful + +from ._scalarizer_base import Scalarizer + + +class DWA(Scalarizer, Stateful): + r""" + :class:`~torchjd.Stateful` + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using Dynamic + Weight Average (DWA), proposed in `End-to-End Multi-Task Learning with Attention + `_. + + DWA weights each value by how quickly its loss has been decreasing relative to the others. At + epoch :math:`t`, the current batch's values are combined as + + .. math:: + \sum_k \lambda_k(t)\, \ell_k, \qquad + \lambda_k(t) = \frac{K \exp(w_k(t-1) / T)}{\sum_i \exp(w_i(t-1) / T)}, \qquad + w_k(t-1) = \frac{L_k(t-1)}{L_k(t-2)} + + where: + + - :math:`\ell_k` is the :math:`k`-th value being scalarized (typically the current batch's loss + for task k); + - :math:`L_k(t)` is the :math:`k`-th value averaged over epoch :math:`t` (used only for the + weights); + - :math:`w_k(t-1)` is the relative descending rate: the ratio of average losses over the two + previous epochs; + - :math:`T` is the temperature; a larger :math:`T` makes the weights more uniform; + - :math:`K` is the number of values (e.g. the number of tasks); the factor :math:`K` keeps + :math:`\sum_k \lambda_k = K`. + + The weights use only the two previous epochs' average losses, so they need no gradient. At each + call, the scalarization is returned and the current batch's losses are summed to the current + epoch's loss sums. :meth:`step` must then be called once at the end of each epoch to finalize + that epoch's average loss and roll the history forward. During the first two epochs, before two + averages are available, the weights are uniform. + + :param temperature: The temperature :math:`T`. Must be strictly positive. Larger values make the + weights more uniform. The paper uses ``2.0``. + + The following example shows how to train a model with DWA. The scalarizer is called on every + batch, and :meth:`step` is called once at the end of each epoch. + + >>> import torch + >>> from torch.nn import Linear + >>> + >>> from torchjd.scalarization import DWA + >>> + >>> model = Linear(3, 2) + >>> scalarizer = DWA() + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + >>> + >>> for epoch in range(3): + ... for _ in range(4): # Iterate over the batches of the epoch. + ... features = torch.randn(8, 3) + ... losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. + ... loss = scalarizer(losses) + ... optimizer.zero_grad() + ... loss.backward() + ... optimizer.step() + ... scalarizer.step() # Roll the epoch history once, at the end of the epoch. + + .. note:: + DWA weights each value by the ratio of its losses over consecutive epochs, which the paper + defines as a descending rate in the range :math:`(0, +\infty)`. The losses are therefore + expected to keep a consistent, nonzero sign across epochs (they need not be positive). + """ + + def __init__(self, temperature: float = 2.0) -> None: + if temperature <= 0.0: + raise ValueError( + f"Parameter `temperature` should be strictly positive. Found `temperature = " + f"{temperature}`." + ) + + super().__init__() + self.temperature = temperature + self._loss_sum: Tensor | None = None + self._n_batches: int = 0 + self._previous_averages: list[Tensor] = [] + + def forward(self, values: Tensor, /) -> Tensor: + weights = self._compute_weights(values) + + detached = values.detach() + if self._loss_sum is None: + self._loss_sum = detached.clone() + elif self._loss_sum.shape != detached.shape: + raise ValueError( + f"The shape of `values` changed from {tuple(self._loss_sum.shape)} to " + f"{tuple(detached.shape)} within an epoch. Call `reset()` before changing it." + ) + else: + self._loss_sum = self._loss_sum + detached + self._n_batches += 1 + + return (weights * values).sum() + + def step(self) -> None: + """ + Finalizes the current epoch's average loss and rolls the history forward, discarding the + average from two epochs ago. Should be called once at the end of each epoch. + """ + + if self._loss_sum is None: + return + average = self._loss_sum / self._n_batches + self._previous_averages = [*self._previous_averages, average][-2:] + self._loss_sum = None + self._n_batches = 0 + + def reset(self) -> None: + self._loss_sum = None + self._n_batches = 0 + self._previous_averages = [] + + def _compute_weights(self, values: Tensor) -> Tensor: + if len(self._previous_averages) < 2: + return torch.ones_like(values) + older = self._previous_averages[0] + newer = self._previous_averages[1] + if older.shape != values.shape: + raise ValueError( + f"The shape of `values` changed from {tuple(older.shape)} to " + f"{tuple(values.shape)}. Call `reset()` before changing it." + ) + rates = (newer / older).flatten() + weights = softmax(rates / self.temperature, dim=0) + return values.numel() * weights.reshape(values.shape) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(temperature={self.temperature})" diff --git a/tests/unit/scalarization/test_dwa.py b/tests/unit/scalarization/test_dwa.py new file mode 100644 index 00000000..d38fe0c1 --- /dev/null +++ b/tests/unit/scalarization/test_dwa.py @@ -0,0 +1,153 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from utils.tensors import ones_, tensor_ + +from torchjd.scalarization import DWA + +from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import all_inputs + + +def test_uniform_weights_for_first_two_epochs() -> None: + dwa = DWA(temperature=2.0) + # Epoch 1: no completed epoch yet, so weights are uniform (sum). + torch.testing.assert_close(dwa(tensor_([1.0, 3.0])), tensor_(4.0)) + dwa.step() + # Epoch 2: only one completed epoch, so weights are still uniform (sum). + torch.testing.assert_close(dwa(tensor_([2.0, 5.0])), tensor_(7.0)) + dwa.step() + + +def test_weights_from_previous_two_epochs() -> None: + dwa = DWA(temperature=2.0) + dwa(tensor_([1.0, 1.0])) + dwa.step() # Epoch 1 average = [1, 1]. + dwa(tensor_([1.0, 4.0])) + dwa.step() # Epoch 2 average = [1, 4]. + # Epoch 3: rates = [1, 4] / [1, 1] = [1, 4]. + losses = tensor_([3.0, 5.0]) + result = dwa(losses) + expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) + torch.testing.assert_close(result, (expected_weights * losses).sum()) + + +def test_uses_per_epoch_average() -> None: + # The weights use the average loss over each epoch's batches, not just the last batch. + dwa = DWA(temperature=2.0) + dwa(tensor_([2.0, 2.0])) + dwa(tensor_([0.0, 0.0])) + dwa.step() # Epoch 1 average = [1, 1]. + dwa(tensor_([2.0, 6.0])) + dwa(tensor_([0.0, 2.0])) + dwa.step() # Epoch 2 average = [1, 4]. + losses = tensor_([3.0, 5.0]) + result = dwa(losses) + expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) + torch.testing.assert_close(result, (expected_weights * losses).sum()) + + +def test_step_discards_oldest_epoch() -> None: + dwa = DWA(temperature=2.0) + dwa(tensor_([9.0, 9.0])) + dwa.step() # Epoch 1 average = [9, 9]; should be discarded after epoch 3. + dwa(tensor_([1.0, 1.0])) + dwa.step() # Epoch 2 average = [1, 1]. + dwa(tensor_([1.0, 4.0])) + dwa.step() # Epoch 3 average = [1, 4]. + # Epoch 4 uses only epochs 2 and 3: rates = [1, 4] / [1, 1] = [1, 4]. + losses = tensor_([3.0, 5.0]) + result = dwa(losses) + expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) + torch.testing.assert_close(result, (expected_weights * losses).sum()) + + +def test_weights_sum_to_numel() -> None: + dwa = DWA() + dwa(tensor_([1.0, 2.0])) + dwa.step() + dwa(tensor_([2.0, 1.0])) + dwa.step() + # The weights sum to the number of elements, so weighting a vector of ones gives that count. + torch.testing.assert_close(dwa(ones_((2,))), tensor_(2.0)) + + +@mark.parametrize("values", all_inputs) +def test_expected_structure(values: Tensor) -> None: + assert_returns_scalar(DWA(), values) + + +@mark.parametrize("values", all_inputs) +def test_grad_flow(values: Tensor) -> None: + assert_grad_flow(DWA(), values) + + +def test_grad_flows_with_computed_weights() -> None: + # After two epochs the weights are computed from the (detached) loss history; gradients must + # still flow to the current values. + dwa = DWA(temperature=2.0) + dwa(tensor_([1.0, 1.0])) + dwa.step() + dwa(tensor_([1.0, 4.0])) + dwa.step() + assert_grad_flow(dwa, tensor_([3.0, 5.0])) + + +def test_reset() -> None: + dwa = DWA() + dwa(tensor_([1.0, 2.0])) + dwa.step() + dwa(tensor_([3.0, 4.0])) + dwa.reset() + assert dwa._previous_averages == [] + assert dwa._loss_sum is None + assert dwa._n_batches == 0 + + +def test_step_without_forward_is_noop() -> None: + dwa = DWA() + dwa.step() # No losses accumulated yet. + assert dwa._previous_averages == [] + + +def test_supports_consistently_negative_losses() -> None: + # DWA works on negative losses too, as long as each value keeps a consistent sign: the ratio of + # same-sign losses is positive, so the weights match those of the equivalent positive case. + dwa = DWA(temperature=2.0) + dwa(tensor_([-2.0, -2.0])) + dwa.step() # Epoch 1 average = [-2, -2]. + dwa(tensor_([-2.0, -8.0])) + dwa.step() # Epoch 2 average = [-2, -8]; rates = [-2, -8] / [-2, -2] = [1, 4]. + losses = tensor_([3.0, 5.0]) + result = dwa(losses) + expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) + torch.testing.assert_close(result, (expected_weights * losses).sum()) + + +def test_raises_on_shape_change_within_epoch() -> None: + dwa = DWA() + dwa(tensor_([1.0, 2.0])) + with raises(ValueError): + dwa(tensor_([1.0, 2.0, 3.0])) + + +def test_raises_on_shape_change_between_epochs() -> None: + dwa = DWA() + dwa(tensor_([1.0, 2.0])) + dwa.step() + dwa(tensor_([2.0, 1.0])) + dwa.step() + with raises(ValueError): + dwa(tensor_([1.0, 2.0, 3.0])) + + +@mark.parametrize("temperature", [0.0, -1.0]) +def test_raises_on_non_positive_temperature(temperature: float) -> None: + with raises(ValueError): + DWA(temperature=temperature) + + +def test_representations() -> None: + assert repr(DWA()) == "DWA(temperature=2.0)" + assert repr(DWA(temperature=1.5)) == "DWA(temperature=1.5)" + assert str(DWA()) == "DWA"