diff --git a/CHANGELOG.md b/CHANGELOG.md index 389a403e..f2daafd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `SDMGradWeighting` from [Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms](https://arxiv.org/pdf/2305.18409) (NeurIPS 2023). It is a stateful `Weighting` that solves for task weights via a simplex-projected inner loop on a cross-batch matrix `A = J_1 @ J_2.T` (computed from two independent mini-batches using `autojac.jac`), with a direction-oriented regularizer pulling the descent direction toward a preference direction. - Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful `Scalarizer` that learns a per-task scale `s_i` and combines the values as diff --git a/NOTICES b/NOTICES index 07c3e851..b7dbf617 100644 --- a/NOTICES +++ b/NOTICES @@ -140,3 +140,31 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------- + +Project: SDMGrad +Source: https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py +Used in: src/torchjd/aggregation/_sdmgrad.py + +MIT License + +Copyright (c) 2023 ml-opt-lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 4a5060e6..867e8d8c 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -41,5 +41,6 @@ Abstract base classes nash_mtl.rst pcgrad.rst random.rst + sdmgrad.rst sum.rst trimmed_mean.rst diff --git a/docs/source/docs/aggregation/sdmgrad.rst b/docs/source/docs/aggregation/sdmgrad.rst new file mode 100644 index 00000000..272c953d --- /dev/null +++ b/docs/source/docs/aggregation/sdmgrad.rst @@ -0,0 +1,7 @@ +:hide-toc: + +SDMGrad +======= + +.. autoclass:: torchjd.aggregation.SDMGradWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 2b77ae32..5285e6bf 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -56,6 +56,7 @@ from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting +from ._sdmgrad import SDMGradWeighting from ._sum import Sum, SumWeighting from ._trimmed_mean import TrimmedMean from ._upgrad import UPGrad, UPGradWeighting @@ -93,6 +94,7 @@ "PCGradWeighting", "Random", "RandomWeighting", + "SDMGradWeighting", "Sum", "SumWeighting", "TrimmedMean", diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index 4856430a..a7b6fc0e 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -11,6 +11,7 @@ from torchjd.aggregation._mixins import _NonDifferentiable from torchjd.linalg import Matrix +from ._utils.simplex import _projection2simplex from ._weighting_bases import _MatrixWeighting @@ -166,27 +167,11 @@ def forward(self, matrix: Matrix, /) -> Tensor: lambd = cast(Tensor, self._lambda) grad = matrix @ lambd + self._rho * lambd - lambd = self._projection2simplex(lambd - self._gamma * grad) + lambd = _projection2simplex(lambd - self._gamma * grad) self._lambda = lambd return lambd - @staticmethod - def _projection2simplex(y: Tensor) -> Tensor: - """Euclidean projection of ``y`` onto the probability simplex.""" - - m = len(y) - sorted_y = torch.sort(y, descending=True)[0] - tmpsum = y.new_zeros(()) - tmax_f = (torch.sum(y) - 1.0) / m - for i in range(m - 1): - tmpsum = tmpsum + sorted_y[i] - tmax = (tmpsum - 1.0) / (i + 1.0) - if tmax > sorted_y[i + 1]: - tmax_f = tmax - break - return torch.max(y - tmax_f, y.new_zeros(m)) - def _ensure_state(self, matrix: Matrix) -> None: key = (matrix.shape[0], matrix.dtype, matrix.device) if self._state_key == key and self._lambda is not None: diff --git a/src/torchjd/aggregation/_sdmgrad.py b/src/torchjd/aggregation/_sdmgrad.py new file mode 100644 index 00000000..e4eb5ab4 --- /dev/null +++ b/src/torchjd/aggregation/_sdmgrad.py @@ -0,0 +1,214 @@ +# Partly adapted from https://github.com/OptMN-Lab/SDMGrad — MIT License, Copyright (c) 2023 ml-opt-lab. +# See NOTICES for the full license text. +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd._mixins import Stateful +from torchjd.aggregation._mixins import _NonDifferentiable +from torchjd.linalg import Matrix + +from ._utils.simplex import _projection2simplex +from ._weighting_bases import _MatrixWeighting + + +class SDMGradWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): + r""" + :class:`~torchjd.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Direction-oriented + Multi-objective Learning: Simple and Provable Stochastic Algorithms + `_ (NeurIPS 2023). + + .. warning:: + The input matrix must be :math:`A = J_1 J_2^\top`, computed from two **independent** + mini-batches via :func:`torchjd.autojac.jac`. It is **not** a Gramian and is not symmetric + or positive semi-definite in general. See the usage examples below. + + :param lr: Learning rate of the inner SGD that solves for the task weights. Must be positive. + :param momentum: Momentum of the inner SGD. Must be in :math:`[0, 1)`. + :param n_iter: Number of inner SGD iterations performed at each call. Must be positive. + :param lambda_: Non-negative coefficient controlling how strongly the descent direction is pulled + toward the preference direction. Must be non-negative. + :param pref_vector: The preference vector :math:`\tilde w` defining the target direction. If not + provided, defaults to the uniform vector :math:`[1/m, \ldots, 1/m]` (i.e. the target diection is the average gradient). + + .. note:: + The inner simplex-projected solver is adapted from the `official implementation + `_. Note that the + official class default for this coefficient is ``0.6``, overridden to ``0.3`` in their own + experiments, which is the value used here (and in `LibMTL + `_). + + Before the inner solve, the input matrix is scale-normalized by the mean of the square + roots of its non-negative diagonal entries (following both the official implementation and + LibMTL). This makes the inner SGD learning rate scale-invariant with respect to gradient + magnitude. The normalization is briefly described in section 6.1 of the paper. + + .. admonition:: Example (three batches per step) + + The following example shows how to train with the SDMGrad algorithm. + + .. testcode:: + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import SDMGradWeighting + from torchjd.autojac import jac + + # Generate data (9 batches of 16 examples of dim 5) for the sake of the example. + inputs = torch.randn(9, 16, 5) + targets = torch.randn(9, 16) + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = SDMGradWeighting(lambda_=0.3) + params = list(model.parameters()) + + # Consume three consecutive (independent) batches per step. + for i in range(len(inputs) // 3): + # Batches corresponding to ξ, ξ' and ζ in the paper's algorithm. + input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2] + target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2] + + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + A = J_1 @ J_2.T + weights = weighting(A) + + losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) + losses_3.backward(weights) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__( + self, + lr: float = 10.0, + momentum: float = 0.5, + n_iter: int = 20, + lambda_: float = 0.3, + pref_vector: Tensor | None = None, + ) -> None: + super().__init__() + self.lr = lr + self.momentum = momentum + self.n_iter = n_iter + self.lambda_ = lambda_ + self.pref_vector = pref_vector + self._w: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def lr(self) -> float: + return self._lr + + @lr.setter + def lr(self, value: float) -> None: + if value <= 0.0: + raise ValueError(f"Attribute `lr` must be positive. Found lr={value!r}.") + self._lr = value + + @property + def momentum(self) -> float: + return self._momentum + + @momentum.setter + def momentum(self, value: float) -> None: + if not (0.0 <= value < 1.0): + raise ValueError(f"Attribute `momentum` must be in [0, 1). Found momentum={value!r}.") + self._momentum = value + + @property + def n_iter(self) -> int: + return self._n_iter + + @n_iter.setter + def n_iter(self, value: int) -> None: + if value < 1: + raise ValueError( + f"Attribute `n_iter` must be a positive integer. Found n_iter={value!r}." + ) + self._n_iter = value + + @property + def lambda_(self) -> float: + return self._lambda + + @lambda_.setter + def lambda_(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `lambda_` must be non-negative. Found lambda_={value!r}.") + self._lambda = value + + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + if value is not None and value.ndim != 1: + raise ValueError( + "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " + f"{value.ndim}`." + ) + self._pref_vector = value + + def reset(self) -> None: + """Clears the stored task weights so the next forward starts from uniform.""" + + self._w = None + self._state_key = None + + def forward(self, matrix: Matrix, /) -> Tensor: + self._ensure_state(matrix) + w = cast(Tensor, self._w) + w_tilde = self._resolve_w_tilde(matrix) + + diag = torch.diag(matrix).clamp(min=0.0) + scale = diag.sqrt().mean() + a = matrix / (scale.pow(2) + 1e-8) + + velocity: Tensor | None = None + for _ in range(self._n_iter): + grad = a @ (w + self._lambda * w_tilde) + velocity = grad if velocity is None else self._momentum * velocity + grad + w = _projection2simplex(w - self._lr * velocity) + + self._w = w + return (w + self._lambda * w_tilde) / (1.0 + self._lambda) + + def _resolve_w_tilde(self, matrix: Matrix) -> Tensor: + m = matrix.shape[0] + if self._pref_vector is None: + return matrix.new_full((m,), 1.0 / m) + if self._pref_vector.shape[0] != m: + raise ValueError( + "The length of `pref_vector` must match the number of rows of the input matrix. " + f"Found len(pref_vector)={self._pref_vector.shape[0]} and matrix.shape[0]={m}." + ) + return self._pref_vector.to(dtype=matrix.dtype, device=matrix.device) + + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.dtype, matrix.device) + if self._state_key == key and self._w is not None: + return + self._w = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) + self._state_key = key + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(lr={self.lr!r}, momentum={self.momentum!r}, " + f"n_iter={self.n_iter!r}, lambda_={self.lambda_!r}, pref_vector={self.pref_vector!r})" + ) diff --git a/src/torchjd/aggregation/_utils/simplex.py b/src/torchjd/aggregation/_utils/simplex.py new file mode 100644 index 00000000..b146d307 --- /dev/null +++ b/src/torchjd/aggregation/_utils/simplex.py @@ -0,0 +1,18 @@ +import torch +from torch import Tensor + + +def _projection2simplex(y: Tensor) -> Tensor: + """Euclidean projection of ``y`` onto the probability simplex.""" + + m = len(y) + sorted_y = torch.sort(y, descending=True)[0] + tmpsum = y.new_zeros(()) + tmax_f = (torch.sum(y) - 1.0) / m + for i in range(m - 1): + tmpsum = tmpsum + sorted_y[i] + tmax = (tmpsum - 1.0) / (i + 1.0) + if tmax > sorted_y[i + 1]: + tmax_f = tmax + break + return torch.max(y - tmax_f, y.new_zeros(m)) diff --git a/tests/unit/aggregation/_utils/test_simplex.py b/tests/unit/aggregation/_utils/test_simplex.py new file mode 100644 index 00000000..7f3907f5 --- /dev/null +++ b/tests/unit/aggregation/_utils/test_simplex.py @@ -0,0 +1,19 @@ +from torch.testing import assert_close +from utils.tensors import tensor_ + +from torchjd.aggregation._utils.simplex import _projection2simplex + + +def test_projection2simplex_known_values() -> None: + """The simplex projection matches hand-computed Euclidean projections.""" + + # Already-positive input: the deficit (1 - sum) is spread equally, no clamping. + assert_close( + _projection2simplex(tensor_([0.5, 0.1, 0.1])), + tensor_([0.6, 0.2, 0.2]), + ) + # Input with a negative entry: it gets clamped to zero. + assert_close( + _projection2simplex(tensor_([1.0, 0.0, -0.5])), + tensor_([1.0, 0.0, 0.0]), + ) diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index f169ea86..cd3943d1 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -170,18 +170,3 @@ def test_non_symmetric_input() -> None: assert_close(W(G), expected) assert W(G).shape == (m,) assert (W(G) >= 0).all() - - -def test_projection2simplex_known_values() -> None: - """The simplex projection matches hand-computed Euclidean projections.""" - - # Already-positive input: the deficit (1 - sum) is spread equally, no clamping. - assert_close( - MoDoWeighting._projection2simplex(tensor_([0.5, 0.1, 0.1])), - tensor_([0.6, 0.2, 0.2]), - ) - # Input with a negative entry: it gets clamped to zero. - assert_close( - MoDoWeighting._projection2simplex(tensor_([1.0, 0.0, -0.5])), - tensor_([1.0, 0.0, 0.0]), - ) diff --git a/tests/unit/aggregation/test_sdmgrad.py b/tests/unit/aggregation/test_sdmgrad.py new file mode 100644 index 00000000..e846b455 --- /dev/null +++ b/tests/unit/aggregation/test_sdmgrad.py @@ -0,0 +1,172 @@ +import torch +from pytest import raises +from torch.testing import assert_close +from utils.tensors import eye_, randn_, tensor_ + +from torchjd.aggregation._sdmgrad import SDMGradWeighting + + +def test_representations() -> None: + W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lambda_=0.3) + assert ( + repr(W) + == "SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lambda_=0.3, pref_vector=None)" + ) + + +def test_reset_restores_first_step_behavior() -> None: + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + A = J1 @ J2.T + W = SDMGradWeighting() + first = W(A) + W(A) + W.reset() + assert_close(first, W(A)) + + +def test_lr_setter_accepts_valid() -> None: + W = SDMGradWeighting() + W.lr = 0.5 + assert W.lr == 0.5 + + +def test_lr_setter_rejects_non_positive() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="lr"): + W.lr = 0.0 + with raises(ValueError, match="lr"): + W.lr = -1.0 + + +def test_momentum_setter_accepts_valid() -> None: + W = SDMGradWeighting() + W.momentum = 0.0 + assert W.momentum == 0.0 + W.momentum = 0.9 + assert W.momentum == 0.9 + + +def test_momentum_setter_rejects_out_of_range() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="momentum"): + W.momentum = -0.1 + with raises(ValueError, match="momentum"): + W.momentum = 1.0 + + +def test_n_iter_setter_rejects_non_positive() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="n_iter"): + W.n_iter = 0 + + +def test_lambda_setter_accepts_valid() -> None: + W = SDMGradWeighting() + W.lambda_ = 0.0 + assert W.lambda_ == 0.0 + W.lambda_ = 0.6 + assert W.lambda_ == 0.6 + + +def test_lambda_setter_rejects_negative() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="lambda_"): + W.lambda_ = -0.1 + + +def test_pref_vector_setter_rejects_non_1d() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="pref_vector"): + W.pref_vector = tensor_([[0.5, 0.5], [0.5, 0.5]]) + + +def test_output_lies_on_simplex() -> None: + """The simplex projection and (1 + lambda_) normalization keep the weights on the simplex.""" + + J1 = randn_((4, 10)) + J2 = randn_((4, 10)) + A = J1 @ J2.T + W = SDMGradWeighting(lambda_=0.3) + weights = W(A) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_update_recurrence() -> None: + """One inner-solve step matches the manually-computed expected output. + + With A = diag(1, 0.25, 0), n_iter=1, lr=10, lambda_=0.3, starting from uniform w=[1/3,1/3,1/3]: + scale = mean([1, 0.5, 0]) = 0.5 => A_norm = diag(4, 1, 0) + grad = A_norm @ 1.3*[1/3,1/3,1/3] = [52/30, 13/30, 0] + w - lr*grad = [-17, -4, 1/3] -> projected to [0, 0, 1] + return ([0,0,1] + 0.3*[1/3,...]) / 1.3 = [1/13, 1/13, 11/13] + """ + A = torch.diag(tensor_([1.0, 0.25, 0.0])) + W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) + assert_close(W(A), tensor_([1 / 13, 1 / 13, 11 / 13])) + + +def test_two_consecutive_steps() -> None: + """Warm-started carry-over across two consecutive calls matches manually-computed values. + + Step 1: A1=diag(1,0.25,0) -> A_norm=diag(4,1,0), state w=[0,0,1], return [1/13,1/13,11/13] + Step 2: A2=eye(3), scale=1 so A_norm=eye(3), warm start w=[0,0,1]: + grad = [0.1, 0.1, 1.1]; w - lr*grad = [-1,-1,-10] -> projected to [0.5, 0.5, 0] + return ([0.5,0.5,0] + 0.3*[1/3,...]) / 1.3 = [6/13, 6/13, 1/13] + """ + A1 = torch.diag(tensor_([1.0, 0.25, 0.0])) + A2 = eye_(3) + W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) + assert_close(W(A1), tensor_([1 / 13, 1 / 13, 11 / 13])) + assert_close(W(A2), tensor_([6 / 13, 6 / 13, 1 / 13])) + + +def test_custom_pref_vector() -> None: + """A custom preference vector changes the output relative to the uniform default. + + With A=diag(1,0.25,0), A_norm=diag(4,1,0), pref=[0,0,1], n_iter=1, lr=10, lambda_=0.3: + grad = A_norm @ ([1/3,1/3,1/3] + 0.3*[0,0,1]) = [4/3, 1/3, 0] + w - lr*grad = [-13, -3, 1/3] -> projected to [0, 0, 1] + return ([0,0,1] + 0.3*[0,0,1]) / 1.3 = [0, 0, 1] + This differs from the uniform-pref result [1/13, 1/13, 11/13]. + """ + A = torch.diag(tensor_([1.0, 0.25, 0.0])) + pref = tensor_([0.0, 0.0, 1.0]) + W_pref = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3, pref_vector=pref) + assert_close(W_pref(A), tensor_([0.0, 0.0, 1.0])) + + W_default = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) + assert not torch.allclose(W_default(A), tensor_([0.0, 0.0, 1.0])) + + +def test_pref_vector_wrong_length_raises() -> None: + W = SDMGradWeighting(pref_vector=tensor_([0.5, 0.5])) + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + A = J1 @ J2.T + with raises(ValueError, match="pref_vector"): + W(A) + + +def test_changing_m_auto_resets() -> None: + """When the number of objectives changes, the warm-started state is re-initialised to uniform.""" + + W = SDMGradWeighting() + W(randn_((3, 8)) @ randn_((3, 8)).T) + fresh = SDMGradWeighting() + J1 = randn_((2, 8)) + J2 = randn_((2, 8)) + A = J1 @ J2.T + assert_close(W(A), fresh(A)) + + +def test_non_differentiable() -> None: + """The _NonDifferentiable mixin must prevent autograd graph construction.""" + + A = randn_((3, 8)) @ randn_((3, 8)).T + A.requires_grad_(True) + W = SDMGradWeighting() + weights = W(A) + assert not weights.requires_grad