Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards
Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful
`Scalarizer` that learns a per-task scale `s_i` and combines the values as
`Σ (exp(s_i) · L_i − s_i)`.
- Added `UW` (Uncertainty Weighting) from [Multi-Task Learning Using Uncertainty to Weigh Losses
for Scene Geometry and
Semantics](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf),
Expand Down
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/imtl_l.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

IMTL-L
======

.. autoclass:: torchjd.scalarization.IMTLL
:members: __call__, reset
1 change: 1 addition & 0 deletions docs/source/docs/scalarization/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Abstract base class

constant.rst
geometric_mean.rst
imtl_l.rst
mean.rst
random.rst
stch.rst
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/scalarization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ._constant import Constant
from ._geometric_mean import GeometricMean
from ._imtl_l import IMTLL
from ._mean import Mean
from ._random import Random
from ._scalarizer_base import Scalarizer
Expand All @@ -31,6 +32,7 @@
__all__ = [
"Constant",
"GeometricMean",
"IMTLL",
"Mean",
"Random",
"Scalarizer",
Expand Down
91 changes: 91 additions & 0 deletions src/torchjd/scalarization/_imtl_l.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from collections.abc import Sequence

import torch
from torch import Tensor, nn

from torchjd._mixins import Stateful

from ._scalarizer_base import Scalarizer


class IMTLL(Scalarizer, Stateful):
r"""
:class:`~torchjd.Stateful`
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using learned
per-task scales. ``IMTL-L`` is the loss-balancing variant of Impartial
Multi-Task Learning, proposed in `Towards Impartial Multi-Task Learning
<https://openreview.net/pdf?id=IMPnRXEWpvr>`_.

Each value :math:`L_i` is assigned a learnable scale parameter :math:`s_i`, and the values are
combined as

.. math::
\sum_i \left( e^{s_i} L_i - s_i \right)

where:

- :math:`L_i` is the :math:`i`-th value (typically the loss of task :math:`i`);
- :math:`s_i` is the learnable scale parameter of task :math:`i`.

The factor :math:`e^{s_i}` rescales each loss so that the scaled losses stay at a comparable
magnitude across tasks, while the :math:`- s_i` term is a regularizer that prevents the trivial
solution :math:`s_i \to -\infty`. The :math:`s_i` are stored as an ``nn.Parameter``, so the
parameters of this scalarizer must be passed to the optimizer to be learned jointly with the
model.

Although it is derived without any distribution assumption (unlike
:class:`~torchjd.scalarization.UW`, which is derived from Gaussian/Laplace likelihoods), IMTL-L
is in fact almost equivalent to :class:`~torchjd.scalarization.UW`: this scalarization equals
:math:`2\,\mathrm{UW}` evaluated at the negated parameter, so the two differ only by a constant
factor of two and the sign convention of the learned parameter, and share the same per-task
weighting and the same optima.

The complementary gradient-balancing variant (IMTL-G) is provided as the
:class:`~torchjd.aggregation.IMTLG` aggregator.

:param shape: The shape of the values to scalarize, used to create one scale per value. An
``int`` ``n`` is interpreted as the shape ``(n,)``.

The following example shows how to train a model with Impartial Multi-Task Learning (loss
balance), as described in the paper.

>>> import torch
>>> from torch.nn import Linear
>>>
>>> from torchjd.scalarization import IMTLL
>>>
>>> model = Linear(3, 2)
>>> scalarizer = IMTLL(2) # Move to the right device with e.g. IMTLL(2).to(device="cuda")
>>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1)
>>>
>>> features = torch.randn(8, 3)
>>> # Compute some dummy losses just for the sake of the example
>>> losses = model(features).pow(2).mean(dim=0) # One loss per output dimension.
>>> loss = scalarizer(losses)
>>> loss.backward()
>>> optimizer.step()

.. note::
The scales are initialized to ``0``, so at the start of training the scalarization reduces to
the plain sum of the values (since :math:`e^0 = 1`). Following the paper, IMTL-L is designed
to balance positive losses.
"""

def __init__(self, shape: int | Sequence[int]) -> None:
super().__init__()
self.log_scale = nn.Parameter(torch.zeros(shape))

def forward(self, values: Tensor, /) -> Tensor:
if values.shape != self.log_scale.shape:
raise ValueError(
f"Parameter `values` should have shape {tuple(self.log_scale.shape)} (matching the "
f"shape of the scales). Found `values.shape = {tuple(values.shape)}`.",
)
return (torch.exp(self.log_scale) * values - self.log_scale).sum()

def reset(self) -> None:
with torch.no_grad():
self.log_scale.zero_()

def __repr__(self) -> str:
return f"{self.__class__.__name__}(shape={tuple(self.log_scale.shape)})"
113 changes: 113 additions & 0 deletions tests/unit/scalarization/test_imtl_l.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from contextlib import nullcontext as does_not_raise

import torch
from pytest import mark, raises
from settings import DEVICE, DTYPE
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_, tensor_, zeros_

from torchjd.scalarization import IMTLL, UW

from ._asserts import assert_grad_flow, assert_returns_scalar
from ._inputs import all_inputs


def _imtl_l(shape: int | tuple[int, ...]) -> IMTLL:
"""Builds an `IMTLL` whose scales live on the test device and dtype."""
return IMTLL(shape).to(device=DEVICE, dtype=DTYPE)


def test_value() -> None:
# With scales initialized to 0, exp(0)=1 and -0=0, so the result is sum(values).
values = tensor_([1.0, 2.0, 4.0])
torch.testing.assert_close(_imtl_l((3,))(values), tensor_(7.0))


def test_int_shape_matches_tuple_shape() -> None:
values = tensor_([1.0, 2.0, 4.0])
assert IMTLL(3).log_scale.shape == (3,)
torch.testing.assert_close(_imtl_l(3)(values), _imtl_l((3,))(values))


@mark.parametrize("values", all_inputs)
def test_expected_structure(values: Tensor) -> None:
assert_returns_scalar(_imtl_l(tuple(values.shape)), values)


@mark.parametrize("values", all_inputs)
def test_grad_flow(values: Tensor) -> None:
assert_grad_flow(_imtl_l(tuple(values.shape)), values)


@mark.parametrize("values", all_inputs)
def test_grad_flows_to_log_scale(values: Tensor) -> None:
scalarizer = _imtl_l(tuple(values.shape))
scalarizer(values).backward()
assert scalarizer.log_scale.grad is not None
assert scalarizer.log_scale.grad.isfinite().all()


@mark.parametrize(
["param_shape", "values_shape", "expectation"],
[
((5,), (5,), does_not_raise()),
((3, 4), (3, 4), does_not_raise()),
((), (), does_not_raise()),
((5,), (4,), raises(ValueError)),
((5,), (5, 1), raises(ValueError)),
((3, 4), (4, 3), raises(ValueError)),
],
)
def test_shape_check(
param_shape: tuple[int, ...],
values_shape: tuple[int, ...],
expectation: ExceptionContext,
) -> None:
scalarizer = _imtl_l(param_shape)
values = ones_(values_shape)
with expectation:
_ = scalarizer(values)


def test_reset_restores_initial_log_scale() -> None:
scalarizer = _imtl_l((3,))
with torch.no_grad():
scalarizer.log_scale.add_(1.0)
scalarizer.reset()
torch.testing.assert_close(scalarizer.log_scale.detach(), zeros_((3,)))


def test_does_not_raise_on_negative_input() -> None:
# IMTL-L is designed for positive losses but does not enforce a positivity precondition.
values = tensor_([-1.0, -2.0, 3.0])
assert_returns_scalar(_imtl_l((3,)), values)


def test_is_trainable() -> None:
scalarizer = _imtl_l((2,))
optimizer = torch.optim.SGD(scalarizer.parameters(), lr=0.1)
values = tensor_([2.0, 5.0])
optimizer.zero_grad()
scalarizer(values).backward()
optimizer.step()
assert not torch.equal(scalarizer.log_scale.detach(), zeros_((2,)))


def test_equivalent_to_uw_up_to_factor_and_sign() -> None:
# Locks the documented relationship: IMTL-L(s) == 2 * UW(-s), i.e. the two scalarizations are
# equal up to a constant factor of 2 and the sign of the learned parameter.
values = tensor_([0.5, 2.0, 4.0])
imtl_l = _imtl_l((3,))
uw = UW((3,)).to(device=DEVICE, dtype=DTYPE)
with torch.no_grad():
s = tensor_([0.3, -0.7, 1.2])
imtl_l.log_scale.copy_(s)
uw.log_var.copy_(-s)
torch.testing.assert_close(imtl_l(values), 2.0 * uw(values))


def test_representations() -> None:
assert repr(IMTLL(3)) == "IMTLL(shape=(3,))"
assert repr(IMTLL((2, 3))) == "IMTLL(shape=(2, 3))"
assert str(IMTLL(3)) == "IMTLL"
Loading