diff --git a/CHANGELOG.md b/CHANGELOG.md index 40ab0a42..2e259ea0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards + Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful + `Scalarizer` that learns a per-task scale `s_i` and combines the values as + `Σ (exp(s_i) · L_i − s_i)`. - Added `UW` (Uncertainty Weighting) from [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf), diff --git a/docs/source/docs/scalarization/imtl_l.rst b/docs/source/docs/scalarization/imtl_l.rst new file mode 100644 index 00000000..c95dca27 --- /dev/null +++ b/docs/source/docs/scalarization/imtl_l.rst @@ -0,0 +1,7 @@ +:hide-toc: + +IMTL-L +====== + +.. autoclass:: torchjd.scalarization.IMTLL + :members: __call__, reset diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst index e1d358af..f2d42be7 100644 --- a/docs/source/docs/scalarization/index.rst +++ b/docs/source/docs/scalarization/index.rst @@ -16,6 +16,7 @@ Abstract base class constant.rst geometric_mean.rst + imtl_l.rst mean.rst random.rst stch.rst diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index a70d07bf..6b37d281 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -21,6 +21,7 @@ from ._constant import Constant from ._geometric_mean import GeometricMean +from ._imtl_l import IMTLL from ._mean import Mean from ._random import Random from ._scalarizer_base import Scalarizer @@ -31,6 +32,7 @@ __all__ = [ "Constant", "GeometricMean", + "IMTLL", "Mean", "Random", "Scalarizer", diff --git a/src/torchjd/scalarization/_imtl_l.py b/src/torchjd/scalarization/_imtl_l.py new file mode 100644 index 00000000..91445f77 --- /dev/null +++ b/src/torchjd/scalarization/_imtl_l.py @@ -0,0 +1,91 @@ +from collections.abc import Sequence + +import torch +from torch import Tensor, nn + +from torchjd._mixins import Stateful + +from ._scalarizer_base import Scalarizer + + +class IMTLL(Scalarizer, Stateful): + r""" + :class:`~torchjd.Stateful` + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using learned + per-task scales. ``IMTL-L`` is the loss-balancing variant of Impartial + Multi-Task Learning, proposed in `Towards Impartial Multi-Task Learning + `_. + + Each value :math:`L_i` is assigned a learnable scale parameter :math:`s_i`, and the values are + combined as + + .. math:: + \sum_i \left( e^{s_i} L_i - s_i \right) + + where: + + - :math:`L_i` is the :math:`i`-th value (typically the loss of task :math:`i`); + - :math:`s_i` is the learnable scale parameter of task :math:`i`. + + The factor :math:`e^{s_i}` rescales each loss so that the scaled losses stay at a comparable + magnitude across tasks, while the :math:`- s_i` term is a regularizer that prevents the trivial + solution :math:`s_i \to -\infty`. The :math:`s_i` are stored as an ``nn.Parameter``, so the + parameters of this scalarizer must be passed to the optimizer to be learned jointly with the + model. + + Although it is derived without any distribution assumption (unlike + :class:`~torchjd.scalarization.UW`, which is derived from Gaussian/Laplace likelihoods), IMTL-L + is in fact almost equivalent to :class:`~torchjd.scalarization.UW`: this scalarization equals + :math:`2\,\mathrm{UW}` evaluated at the negated parameter, so the two differ only by a constant + factor of two and the sign convention of the learned parameter, and share the same per-task + weighting and the same optima. + + The complementary gradient-balancing variant (IMTL-G) is provided as the + :class:`~torchjd.aggregation.IMTLG` aggregator. + + :param shape: The shape of the values to scalarize, used to create one scale per value. An + ``int`` ``n`` is interpreted as the shape ``(n,)``. + + The following example shows how to train a model with Impartial Multi-Task Learning (loss + balance), as described in the paper. + + >>> import torch + >>> from torch.nn import Linear + >>> + >>> from torchjd.scalarization import IMTLL + >>> + >>> model = Linear(3, 2) + >>> scalarizer = IMTLL(2) # Move to the right device with e.g. IMTLL(2).to(device="cuda") + >>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1) + >>> + >>> features = torch.randn(8, 3) + >>> # Compute some dummy losses just for the sake of the example + >>> losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. + >>> loss = scalarizer(losses) + >>> loss.backward() + >>> optimizer.step() + + .. note:: + The scales are initialized to ``0``, so at the start of training the scalarization reduces to + the plain sum of the values (since :math:`e^0 = 1`). Following the paper, IMTL-L is designed + to balance positive losses. + """ + + def __init__(self, shape: int | Sequence[int]) -> None: + super().__init__() + self.log_scale = nn.Parameter(torch.zeros(shape)) + + def forward(self, values: Tensor, /) -> Tensor: + if values.shape != self.log_scale.shape: + raise ValueError( + f"Parameter `values` should have shape {tuple(self.log_scale.shape)} (matching the " + f"shape of the scales). Found `values.shape = {tuple(values.shape)}`.", + ) + return (torch.exp(self.log_scale) * values - self.log_scale).sum() + + def reset(self) -> None: + with torch.no_grad(): + self.log_scale.zero_() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(shape={tuple(self.log_scale.shape)})" diff --git a/tests/unit/scalarization/test_imtl_l.py b/tests/unit/scalarization/test_imtl_l.py new file mode 100644 index 00000000..cdb6790d --- /dev/null +++ b/tests/unit/scalarization/test_imtl_l.py @@ -0,0 +1,113 @@ +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises +from settings import DEVICE, DTYPE +from torch import Tensor +from utils.contexts import ExceptionContext +from utils.tensors import ones_, tensor_, zeros_ + +from torchjd.scalarization import IMTLL, UW + +from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import all_inputs + + +def _imtl_l(shape: int | tuple[int, ...]) -> IMTLL: + """Builds an `IMTLL` whose scales live on the test device and dtype.""" + return IMTLL(shape).to(device=DEVICE, dtype=DTYPE) + + +def test_value() -> None: + # With scales initialized to 0, exp(0)=1 and -0=0, so the result is sum(values). + values = tensor_([1.0, 2.0, 4.0]) + torch.testing.assert_close(_imtl_l((3,))(values), tensor_(7.0)) + + +def test_int_shape_matches_tuple_shape() -> None: + values = tensor_([1.0, 2.0, 4.0]) + assert IMTLL(3).log_scale.shape == (3,) + torch.testing.assert_close(_imtl_l(3)(values), _imtl_l((3,))(values)) + + +@mark.parametrize("values", all_inputs) +def test_expected_structure(values: Tensor) -> None: + assert_returns_scalar(_imtl_l(tuple(values.shape)), values) + + +@mark.parametrize("values", all_inputs) +def test_grad_flow(values: Tensor) -> None: + assert_grad_flow(_imtl_l(tuple(values.shape)), values) + + +@mark.parametrize("values", all_inputs) +def test_grad_flows_to_log_scale(values: Tensor) -> None: + scalarizer = _imtl_l(tuple(values.shape)) + scalarizer(values).backward() + assert scalarizer.log_scale.grad is not None + assert scalarizer.log_scale.grad.isfinite().all() + + +@mark.parametrize( + ["param_shape", "values_shape", "expectation"], + [ + ((5,), (5,), does_not_raise()), + ((3, 4), (3, 4), does_not_raise()), + ((), (), does_not_raise()), + ((5,), (4,), raises(ValueError)), + ((5,), (5, 1), raises(ValueError)), + ((3, 4), (4, 3), raises(ValueError)), + ], +) +def test_shape_check( + param_shape: tuple[int, ...], + values_shape: tuple[int, ...], + expectation: ExceptionContext, +) -> None: + scalarizer = _imtl_l(param_shape) + values = ones_(values_shape) + with expectation: + _ = scalarizer(values) + + +def test_reset_restores_initial_log_scale() -> None: + scalarizer = _imtl_l((3,)) + with torch.no_grad(): + scalarizer.log_scale.add_(1.0) + scalarizer.reset() + torch.testing.assert_close(scalarizer.log_scale.detach(), zeros_((3,))) + + +def test_does_not_raise_on_negative_input() -> None: + # IMTL-L is designed for positive losses but does not enforce a positivity precondition. + values = tensor_([-1.0, -2.0, 3.0]) + assert_returns_scalar(_imtl_l((3,)), values) + + +def test_is_trainable() -> None: + scalarizer = _imtl_l((2,)) + optimizer = torch.optim.SGD(scalarizer.parameters(), lr=0.1) + values = tensor_([2.0, 5.0]) + optimizer.zero_grad() + scalarizer(values).backward() + optimizer.step() + assert not torch.equal(scalarizer.log_scale.detach(), zeros_((2,))) + + +def test_equivalent_to_uw_up_to_factor_and_sign() -> None: + # Locks the documented relationship: IMTL-L(s) == 2 * UW(-s), i.e. the two scalarizations are + # equal up to a constant factor of 2 and the sign of the learned parameter. + values = tensor_([0.5, 2.0, 4.0]) + imtl_l = _imtl_l((3,)) + uw = UW((3,)).to(device=DEVICE, dtype=DTYPE) + with torch.no_grad(): + s = tensor_([0.3, -0.7, 1.2]) + imtl_l.log_scale.copy_(s) + uw.log_var.copy_(-s) + torch.testing.assert_close(imtl_l(values), 2.0 * uw(values)) + + +def test_representations() -> None: + assert repr(IMTLL(3)) == "IMTLL(shape=(3,))" + assert repr(IMTLL((2, 3))) == "IMTLL(shape=(2, 3))" + assert str(IMTLL(3)) == "IMTLL"