Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `DWA` (Dynamic Weight Average) from [End-to-End Multi-Task Learning with
Attention](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf)
(CVPR 2019), a stateful `Scalarizer` that weights each value by the relative rate at which its
loss decreased over the two previous epochs. It has no learnable parameters; call its `step()`
method once per epoch to roll the loss history.
- Added `SDMGradWeighting` from [Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms](https://arxiv.org/pdf/2305.18409) (NeurIPS 2023). It is a stateful `Weighting` that solves for task weights via a simplex-projected inner loop on a cross-batch matrix `A = J_1 @ J_2.T` (computed from two independent mini-batches using `autojac.jac`), with a direction-oriented regularizer pulling the descent direction toward a preference direction.
- Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards
Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful
Expand Down
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/dwa.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

DWA
===

.. autoclass:: torchjd.scalarization.DWA
:members: __call__, step, reset
1 change: 1 addition & 0 deletions docs/source/docs/scalarization/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Abstract base class
:maxdepth: 1

constant.rst
dwa.rst
geometric_mean.rst
imtl_l.rst
mean.rst
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/scalarization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

from ._constant import Constant
from ._dwa import DWA
from ._geometric_mean import GeometricMean
from ._imtl_l import IMTLL
from ._mean import Mean
Expand All @@ -31,6 +32,7 @@

__all__ = [
"Constant",
"DWA",
"GeometricMean",
"IMTLL",
"Mean",
Expand Down
137 changes: 137 additions & 0 deletions src/torchjd/scalarization/_dwa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from torch import Tensor
from torch.nn.functional import softmax

from torchjd._mixins import Stateful

from ._scalarizer_base import Scalarizer


class DWA(Scalarizer, Stateful):
r"""
:class:`~torchjd.Stateful`
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using Dynamic
Weight Average (DWA), proposed in `End-to-End Multi-Task Learning with Attention
<https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf>`_.

DWA weights each value by how quickly its loss has been decreasing relative to the others. At
epoch :math:`t`, the current batch's values are combined as

.. math::
\sum_k \lambda_k(t)\, \ell_k, \qquad
\lambda_k(t) = \frac{K \exp(w_k(t-1) / T)}{\sum_i \exp(w_i(t-1) / T)}, \qquad
w_k(t-1) = \frac{L_k(t-1)}{L_k(t-2)}
Comment thread
ValerianRey marked this conversation as resolved.

where:

- :math:`\ell_k` is the :math:`k`-th value being scalarized (typically the current batch's loss
for task k);
- :math:`L_k(t)` is the :math:`k`-th value averaged over epoch :math:`t` (used only for the
weights);
- :math:`w_k(t-1)` is the relative descending rate: the ratio of average losses over the two
previous epochs;
- :math:`T` is the temperature; a larger :math:`T` makes the weights more uniform;
- :math:`K` is the number of values (e.g. the number of tasks); the factor :math:`K` keeps
:math:`\sum_k \lambda_k = K`.

The weights use only the two previous epochs' average losses, so they need no gradient. At each
call, the scalarization is returned and the current batch's losses are summed to the current
epoch's loss sums. :meth:`step` must then be called once at the end of each epoch to finalize
that epoch's average loss and roll the history forward. During the first two epochs, before two
averages are available, the weights are uniform.

:param temperature: The temperature :math:`T`. Must be strictly positive. Larger values make the
weights more uniform. The paper uses ``2.0``.

The following example shows how to train a model with DWA. The scalarizer is called on every
batch, and :meth:`step` is called once at the end of each epoch.

>>> import torch
>>> from torch.nn import Linear
>>>
>>> from torchjd.scalarization import DWA
>>>
>>> model = Linear(3, 2)
>>> scalarizer = DWA()
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
>>>
>>> for epoch in range(3):
... for _ in range(4): # Iterate over the batches of the epoch.
... features = torch.randn(8, 3)
... losses = model(features).pow(2).mean(dim=0) # One loss per output dimension.
... loss = scalarizer(losses)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
... scalarizer.step() # Roll the epoch history once, at the end of the epoch.

.. note::
DWA weights each value by the ratio of its losses over consecutive epochs, which the paper
defines as a descending rate in the range :math:`(0, +\infty)`. The losses are therefore
expected to keep a consistent, nonzero sign across epochs (they need not be positive).
"""

def __init__(self, temperature: float = 2.0) -> None:
if temperature <= 0.0:
raise ValueError(
f"Parameter `temperature` should be strictly positive. Found `temperature = "
f"{temperature}`."
)

super().__init__()
self.temperature = temperature
self._loss_sum: Tensor | None = None
self._n_batches: int = 0
self._previous_averages: list[Tensor] = []

def forward(self, values: Tensor, /) -> Tensor:
weights = self._compute_weights(values)

detached = values.detach()
if self._loss_sum is None:
self._loss_sum = detached.clone()
elif self._loss_sum.shape != detached.shape:
raise ValueError(
f"The shape of `values` changed from {tuple(self._loss_sum.shape)} to "
f"{tuple(detached.shape)} within an epoch. Call `reset()` before changing it."
)
else:
self._loss_sum = self._loss_sum + detached
self._n_batches += 1

return (weights * values).sum()

def step(self) -> None:
"""
Finalizes the current epoch's average loss and rolls the history forward, discarding the
average from two epochs ago. Should be called once at the end of each epoch.
"""

if self._loss_sum is None:
return
average = self._loss_sum / self._n_batches
self._previous_averages = [*self._previous_averages, average][-2:]
self._loss_sum = None
self._n_batches = 0

def reset(self) -> None:
self._loss_sum = None
self._n_batches = 0
self._previous_averages = []

def _compute_weights(self, values: Tensor) -> Tensor:
if len(self._previous_averages) < 2:
return torch.ones_like(values)
older = self._previous_averages[0]
newer = self._previous_averages[1]
if older.shape != values.shape:
raise ValueError(
f"The shape of `values` changed from {tuple(older.shape)} to "
f"{tuple(values.shape)}. Call `reset()` before changing it."
)
rates = (newer / older).flatten()
weights = softmax(rates / self.temperature, dim=0)
return values.numel() * weights.reshape(values.shape)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(temperature={self.temperature})"
153 changes: 153 additions & 0 deletions tests/unit/scalarization/test_dwa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
from pytest import mark, raises
from torch import Tensor
from utils.tensors import ones_, tensor_

from torchjd.scalarization import DWA

from ._asserts import assert_grad_flow, assert_returns_scalar
from ._inputs import all_inputs


def test_uniform_weights_for_first_two_epochs() -> None:
dwa = DWA(temperature=2.0)
# Epoch 1: no completed epoch yet, so weights are uniform (sum).
torch.testing.assert_close(dwa(tensor_([1.0, 3.0])), tensor_(4.0))
dwa.step()
# Epoch 2: only one completed epoch, so weights are still uniform (sum).
torch.testing.assert_close(dwa(tensor_([2.0, 5.0])), tensor_(7.0))
dwa.step()


def test_weights_from_previous_two_epochs() -> None:
dwa = DWA(temperature=2.0)
dwa(tensor_([1.0, 1.0]))
dwa.step() # Epoch 1 average = [1, 1].
dwa(tensor_([1.0, 4.0]))
dwa.step() # Epoch 2 average = [1, 4].
# Epoch 3: rates = [1, 4] / [1, 1] = [1, 4].
losses = tensor_([3.0, 5.0])
result = dwa(losses)
expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0)
torch.testing.assert_close(result, (expected_weights * losses).sum())


def test_uses_per_epoch_average() -> None:
# The weights use the average loss over each epoch's batches, not just the last batch.
dwa = DWA(temperature=2.0)
dwa(tensor_([2.0, 2.0]))
dwa(tensor_([0.0, 0.0]))
dwa.step() # Epoch 1 average = [1, 1].
dwa(tensor_([2.0, 6.0]))
dwa(tensor_([0.0, 2.0]))
dwa.step() # Epoch 2 average = [1, 4].
losses = tensor_([3.0, 5.0])
result = dwa(losses)
expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0)
torch.testing.assert_close(result, (expected_weights * losses).sum())


def test_step_discards_oldest_epoch() -> None:
dwa = DWA(temperature=2.0)
dwa(tensor_([9.0, 9.0]))
dwa.step() # Epoch 1 average = [9, 9]; should be discarded after epoch 3.
dwa(tensor_([1.0, 1.0]))
dwa.step() # Epoch 2 average = [1, 1].
dwa(tensor_([1.0, 4.0]))
dwa.step() # Epoch 3 average = [1, 4].
# Epoch 4 uses only epochs 2 and 3: rates = [1, 4] / [1, 1] = [1, 4].
losses = tensor_([3.0, 5.0])
result = dwa(losses)
expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0)
torch.testing.assert_close(result, (expected_weights * losses).sum())


def test_weights_sum_to_numel() -> None:
dwa = DWA()
dwa(tensor_([1.0, 2.0]))
dwa.step()
dwa(tensor_([2.0, 1.0]))
dwa.step()
# The weights sum to the number of elements, so weighting a vector of ones gives that count.
torch.testing.assert_close(dwa(ones_((2,))), tensor_(2.0))


@mark.parametrize("values", all_inputs)
def test_expected_structure(values: Tensor) -> None:
assert_returns_scalar(DWA(), values)


@mark.parametrize("values", all_inputs)
def test_grad_flow(values: Tensor) -> None:
assert_grad_flow(DWA(), values)


def test_grad_flows_with_computed_weights() -> None:
# After two epochs the weights are computed from the (detached) loss history; gradients must
# still flow to the current values.
dwa = DWA(temperature=2.0)
dwa(tensor_([1.0, 1.0]))
dwa.step()
dwa(tensor_([1.0, 4.0]))
dwa.step()
assert_grad_flow(dwa, tensor_([3.0, 5.0]))


def test_reset() -> None:
dwa = DWA()
dwa(tensor_([1.0, 2.0]))
dwa.step()
dwa(tensor_([3.0, 4.0]))
dwa.reset()
assert dwa._previous_averages == []
assert dwa._loss_sum is None
assert dwa._n_batches == 0


def test_step_without_forward_is_noop() -> None:
dwa = DWA()
dwa.step() # No losses accumulated yet.
assert dwa._previous_averages == []


def test_supports_consistently_negative_losses() -> None:
# DWA works on negative losses too, as long as each value keeps a consistent sign: the ratio of
# same-sign losses is positive, so the weights match those of the equivalent positive case.
dwa = DWA(temperature=2.0)
dwa(tensor_([-2.0, -2.0]))
dwa.step() # Epoch 1 average = [-2, -2].
dwa(tensor_([-2.0, -8.0]))
dwa.step() # Epoch 2 average = [-2, -8]; rates = [-2, -8] / [-2, -2] = [1, 4].
losses = tensor_([3.0, 5.0])
result = dwa(losses)
expected_weights = 2.0 * torch.softmax(tensor_([1.0, 4.0]) / 2.0, dim=0)
torch.testing.assert_close(result, (expected_weights * losses).sum())


def test_raises_on_shape_change_within_epoch() -> None:
dwa = DWA()
dwa(tensor_([1.0, 2.0]))
with raises(ValueError):
dwa(tensor_([1.0, 2.0, 3.0]))


def test_raises_on_shape_change_between_epochs() -> None:
dwa = DWA()
dwa(tensor_([1.0, 2.0]))
dwa.step()
dwa(tensor_([2.0, 1.0]))
dwa.step()
with raises(ValueError):
dwa(tensor_([1.0, 2.0, 3.0]))


@mark.parametrize("temperature", [0.0, -1.0])
def test_raises_on_non_positive_temperature(temperature: float) -> None:
with raises(ValueError):
DWA(temperature=temperature)


def test_representations() -> None:
assert repr(DWA()) == "DWA(temperature=2.0)"
assert repr(DWA(temperature=1.5)) == "DWA(temperature=1.5)"
assert str(DWA()) == "DWA"
Loading