From adf83049172534a70512fb42d016a21de5871166 Mon Sep 17 00:00:00 2001 From: ppraneth Date: Thu, 11 Jun 2026 09:28:15 +0530 Subject: [PATCH 1/7] add DWA Signed-off-by: ppraneth --- CHANGELOG.md | 5 + docs/source/docs/scalarization/dwa.rst | 7 ++ docs/source/docs/scalarization/index.rst | 1 + src/torchjd/scalarization/__init__.py | 2 + src/torchjd/scalarization/_dwa.py | 135 ++++++++++++++++++++ tests/unit/scalarization/test_dwa.py | 154 +++++++++++++++++++++++ 6 files changed, 304 insertions(+) create mode 100644 docs/source/docs/scalarization/dwa.rst create mode 100644 src/torchjd/scalarization/_dwa.py create mode 100644 tests/unit/scalarization/test_dwa.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e259ea0..f4dc0aea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `DWA` (Dynamic Weight Average) from [End-to-End Multi-Task Learning with + Attention](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf) + (CVPR 2019), a stateful `Scalarizer` that weights each value by the relative rate at which its + loss decreased over the two previous epochs. It has no learnable parameters; call its `step()` + method once per epoch to roll the loss history. - Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful `Scalarizer` that learns a per-task scale `s_i` and combines the values as diff --git a/docs/source/docs/scalarization/dwa.rst b/docs/source/docs/scalarization/dwa.rst new file mode 100644 index 00000000..49965ac6 --- /dev/null +++ b/docs/source/docs/scalarization/dwa.rst @@ -0,0 +1,7 @@ +:hide-toc: + +DWA +=== + +.. autoclass:: torchjd.scalarization.DWA + :members: __call__, step, reset diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst index f2d42be7..fff5d797 100644 --- a/docs/source/docs/scalarization/index.rst +++ b/docs/source/docs/scalarization/index.rst @@ -15,6 +15,7 @@ Abstract base class :maxdepth: 1 constant.rst + dwa.rst geometric_mean.rst imtl_l.rst mean.rst diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index 6b37d281..6f596edd 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -20,6 +20,7 @@ """ from ._constant import Constant +from ._dwa import DWA from ._geometric_mean import GeometricMean from ._imtl_l import IMTLL from ._mean import Mean @@ -31,6 +32,7 @@ __all__ = [ "Constant", + "DWA", "GeometricMean", "IMTLL", "Mean", diff --git a/src/torchjd/scalarization/_dwa.py b/src/torchjd/scalarization/_dwa.py new file mode 100644 index 00000000..3c990e86 --- /dev/null +++ b/src/torchjd/scalarization/_dwa.py @@ -0,0 +1,135 @@ +import torch +from torch import Tensor +from torch.nn.functional import softmax + +from torchjd._mixins import Stateful + +from ._scalarizer_base import Scalarizer + + +class DWA(Scalarizer, Stateful): + r""" + :class:`~torchjd.Stateful` + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using Dynamic + Weight Average (DWA), proposed in `End-to-End Multi-Task Learning with Attention + `_. + + DWA weights each value by how quickly its loss has been decreasing relative to the others. At + epoch :math:`t`, the values are combined as + + .. math:: + \sum_k \lambda_k(t)\, L_k(t), \qquad + \lambda_k(t) = \frac{K \exp(w_k(t-1) / T)}{\sum_i \exp(w_i(t-1) / T)}, \qquad + w_k(t-1) = \frac{L_k(t-1)}{L_k(t-2)} + + where: + + - :math:`L_k(t)` is the :math:`k`-th value (typically the loss of task :math:`k`) at epoch + :math:`t`, averaged over the batches of that epoch; + - :math:`w_k(t-1)` is the relative descending rate of task :math:`k` (the ratio of its average + losses over the two previous epochs); + - :math:`T` is the temperature (a larger :math:`T` makes the weights more uniform); + - :math:`K` is the number of values, and the factor :math:`K` keeps :math:`\sum_k \lambda_k = K`. + + The weights depend only on the average losses of the two previous epochs, so they are computed + from past values and require no gradient. At each call, the current epoch's losses are + accumulated, and :meth:`step` must be called once at the end of each epoch to finalize that + epoch's average loss and roll the history forward. During the first two epochs (before two + averages are available) the weights are uniform. + + Unlike :class:`~torchjd.scalarization.UW` and :class:`~torchjd.scalarization.IMTLL`, DWA has no + learnable parameters; its state is a non-trainable buffer cleared by :meth:`reset`. + + :param temperature: The temperature :math:`T`. Must be strictly positive. Larger values make the + weights more uniform. The paper uses ``2.0``. + + The following example shows how to train a model with DWA. The scalarizer is called on every + batch, and :meth:`step` is called once at the end of each epoch. + + >>> import torch + >>> from torch.nn import Linear + >>> + >>> from torchjd.scalarization import DWA + >>> + >>> model = Linear(3, 2) + >>> scalarizer = DWA() + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + >>> + >>> for epoch in range(3): + ... for _ in range(4): # Iterate over the batches of the epoch. + ... features = torch.randn(8, 3) + ... losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. + ... loss = scalarizer(losses) + ... optimizer.zero_grad() + ... loss.backward() + ... optimizer.step() + ... scalarizer.step() # Roll the epoch history once, at the end of the epoch. + + .. note:: + DWA is designed to balance positive losses (it divides the losses of consecutive epochs). + """ + + def __init__(self, temperature: float = 2.0) -> None: + if temperature <= 0.0: + raise ValueError( + f"Parameter `temperature` should be strictly positive. Found `temperature = " + f"{temperature}`." + ) + + super().__init__() + self.temperature = temperature + self._loss_sum: Tensor | None = None + self._n_batches: int = 0 + self._previous_averages: list[Tensor] = [] + + def forward(self, values: Tensor, /) -> Tensor: + weights = self._compute_weights(values) + + detached = values.detach() + if self._loss_sum is None: + self._loss_sum = detached.clone() + elif self._loss_sum.shape != detached.shape: + raise ValueError( + f"The shape of `values` changed from {tuple(self._loss_sum.shape)} to " + f"{tuple(detached.shape)} within an epoch. Call `reset()` before changing it." + ) + else: + self._loss_sum = self._loss_sum + detached + self._n_batches += 1 + + return (weights * values).sum() + + def step(self) -> None: + """ + Finalizes the current epoch's average loss and rolls the history forward, discarding the + average from two epochs ago. Should be called once at the end of each epoch. + """ + + if self._loss_sum is None: + return + average = self._loss_sum / self._n_batches + self._previous_averages = [*self._previous_averages, average][-2:] + self._loss_sum = None + self._n_batches = 0 + + def reset(self) -> None: + self._loss_sum = None + self._n_batches = 0 + self._previous_averages = [] + + def _compute_weights(self, values: Tensor) -> Tensor: + if len(self._previous_averages) < 2: + return torch.ones_like(values) + older = self._previous_averages[0] + newer = self._previous_averages[1] + if older.shape != values.shape: + raise ValueError( + f"The shape of `values` changed from {tuple(older.shape)} to " + f"{tuple(values.shape)}. Call `reset()` before changing it." + ) + rates = (newer / older).flatten() + weights = softmax(rates / self.temperature, dim=0) + return values.numel() * weights.reshape(values.shape) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(temperature={self.temperature})" diff --git a/tests/unit/scalarization/test_dwa.py b/tests/unit/scalarization/test_dwa.py new file mode 100644 index 00000000..93875cf7 --- /dev/null +++ b/tests/unit/scalarization/test_dwa.py @@ -0,0 +1,154 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from utils.tensors import ones_, tensor_ + +from torchjd.scalarization import DWA + +from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import all_inputs + + +def test_value_bootstrap() -> None: + # Before two epochs have completed, DWA uses uniform weights (so it acts like a sum). + values = tensor_([1.0, 2.0, 4.0]) + torch.testing.assert_close(DWA()(values), tensor_(7.0)) + + +def test_uniform_weights_for_first_two_epochs() -> None: + dwa = DWA(temperature=2.0) + # Epoch 1: no completed epoch yet, so weights are uniform (sum). + torch.testing.assert_close(dwa(tensor_([1.0, 3.0])), tensor_(4.0)) + dwa.step() + # Epoch 2: only one completed epoch, so weights are still uniform (sum). + torch.testing.assert_close(dwa(tensor_([2.0, 5.0])), tensor_(7.0)) + dwa.step() + + +def test_weights_from_previous_two_epochs() -> None: + dwa = DWA(temperature=2.0) + dwa(tensor_([1.0, 1.0])) + dwa.step() # Epoch 1 average = [1, 1]. + dwa(tensor_([1.0, 4.0])) + dwa.step() # Epoch 2 average = [1, 4]. + # Epoch 3: rates = [1, 4] / [1, 1] = [1, 4]. + losses = tensor_([3.0, 5.0]) + result = dwa(losses) + expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) + torch.testing.assert_close(result, (expected_weights * losses).sum()) + + +def test_uses_per_epoch_average() -> None: + # The weights use the average loss over each epoch's batches, not just the last batch. + dwa = DWA(temperature=2.0) + dwa(tensor_([2.0, 2.0])) + dwa(tensor_([0.0, 0.0])) + dwa.step() # Epoch 1 average = [1, 1]. + dwa(tensor_([2.0, 6.0])) + dwa(tensor_([0.0, 2.0])) + dwa.step() # Epoch 2 average = [1, 4]. + losses = tensor_([3.0, 5.0]) + result = dwa(losses) + expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) + torch.testing.assert_close(result, (expected_weights * losses).sum()) + + +def test_step_discards_oldest_epoch() -> None: + dwa = DWA(temperature=2.0) + dwa(tensor_([9.0, 9.0])) + dwa.step() # Epoch 1 average = [9, 9]; should be discarded after epoch 3. + dwa(tensor_([1.0, 1.0])) + dwa.step() # Epoch 2 average = [1, 1]. + dwa(tensor_([1.0, 4.0])) + dwa.step() # Epoch 3 average = [1, 4]. + # Epoch 4 uses only epochs 2 and 3: rates = [1, 4] / [1, 1] = [1, 4]. + losses = tensor_([3.0, 5.0]) + result = dwa(losses) + expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) + torch.testing.assert_close(result, (expected_weights * losses).sum()) + + +def test_weights_sum_to_numel() -> None: + dwa = DWA() + dwa(tensor_([1.0, 2.0])) + dwa.step() + dwa(tensor_([2.0, 1.0])) + dwa.step() + # The weights sum to the number of elements, so weighting a vector of ones gives that count. + torch.testing.assert_close(dwa(ones_((2,))), tensor_(2.0)) + + +@mark.parametrize("values", all_inputs) +def test_expected_structure(values: Tensor) -> None: + assert_returns_scalar(DWA(), values) + + +@mark.parametrize("values", all_inputs) +def test_grad_flow(values: Tensor) -> None: + assert_grad_flow(DWA(), values) + + +def test_grad_flows_with_computed_weights() -> None: + # After two epochs the weights are computed from the (detached) loss history; gradients must + # still flow to the current values. + dwa = DWA(temperature=2.0) + dwa(tensor_([1.0, 1.0])) + dwa.step() + dwa(tensor_([1.0, 4.0])) + dwa.step() + assert_grad_flow(dwa, tensor_([3.0, 5.0])) + + +def test_reset() -> None: + dwa = DWA() + dwa(tensor_([1.0, 2.0])) + dwa.step() + dwa(tensor_([3.0, 4.0])) + dwa.reset() + assert dwa._previous_averages == [] + assert dwa._loss_sum is None + assert dwa._n_batches == 0 + + +def test_step_without_forward_is_noop() -> None: + dwa = DWA() + dwa.step() # No losses accumulated yet. + assert dwa._previous_averages == [] + + +def test_has_no_learnable_parameters() -> None: + assert list(DWA().parameters()) == [] + + +def test_does_not_raise_on_negative_input() -> None: + # DWA is designed for positive losses but does not enforce a positivity precondition. + assert_returns_scalar(DWA(), tensor_([-1.0, -2.0, 3.0])) + + +def test_raises_on_shape_change_within_epoch() -> None: + dwa = DWA() + dwa(tensor_([1.0, 2.0])) + with raises(ValueError): + dwa(tensor_([1.0, 2.0, 3.0])) + + +def test_raises_on_shape_change_between_epochs() -> None: + dwa = DWA() + dwa(tensor_([1.0, 2.0])) + dwa.step() + dwa(tensor_([2.0, 1.0])) + dwa.step() + with raises(ValueError): + dwa(tensor_([1.0, 2.0, 3.0])) + + +@mark.parametrize("temperature", [0.0, -1.0]) +def test_raises_on_non_positive_temperature(temperature: float) -> None: + with raises(ValueError): + DWA(temperature=temperature) + + +def test_representations() -> None: + assert repr(DWA()) == "DWA(temperature=2.0)" + assert repr(DWA(temperature=1.5)) == "DWA(temperature=1.5)" + assert str(DWA()) == "DWA" From e568dbdb12322e0e086c9365cac2c3ba549b3e04 Mon Sep 17 00:00:00 2001 From: ppraneth Date: Thu, 11 Jun 2026 09:48:15 +0530 Subject: [PATCH 2/7] add notes Signed-off-by: ppraneth --- src/torchjd/scalarization/_dwa.py | 4 +++- tests/unit/scalarization/test_dwa.py | 15 ++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/torchjd/scalarization/_dwa.py b/src/torchjd/scalarization/_dwa.py index 3c990e86..a8de99d7 100644 --- a/src/torchjd/scalarization/_dwa.py +++ b/src/torchjd/scalarization/_dwa.py @@ -66,7 +66,9 @@ class DWA(Scalarizer, Stateful): ... scalarizer.step() # Roll the epoch history once, at the end of the epoch. .. note:: - DWA is designed to balance positive losses (it divides the losses of consecutive epochs). + DWA weights each value by the ratio of its losses over consecutive epochs, which the paper + defines as a descending rate in the range :math:`(0, +\infty)`. The losses are therefore + expected to keep a consistent, nonzero sign across epochs (they need not be positive). """ def __init__(self, temperature: float = 2.0) -> None: diff --git a/tests/unit/scalarization/test_dwa.py b/tests/unit/scalarization/test_dwa.py index 93875cf7..446a18fb 100644 --- a/tests/unit/scalarization/test_dwa.py +++ b/tests/unit/scalarization/test_dwa.py @@ -120,9 +120,18 @@ def test_has_no_learnable_parameters() -> None: assert list(DWA().parameters()) == [] -def test_does_not_raise_on_negative_input() -> None: - # DWA is designed for positive losses but does not enforce a positivity precondition. - assert_returns_scalar(DWA(), tensor_([-1.0, -2.0, 3.0])) +def test_supports_consistently_negative_losses() -> None: + # DWA works on negative losses too, as long as each value keeps a consistent sign: the ratio of + # same-sign losses is positive, so the weights match those of the equivalent positive case. + dwa = DWA(temperature=2.0) + dwa(tensor_([-2.0, -2.0])) + dwa.step() # Epoch 1 average = [-2, -2]. + dwa(tensor_([-2.0, -8.0])) + dwa.step() # Epoch 2 average = [-2, -8]; rates = [-2, -8] / [-2, -2] = [1, 4]. + losses = tensor_([3.0, 5.0]) + result = dwa(losses) + expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0) + torch.testing.assert_close(result, (expected_weights * losses).sum()) def test_raises_on_shape_change_within_epoch() -> None: From 2e7e69f56212dd141ffb5cf6d3d584b9b812f1f8 Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri Date: Thu, 11 Jun 2026 15:46:09 +0530 Subject: [PATCH 3/7] Update src/torchjd/scalarization/_dwa.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- src/torchjd/scalarization/_dwa.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/torchjd/scalarization/_dwa.py b/src/torchjd/scalarization/_dwa.py index a8de99d7..0f26a1b7 100644 --- a/src/torchjd/scalarization/_dwa.py +++ b/src/torchjd/scalarization/_dwa.py @@ -37,9 +37,6 @@ class DWA(Scalarizer, Stateful): epoch's average loss and roll the history forward. During the first two epochs (before two averages are available) the weights are uniform. - Unlike :class:`~torchjd.scalarization.UW` and :class:`~torchjd.scalarization.IMTLL`, DWA has no - learnable parameters; its state is a non-trainable buffer cleared by :meth:`reset`. - :param temperature: The temperature :math:`T`. Must be strictly positive. Larger values make the weights more uniform. The paper uses ``2.0``. From c0527dde0510d58434cf83f057d78a12b62e344a Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri Date: Thu, 11 Jun 2026 15:46:23 +0530 Subject: [PATCH 4/7] Update tests/unit/scalarization/test_dwa.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- tests/unit/scalarization/test_dwa.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/unit/scalarization/test_dwa.py b/tests/unit/scalarization/test_dwa.py index 446a18fb..ed0ea300 100644 --- a/tests/unit/scalarization/test_dwa.py +++ b/tests/unit/scalarization/test_dwa.py @@ -116,10 +116,6 @@ def test_step_without_forward_is_noop() -> None: assert dwa._previous_averages == [] -def test_has_no_learnable_parameters() -> None: - assert list(DWA().parameters()) == [] - - def test_supports_consistently_negative_losses() -> None: # DWA works on negative losses too, as long as each value keeps a consistent sign: the ratio of # same-sign losses is positive, so the weights match those of the equivalent positive case. From b9ae04b4ef7b8f8902c0b7237542f326afc88466 Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri Date: Thu, 11 Jun 2026 15:46:39 +0530 Subject: [PATCH 5/7] Update tests/unit/scalarization/test_dwa.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- tests/unit/scalarization/test_dwa.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/unit/scalarization/test_dwa.py b/tests/unit/scalarization/test_dwa.py index ed0ea300..d38fe0c1 100644 --- a/tests/unit/scalarization/test_dwa.py +++ b/tests/unit/scalarization/test_dwa.py @@ -9,12 +9,6 @@ from ._inputs import all_inputs -def test_value_bootstrap() -> None: - # Before two epochs have completed, DWA uses uniform weights (so it acts like a sum). - values = tensor_([1.0, 2.0, 4.0]) - torch.testing.assert_close(DWA()(values), tensor_(7.0)) - - def test_uniform_weights_for_first_two_epochs() -> None: dwa = DWA(temperature=2.0) # Epoch 1: no completed epoch yet, so weights are uniform (sum). From ad4df9fb191d3ef23fb0b86c43b1044b39999c6c Mon Sep 17 00:00:00 2001 From: ppraneth Date: Thu, 11 Jun 2026 16:00:23 +0530 Subject: [PATCH 6/7] fix doc Signed-off-by: ppraneth --- src/torchjd/scalarization/_dwa.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/torchjd/scalarization/_dwa.py b/src/torchjd/scalarization/_dwa.py index 0f26a1b7..bcc15c23 100644 --- a/src/torchjd/scalarization/_dwa.py +++ b/src/torchjd/scalarization/_dwa.py @@ -15,27 +15,28 @@ class DWA(Scalarizer, Stateful): `_. DWA weights each value by how quickly its loss has been decreasing relative to the others. At - epoch :math:`t`, the values are combined as + epoch :math:`t`, the current batch's values are combined as .. math:: - \sum_k \lambda_k(t)\, L_k(t), \qquad + \sum_k \lambda_k(t)\, \ell_k, \qquad \lambda_k(t) = \frac{K \exp(w_k(t-1) / T)}{\sum_i \exp(w_i(t-1) / T)}, \qquad w_k(t-1) = \frac{L_k(t-1)}{L_k(t-2)} where: - - :math:`L_k(t)` is the :math:`k`-th value (typically the loss of task :math:`k`) at epoch - :math:`t`, averaged over the batches of that epoch; - - :math:`w_k(t-1)` is the relative descending rate of task :math:`k` (the ratio of its average - losses over the two previous epochs); - - :math:`T` is the temperature (a larger :math:`T` makes the weights more uniform); - - :math:`K` is the number of values, and the factor :math:`K` keeps :math:`\sum_k \lambda_k = K`. - - The weights depend only on the average losses of the two previous epochs, so they are computed - from past values and require no gradient. At each call, the current epoch's losses are - accumulated, and :meth:`step` must be called once at the end of each epoch to finalize that - epoch's average loss and roll the history forward. During the first two epochs (before two - averages are available) the weights are uniform. + - :math:`\ell_k` is the :math:`k`-th value being scalarized (typically the current batch's loss); + - :math:`L_k(t)` is the :math:`k`-th value averaged over epoch :math:`t` (used only for the + weights); + - :math:`w_k(t-1)` is the relative descending rate: the ratio of average losses over the two + previous epochs; + - :math:`T` is the temperature; a larger :math:`T` makes the weights more uniform; + - :math:`K` is the number of values; the factor :math:`K` keeps :math:`\sum_k \lambda_k = K`. + + The weights use only the two previous epochs' average losses, so they need no gradient. At each + call, the scalarization is returned and the current batch's losses are summed to the current + epoch's loss sums. :meth:`step` must then be called once at the end of each epoch to finalize + that epoch's average loss and roll the history forward. During the first two epochs, before two + averages are available, the weights are uniform. :param temperature: The temperature :math:`T`. Must be strictly positive. Larger values make the weights more uniform. The paper uses ``2.0``. From 1fd703817d398282835ec5d369d84e674ce86e1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 11 Jun 2026 14:21:19 +0200 Subject: [PATCH 7/7] Minor improvement of the docs --- src/torchjd/scalarization/_dwa.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchjd/scalarization/_dwa.py b/src/torchjd/scalarization/_dwa.py index bcc15c23..d6d3559f 100644 --- a/src/torchjd/scalarization/_dwa.py +++ b/src/torchjd/scalarization/_dwa.py @@ -24,13 +24,15 @@ class DWA(Scalarizer, Stateful): where: - - :math:`\ell_k` is the :math:`k`-th value being scalarized (typically the current batch's loss); + - :math:`\ell_k` is the :math:`k`-th value being scalarized (typically the current batch's loss + for task k); - :math:`L_k(t)` is the :math:`k`-th value averaged over epoch :math:`t` (used only for the weights); - :math:`w_k(t-1)` is the relative descending rate: the ratio of average losses over the two previous epochs; - :math:`T` is the temperature; a larger :math:`T` makes the weights more uniform; - - :math:`K` is the number of values; the factor :math:`K` keeps :math:`\sum_k \lambda_k = K`. + - :math:`K` is the number of values (e.g. the number of tasks); the factor :math:`K` keeps + :math:`\sum_k \lambda_k = K`. The weights use only the two previous epochs' average losses, so they need no gradient. At each call, the scalarization is returned and the current batch's losses are summed to the current