From c409c4a41f6fc86e53fe5f02ba132a8d86b1cddd Mon Sep 17 00:00:00 2001 From: Khush Date: Wed, 10 Jun 2026 09:50:59 -0400 Subject: [PATCH 1/4] feat(aggregation): Add SDMGradWeighting --- CHANGELOG.md | 1 + NOTICES | 28 +++ docs/source/docs/aggregation/index.rst | 1 + docs/source/docs/aggregation/sdmgrad.rst | 7 + src/torchjd/aggregation/__init__.py | 2 + src/torchjd/aggregation/_sdmgrad.py | 266 +++++++++++++++++++++++ tests/unit/aggregation/test_sdmgrad.py | 254 ++++++++++++++++++++++ 7 files changed, 559 insertions(+) create mode 100644 docs/source/docs/aggregation/sdmgrad.rst create mode 100644 src/torchjd/aggregation/_sdmgrad.py create mode 100644 tests/unit/aggregation/test_sdmgrad.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e259ea0..4ac85d97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `SDMGradWeighting` from [Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms](https://arxiv.org/pdf/2305.18409) (NeurIPS 2023). It is a stateful `Weighting` that solves for task weights via a simplex-projected inner loop on a cross-batch matrix `A = J_1 @ J_2.T` (computed from two independent mini-batches using `autojac.jac`), with a direction-oriented regularizer pulling the descent direction toward a preference direction. - Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful `Scalarizer` that learns a per-task scale `s_i` and combines the values as diff --git a/NOTICES b/NOTICES index 07c3e851..b7dbf617 100644 --- a/NOTICES +++ b/NOTICES @@ -140,3 +140,31 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------- + +Project: SDMGrad +Source: https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py +Used in: src/torchjd/aggregation/_sdmgrad.py + +MIT License + +Copyright (c) 2023 ml-opt-lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 4a5060e6..867e8d8c 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -41,5 +41,6 @@ Abstract base classes nash_mtl.rst pcgrad.rst random.rst + sdmgrad.rst sum.rst trimmed_mean.rst diff --git a/docs/source/docs/aggregation/sdmgrad.rst b/docs/source/docs/aggregation/sdmgrad.rst new file mode 100644 index 00000000..272c953d --- /dev/null +++ b/docs/source/docs/aggregation/sdmgrad.rst @@ -0,0 +1,7 @@ +:hide-toc: + +SDMGrad +======= + +.. autoclass:: torchjd.aggregation.SDMGradWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 2b77ae32..5285e6bf 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -56,6 +56,7 @@ from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting +from ._sdmgrad import SDMGradWeighting from ._sum import Sum, SumWeighting from ._trimmed_mean import TrimmedMean from ._upgrad import UPGrad, UPGradWeighting @@ -93,6 +94,7 @@ "PCGradWeighting", "Random", "RandomWeighting", + "SDMGradWeighting", "Sum", "SumWeighting", "TrimmedMean", diff --git a/src/torchjd/aggregation/_sdmgrad.py b/src/torchjd/aggregation/_sdmgrad.py new file mode 100644 index 00000000..44a5eb14 --- /dev/null +++ b/src/torchjd/aggregation/_sdmgrad.py @@ -0,0 +1,266 @@ +# Partly adapted from https://github.com/OptMN-Lab/SDMGrad — MIT License, Copyright (c) 2023 ml-opt-lab. +# See NOTICES for the full license text. +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd._mixins import Stateful +from torchjd.aggregation._mixins import _NonDifferentiable +from torchjd.linalg import Matrix + +from ._weighting_bases import _MatrixWeighting + + +class SDMGradWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): + r""" + :class:`~torchjd.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Direction-oriented + Multi-objective Learning: Simple and Provable Stochastic Algorithms + `_ (NeurIPS 2023). + + .. warning:: + The input matrix must be :math:`A = J_1 J_2^\top`, computed from two **independent** + mini-batches via :func:`torchjd.autojac.jac`. It is **not** a Gramian and is not symmetric + or positive semi-definite in general. See the usage examples below. + + :param lr: Learning rate of the inner SGD that solves for the task weights. Must be positive. + :param momentum: Momentum of the inner SGD. Must be in :math:`[0, 1)`. + :param n_iter: Number of inner SGD iterations performed at each call. Must be positive. + :param lamda: Non-negative coefficient controlling how strongly the descent direction is pulled + toward the preference direction. Must be non-negative. + :param pref_vector: The preference vector :math:`\tilde w` defining the target direction. If not + provided, defaults to the uniform vector :math:`[1/m, \ldots, 1/m]` (i.e. the average + gradient). + + .. note:: + The inner simplex-projected solve is adapted from the `official implementation + `_. Note that the + official class default ``lamda=0.6`` is overridden to ``0.3`` in their own experiments, which + is the value used here (and in `LibMTL + `_). + + .. admonition:: Example (two batches per step) + + The following example reproduces SDMGrad using two independent mini-batches per step, reusing + the second batch for the parameter update. + + .. testcode:: + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import SDMGradWeighting + from torchjd.autojac import jac + + # Generate data (8 batches of 16 examples of dim 5) for the sake of the example. + inputs = torch.randn(8, 16, 5) + targets = torch.randn(8, 16) + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = SDMGradWeighting(lamda=0.3) + params = list(model.parameters()) + + # Consume two consecutive (independent) batches per step. + for i in range(len(inputs) // 2): + input_1, input_2 = inputs[2 * i], inputs[2 * i + 1] + target_1, target_2 = targets[2 * i], targets[2 * i + 1] + + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + # retain_graph=True so losses_2's graph survives for the backward step below. + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params, retain_graph=True) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + A = J_1 @ J_2.T + weights = weighting(A) + + losses_2.backward(weights) + optimizer.step() + optimizer.zero_grad() + + .. admonition:: Example (three batches per step) + + The following example reproduces SDMGrad using three independent mini-batches per step, + keeping the weight update and the parameter update on separate draws. + + .. testcode:: + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import SDMGradWeighting + from torchjd.autojac import jac + + # Generate data (9 batches of 16 examples of dim 5) for the sake of the example. + inputs = torch.randn(9, 16, 5) + targets = torch.randn(9, 16) + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = SDMGradWeighting(lamda=0.3) + params = list(model.parameters()) + + # Consume three consecutive (independent) batches per step. + for i in range(len(inputs) // 3): + input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2] + target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2] + + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + A = J_1 @ J_2.T + weights = weighting(A) + + losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) + losses_3.backward(weights) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__( + self, + lr: float = 10.0, + momentum: float = 0.5, + n_iter: int = 20, + lamda: float = 0.3, + pref_vector: Tensor | None = None, + ) -> None: + super().__init__() + self.lr = lr + self.momentum = momentum + self.n_iter = n_iter + self.lamda = lamda + self.pref_vector = pref_vector + self._w: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def lr(self) -> float: + return self._lr + + @lr.setter + def lr(self, value: float) -> None: + if value <= 0.0: + raise ValueError(f"Attribute `lr` must be positive. Found lr={value!r}.") + self._lr = value + + @property + def momentum(self) -> float: + return self._momentum + + @momentum.setter + def momentum(self, value: float) -> None: + if not (0.0 <= value < 1.0): + raise ValueError(f"Attribute `momentum` must be in [0, 1). Found momentum={value!r}.") + self._momentum = value + + @property + def n_iter(self) -> int: + return self._n_iter + + @n_iter.setter + def n_iter(self, value: int) -> None: + if value < 1: + raise ValueError( + f"Attribute `n_iter` must be a positive integer. Found n_iter={value!r}." + ) + self._n_iter = value + + @property + def lamda(self) -> float: + return self._lamda + + @lamda.setter + def lamda(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `lamda` must be non-negative. Found lamda={value!r}.") + self._lamda = value + + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + if value is not None and value.ndim != 1: + raise ValueError( + "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " + f"{value.ndim}`." + ) + self._pref_vector = value + + def reset(self) -> None: + """Clears the stored task weights so the next forward starts from uniform.""" + + self._w = None + self._state_key = None + + def forward(self, matrix: Matrix, /) -> Tensor: + self._ensure_state(matrix) + w = cast(Tensor, self._w) + w_tilde = self._resolve_w_tilde(matrix) + + velocity: Tensor | None = None + for _ in range(self._n_iter): + grad = matrix @ (w + self._lamda * w_tilde) + velocity = grad if velocity is None else self._momentum * velocity + grad + w = self._projection2simplex(w - self._lr * velocity) + + self._w = w + return (w + self._lamda * w_tilde) / (1.0 + self._lamda) + + def _resolve_w_tilde(self, matrix: Matrix) -> Tensor: + m = matrix.shape[0] + if self._pref_vector is None: + return matrix.new_full((m,), 1.0 / m) + if self._pref_vector.shape[0] != m: + raise ValueError( + "The length of `pref_vector` must match the number of rows of the input matrix. " + f"Found len(pref_vector)={self._pref_vector.shape[0]} and matrix.shape[0]={m}." + ) + return self._pref_vector.to(dtype=matrix.dtype, device=matrix.device) + + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.dtype, matrix.device) + if self._state_key == key and self._w is not None: + return + self._w = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) + self._state_key = key + + @staticmethod + def _projection2simplex(y: Tensor) -> Tensor: + """Euclidean projection of ``y`` onto the probability simplex.""" + + m = len(y) + sorted_y = torch.sort(y, descending=True)[0] + tmpsum = y.new_zeros(()) + tmax_f = (torch.sum(y) - 1.0) / m + for i in range(m - 1): + tmpsum = tmpsum + sorted_y[i] + tmax = (tmpsum - 1.0) / (i + 1.0) + if tmax > sorted_y[i + 1]: + tmax_f = tmax + break + return torch.max(y - tmax_f, y.new_zeros(m)) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(lr={self.lr!r}, momentum={self.momentum!r}, " + f"n_iter={self.n_iter!r}, lamda={self.lamda!r}, pref_vector={self.pref_vector!r})" + ) diff --git a/tests/unit/aggregation/test_sdmgrad.py b/tests/unit/aggregation/test_sdmgrad.py new file mode 100644 index 00000000..90a25e46 --- /dev/null +++ b/tests/unit/aggregation/test_sdmgrad.py @@ -0,0 +1,254 @@ +import torch +from pytest import raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._sdmgrad import SDMGradWeighting + + +def _project_to_simplex(y: Tensor) -> Tensor: + """Reference Euclidean projection onto the probability simplex.""" + + m = len(y) + sorted_y = torch.sort(y, descending=True)[0] + tmpsum = y.new_zeros(()) + tmax_f = (torch.sum(y) - 1.0) / m + for i in range(m - 1): + tmpsum = tmpsum + sorted_y[i] + tmax = (tmpsum - 1.0) / (i + 1.0) + if tmax > sorted_y[i + 1]: + tmax_f = tmax + break + return torch.max(y - tmax_f, y.new_zeros(m)) + + +def _sdmgrad_reference( + A: Tensor, + w: Tensor, + w_tilde: Tensor, + lr: float, + momentum: float, + n_iter: int, + lamda: float, +) -> tuple[Tensor, Tensor]: + """Reference inner solve. Returns the updated state ``w`` and the returned (normalized) weights.""" + + velocity: Tensor | None = None + for _ in range(n_iter): + grad = A @ (w + lamda * w_tilde) + velocity = grad if velocity is None else momentum * velocity + grad + w = _project_to_simplex(w - lr * velocity) + return w, (w + lamda * w_tilde) / (1.0 + lamda) + + +def test_representations() -> None: + W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lamda=0.3) + assert ( + repr(W) == "SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lamda=0.3, pref_vector=None)" + ) + + W_pref = SDMGradWeighting(pref_vector=tensor_([0.25, 0.75])) + assert "pref_vector=tensor(" in repr(W_pref) + + +def test_reset_restores_first_step_behavior() -> None: + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + A = J1 @ J2.T + W = SDMGradWeighting() + first = W(A) + W(A) + W.reset() + assert_close(first, W(A)) + + +def test_lr_setter_accepts_valid() -> None: + W = SDMGradWeighting() + W.lr = 0.5 + assert W.lr == 0.5 + + +def test_lr_setter_rejects_non_positive() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="lr"): + W.lr = 0.0 + with raises(ValueError, match="lr"): + W.lr = -1.0 + + +def test_momentum_setter_accepts_valid() -> None: + W = SDMGradWeighting() + W.momentum = 0.0 + assert W.momentum == 0.0 + W.momentum = 0.9 + assert W.momentum == 0.9 + + +def test_momentum_setter_rejects_out_of_range() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="momentum"): + W.momentum = -0.1 + with raises(ValueError, match="momentum"): + W.momentum = 1.0 + + +def test_n_iter_setter_rejects_non_positive() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="n_iter"): + W.n_iter = 0 + + +def test_lamda_setter_accepts_valid() -> None: + W = SDMGradWeighting() + W.lamda = 0.0 + assert W.lamda == 0.0 + W.lamda = 0.6 + assert W.lamda == 0.6 + + +def test_lamda_setter_rejects_negative() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="lamda"): + W.lamda = -0.1 + + +def test_pref_vector_setter_rejects_non_1d() -> None: + W = SDMGradWeighting() + with raises(ValueError, match="pref_vector"): + W.pref_vector = tensor_([[0.5, 0.5], [0.5, 0.5]]) + + +def test_output_lies_on_simplex() -> None: + """The simplex projection and (1+lamda) normalization keep the weights on the simplex.""" + + J1 = randn_((4, 10)) + J2 = randn_((4, 10)) + A = J1 @ J2.T + W = SDMGradWeighting(lamda=0.3) + weights = W(A) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_update_recurrence() -> None: + """Verify one full inner solve by hand.""" + + lr, momentum, n_iter, lamda = 10.0, 0.5, 5, 0.3 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + A = J1 @ J2.T + m = J1.shape[0] + + W = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda) + w0 = tensor_([1.0 / m] * m) + w_tilde = tensor_([1.0 / m] * m) + _, expected = _sdmgrad_reference(A, w0, w_tilde, lr, momentum, n_iter, lamda) + + assert_close(W(A), expected) + + +def test_two_consecutive_steps() -> None: + """Verify warm-started carry-over across two consecutive calls.""" + + lr, momentum, n_iter, lamda = 10.0, 0.5, 5, 0.3 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + J3 = randn_((3, 8)) + J4 = randn_((3, 8)) + A1 = J1 @ J2.T + A2 = J3 @ J4.T + m = J1.shape[0] + + W = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda) + w_tilde = tensor_([1.0 / m] * m) + + w0 = tensor_([1.0 / m] * m) + w1, out1 = _sdmgrad_reference(A1, w0, w_tilde, lr, momentum, n_iter, lamda) + _, out2 = _sdmgrad_reference(A2, w1, w_tilde, lr, momentum, n_iter, lamda) + + assert_close(W(A1), out1) + assert_close(W(A2), out2) + + +def test_custom_pref_vector() -> None: + """A custom preference vector is used as the target direction and changes the output.""" + + lr, momentum, n_iter, lamda = 10.0, 0.5, 5, 0.3 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + A = J1 @ J2.T + m = J1.shape[0] + pref = tensor_([0.1, 0.2, 0.7]) + + W = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda, pref_vector=pref) + w0 = tensor_([1.0 / m] * m) + _, expected = _sdmgrad_reference(A, w0, pref, lr, momentum, n_iter, lamda) + assert_close(W(A), expected) + + # The custom preference vector should change the output compared to the uniform default. + W_default = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda) + assert not torch.allclose(W(A), W_default(A)) + + +def test_pref_vector_wrong_length_raises() -> None: + W = SDMGradWeighting(pref_vector=tensor_([0.5, 0.5])) + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + A = J1 @ J2.T + with raises(ValueError, match="pref_vector"): + W(A) + + +def test_changing_m_auto_resets() -> None: + """When the number of objectives changes, the warm-started state is re-initialised to uniform.""" + + W = SDMGradWeighting() + W(randn_((3, 8)) @ randn_((3, 8)).T) + fresh = SDMGradWeighting() + J1 = randn_((2, 8)) + J2 = randn_((2, 8)) + A = J1 @ J2.T + assert_close(W(A), fresh(A)) + + +def test_non_differentiable() -> None: + """The _NonDifferentiable mixin must prevent autograd graph construction.""" + + A = randn_((3, 8)) @ randn_((3, 8)).T + A.requires_grad_(True) + W = SDMGradWeighting() + weights = W(A) + assert not weights.requires_grad + + +def test_non_symmetric_input() -> None: + """SDMGradWeighting must accept and correctly process a non-symmetric cross-batch matrix.""" + + lr, momentum, n_iter, lamda = 10.0, 0.5, 5, 0.3 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + A = J1 @ J2.T # not symmetric, not PSD in general + m = J1.shape[0] + + W = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda) + w0 = tensor_([1.0 / m] * m) + w_tilde = tensor_([1.0 / m] * m) + _, expected = _sdmgrad_reference(A, w0, w_tilde, lr, momentum, n_iter, lamda) + + assert_close(W(A), expected) + assert (W(A) >= 0).all() + + +def test_projection2simplex_known_values() -> None: + """The simplex projection matches hand-computed Euclidean projections.""" + + assert_close( + SDMGradWeighting._projection2simplex(tensor_([0.5, 0.1, 0.1])), + tensor_([0.6, 0.2, 0.2]), + ) + assert_close( + SDMGradWeighting._projection2simplex(tensor_([1.0, 0.0, -0.5])), + tensor_([1.0, 0.0, 0.0]), + ) From f3c3d482c784b06ca046f57b3963d97f6f0ecd26 Mon Sep 17 00:00:00 2001 From: Khush Date: Wed, 10 Jun 2026 13:29:17 -0400 Subject: [PATCH 2/4] refactor(aggregation): address PR review comments on SDMGradWeighting --- src/torchjd/aggregation/_modo.py | 19 +- src/torchjd/aggregation/_sdmgrad.py | 103 +++-------- src/torchjd/aggregation/_utils/simplex.py | 18 ++ tests/unit/aggregation/_utils/test_simplex.py | 19 ++ tests/unit/aggregation/test_modo.py | 15 -- tests/unit/aggregation/test_sdmgrad.py | 175 +++++------------- 6 files changed, 106 insertions(+), 243 deletions(-) create mode 100644 src/torchjd/aggregation/_utils/simplex.py create mode 100644 tests/unit/aggregation/_utils/test_simplex.py diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index 4856430a..a7b6fc0e 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -11,6 +11,7 @@ from torchjd.aggregation._mixins import _NonDifferentiable from torchjd.linalg import Matrix +from ._utils.simplex import _projection2simplex from ._weighting_bases import _MatrixWeighting @@ -166,27 +167,11 @@ def forward(self, matrix: Matrix, /) -> Tensor: lambd = cast(Tensor, self._lambda) grad = matrix @ lambd + self._rho * lambd - lambd = self._projection2simplex(lambd - self._gamma * grad) + lambd = _projection2simplex(lambd - self._gamma * grad) self._lambda = lambd return lambd - @staticmethod - def _projection2simplex(y: Tensor) -> Tensor: - """Euclidean projection of ``y`` onto the probability simplex.""" - - m = len(y) - sorted_y = torch.sort(y, descending=True)[0] - tmpsum = y.new_zeros(()) - tmax_f = (torch.sum(y) - 1.0) / m - for i in range(m - 1): - tmpsum = tmpsum + sorted_y[i] - tmax = (tmpsum - 1.0) / (i + 1.0) - if tmax > sorted_y[i + 1]: - tmax_f = tmax - break - return torch.max(y - tmax_f, y.new_zeros(m)) - def _ensure_state(self, matrix: Matrix) -> None: key = (matrix.shape[0], matrix.dtype, matrix.device) if self._state_key == key and self._lambda is not None: diff --git a/src/torchjd/aggregation/_sdmgrad.py b/src/torchjd/aggregation/_sdmgrad.py index 44a5eb14..649c14fb 100644 --- a/src/torchjd/aggregation/_sdmgrad.py +++ b/src/torchjd/aggregation/_sdmgrad.py @@ -11,6 +11,7 @@ from torchjd.aggregation._mixins import _NonDifferentiable from torchjd.linalg import Matrix +from ._utils.simplex import _projection2simplex from ._weighting_bases import _MatrixWeighting @@ -29,68 +30,21 @@ class SDMGradWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): :param lr: Learning rate of the inner SGD that solves for the task weights. Must be positive. :param momentum: Momentum of the inner SGD. Must be in :math:`[0, 1)`. :param n_iter: Number of inner SGD iterations performed at each call. Must be positive. - :param lamda: Non-negative coefficient controlling how strongly the descent direction is pulled + :param lambda_: Non-negative coefficient controlling how strongly the descent direction is pulled toward the preference direction. Must be non-negative. :param pref_vector: The preference vector :math:`\tilde w` defining the target direction. If not - provided, defaults to the uniform vector :math:`[1/m, \ldots, 1/m]` (i.e. the average - gradient). + provided, defaults to the uniform vector :math:`[1/m, \ldots, 1/m]` (i.e. the target diection is the average gradient). .. note:: - The inner simplex-projected solve is adapted from the `official implementation + The inner simplex-projected solver is adapted from the `official implementation `_. Note that the - official class default ``lamda=0.6`` is overridden to ``0.3`` in their own experiments, which - is the value used here (and in `LibMTL + official class default for this coefficient is ``0.6``, overridden to ``0.3`` in their own + experiments, which is the value used here (and in `LibMTL `_). - .. admonition:: Example (two batches per step) - - The following example reproduces SDMGrad using two independent mini-batches per step, reusing - the second batch for the parameter update. - - .. testcode:: - - import torch - from torch.nn import Linear, MSELoss, ReLU, Sequential - from torch.optim import SGD - - from torchjd.aggregation import SDMGradWeighting - from torchjd.autojac import jac - - # Generate data (8 batches of 16 examples of dim 5) for the sake of the example. - inputs = torch.randn(8, 16, 5) - targets = torch.randn(8, 16) - - model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) - optimizer = SGD(model.parameters()) - criterion = MSELoss(reduction="none") - weighting = SDMGradWeighting(lamda=0.3) - params = list(model.parameters()) - - # Consume two consecutive (independent) batches per step. - for i in range(len(inputs) // 2): - input_1, input_2 = inputs[2 * i], inputs[2 * i + 1] - target_1, target_2 = targets[2 * i], targets[2 * i + 1] - - losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) - jacs_1 = jac(losses_1, params) - J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) - - # retain_graph=True so losses_2's graph survives for the backward step below. - losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) - jacs_2 = jac(losses_2, params, retain_graph=True) - J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) - - A = J_1 @ J_2.T - weights = weighting(A) - - losses_2.backward(weights) - optimizer.step() - optimizer.zero_grad() - .. admonition:: Example (three batches per step) - The following example reproduces SDMGrad using three independent mini-batches per step, - keeping the weight update and the parameter update on separate draws. + The following example shows how to train with the SDMGrad algorithm. .. testcode:: @@ -108,11 +62,12 @@ class SDMGradWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") - weighting = SDMGradWeighting(lamda=0.3) + weighting = SDMGradWeighting(lambda_=0.3) params = list(model.parameters()) # Consume three consecutive (independent) batches per step. for i in range(len(inputs) // 3): + # Batches corresponding to ξ, ξ' and ζ in the paper's algorithm. input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2] target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2] @@ -138,14 +93,14 @@ def __init__( lr: float = 10.0, momentum: float = 0.5, n_iter: int = 20, - lamda: float = 0.3, + lambda_: float = 0.3, pref_vector: Tensor | None = None, ) -> None: super().__init__() self.lr = lr self.momentum = momentum self.n_iter = n_iter - self.lamda = lamda + self.lambda_ = lambda_ self.pref_vector = pref_vector self._w: Tensor | None = None self._state_key: tuple[int, torch.dtype, torch.device] | None = None @@ -183,14 +138,14 @@ def n_iter(self, value: int) -> None: self._n_iter = value @property - def lamda(self) -> float: - return self._lamda + def lambda_(self) -> float: + return self._lambda - @lamda.setter - def lamda(self, value: float) -> None: + @lambda_.setter + def lambda_(self, value: float) -> None: if value < 0.0: - raise ValueError(f"Attribute `lamda` must be non-negative. Found lamda={value!r}.") - self._lamda = value + raise ValueError(f"Attribute `lambda_` must be non-negative. Found lambda_={value!r}.") + self._lambda = value @property def pref_vector(self) -> Tensor | None: @@ -218,12 +173,12 @@ def forward(self, matrix: Matrix, /) -> Tensor: velocity: Tensor | None = None for _ in range(self._n_iter): - grad = matrix @ (w + self._lamda * w_tilde) + grad = matrix @ (w + self._lambda * w_tilde) velocity = grad if velocity is None else self._momentum * velocity + grad - w = self._projection2simplex(w - self._lr * velocity) + w = _projection2simplex(w - self._lr * velocity) self._w = w - return (w + self._lamda * w_tilde) / (1.0 + self._lamda) + return (w + self._lambda * w_tilde) / (1.0 + self._lambda) def _resolve_w_tilde(self, matrix: Matrix) -> Tensor: m = matrix.shape[0] @@ -243,24 +198,8 @@ def _ensure_state(self, matrix: Matrix) -> None: self._w = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) self._state_key = key - @staticmethod - def _projection2simplex(y: Tensor) -> Tensor: - """Euclidean projection of ``y`` onto the probability simplex.""" - - m = len(y) - sorted_y = torch.sort(y, descending=True)[0] - tmpsum = y.new_zeros(()) - tmax_f = (torch.sum(y) - 1.0) / m - for i in range(m - 1): - tmpsum = tmpsum + sorted_y[i] - tmax = (tmpsum - 1.0) / (i + 1.0) - if tmax > sorted_y[i + 1]: - tmax_f = tmax - break - return torch.max(y - tmax_f, y.new_zeros(m)) - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(lr={self.lr!r}, momentum={self.momentum!r}, " - f"n_iter={self.n_iter!r}, lamda={self.lamda!r}, pref_vector={self.pref_vector!r})" + f"n_iter={self.n_iter!r}, lambda_={self.lambda_!r}, pref_vector={self.pref_vector!r})" ) diff --git a/src/torchjd/aggregation/_utils/simplex.py b/src/torchjd/aggregation/_utils/simplex.py new file mode 100644 index 00000000..b146d307 --- /dev/null +++ b/src/torchjd/aggregation/_utils/simplex.py @@ -0,0 +1,18 @@ +import torch +from torch import Tensor + + +def _projection2simplex(y: Tensor) -> Tensor: + """Euclidean projection of ``y`` onto the probability simplex.""" + + m = len(y) + sorted_y = torch.sort(y, descending=True)[0] + tmpsum = y.new_zeros(()) + tmax_f = (torch.sum(y) - 1.0) / m + for i in range(m - 1): + tmpsum = tmpsum + sorted_y[i] + tmax = (tmpsum - 1.0) / (i + 1.0) + if tmax > sorted_y[i + 1]: + tmax_f = tmax + break + return torch.max(y - tmax_f, y.new_zeros(m)) diff --git a/tests/unit/aggregation/_utils/test_simplex.py b/tests/unit/aggregation/_utils/test_simplex.py new file mode 100644 index 00000000..7f3907f5 --- /dev/null +++ b/tests/unit/aggregation/_utils/test_simplex.py @@ -0,0 +1,19 @@ +from torch.testing import assert_close +from utils.tensors import tensor_ + +from torchjd.aggregation._utils.simplex import _projection2simplex + + +def test_projection2simplex_known_values() -> None: + """The simplex projection matches hand-computed Euclidean projections.""" + + # Already-positive input: the deficit (1 - sum) is spread equally, no clamping. + assert_close( + _projection2simplex(tensor_([0.5, 0.1, 0.1])), + tensor_([0.6, 0.2, 0.2]), + ) + # Input with a negative entry: it gets clamped to zero. + assert_close( + _projection2simplex(tensor_([1.0, 0.0, -0.5])), + tensor_([1.0, 0.0, 0.0]), + ) diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index f169ea86..cd3943d1 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -170,18 +170,3 @@ def test_non_symmetric_input() -> None: assert_close(W(G), expected) assert W(G).shape == (m,) assert (W(G) >= 0).all() - - -def test_projection2simplex_known_values() -> None: - """The simplex projection matches hand-computed Euclidean projections.""" - - # Already-positive input: the deficit (1 - sum) is spread equally, no clamping. - assert_close( - MoDoWeighting._projection2simplex(tensor_([0.5, 0.1, 0.1])), - tensor_([0.6, 0.2, 0.2]), - ) - # Input with a negative entry: it gets clamped to zero. - assert_close( - MoDoWeighting._projection2simplex(tensor_([1.0, 0.0, -0.5])), - tensor_([1.0, 0.0, 0.0]), - ) diff --git a/tests/unit/aggregation/test_sdmgrad.py b/tests/unit/aggregation/test_sdmgrad.py index 90a25e46..e8c73546 100644 --- a/tests/unit/aggregation/test_sdmgrad.py +++ b/tests/unit/aggregation/test_sdmgrad.py @@ -1,56 +1,18 @@ import torch from pytest import raises -from torch import Tensor from torch.testing import assert_close from utils.tensors import randn_, tensor_ from torchjd.aggregation._sdmgrad import SDMGradWeighting -def _project_to_simplex(y: Tensor) -> Tensor: - """Reference Euclidean projection onto the probability simplex.""" - - m = len(y) - sorted_y = torch.sort(y, descending=True)[0] - tmpsum = y.new_zeros(()) - tmax_f = (torch.sum(y) - 1.0) / m - for i in range(m - 1): - tmpsum = tmpsum + sorted_y[i] - tmax = (tmpsum - 1.0) / (i + 1.0) - if tmax > sorted_y[i + 1]: - tmax_f = tmax - break - return torch.max(y - tmax_f, y.new_zeros(m)) - - -def _sdmgrad_reference( - A: Tensor, - w: Tensor, - w_tilde: Tensor, - lr: float, - momentum: float, - n_iter: int, - lamda: float, -) -> tuple[Tensor, Tensor]: - """Reference inner solve. Returns the updated state ``w`` and the returned (normalized) weights.""" - - velocity: Tensor | None = None - for _ in range(n_iter): - grad = A @ (w + lamda * w_tilde) - velocity = grad if velocity is None else momentum * velocity + grad - w = _project_to_simplex(w - lr * velocity) - return w, (w + lamda * w_tilde) / (1.0 + lamda) - - def test_representations() -> None: - W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lamda=0.3) + W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lambda_=0.3) assert ( - repr(W) == "SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lamda=0.3, pref_vector=None)" + repr(W) + == "SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lambda_=0.3, pref_vector=None)" ) - W_pref = SDMGradWeighting(pref_vector=tensor_([0.25, 0.75])) - assert "pref_vector=tensor(" in repr(W_pref) - def test_reset_restores_first_step_behavior() -> None: J1 = randn_((3, 8)) @@ -99,18 +61,18 @@ def test_n_iter_setter_rejects_non_positive() -> None: W.n_iter = 0 -def test_lamda_setter_accepts_valid() -> None: +def test_lambda_setter_accepts_valid() -> None: W = SDMGradWeighting() - W.lamda = 0.0 - assert W.lamda == 0.0 - W.lamda = 0.6 - assert W.lamda == 0.6 + W.lambda_ = 0.0 + assert W.lambda_ == 0.0 + W.lambda_ = 0.6 + assert W.lambda_ == 0.6 -def test_lamda_setter_rejects_negative() -> None: +def test_lambda_setter_rejects_negative() -> None: W = SDMGradWeighting() - with raises(ValueError, match="lamda"): - W.lamda = -0.1 + with raises(ValueError, match="lambda_"): + W.lambda_ = -0.1 def test_pref_vector_setter_rejects_non_1d() -> None: @@ -120,12 +82,12 @@ def test_pref_vector_setter_rejects_non_1d() -> None: def test_output_lies_on_simplex() -> None: - """The simplex projection and (1+lamda) normalization keep the weights on the simplex.""" + """The simplex projection and (1 + lambda_) normalization keep the weights on the simplex.""" J1 = randn_((4, 10)) J2 = randn_((4, 10)) A = J1 @ J2.T - W = SDMGradWeighting(lamda=0.3) + W = SDMGradWeighting(lambda_=0.3) weights = W(A) assert weights.shape == (4,) assert (weights >= 0).all() @@ -133,63 +95,49 @@ def test_output_lies_on_simplex() -> None: def test_update_recurrence() -> None: - """Verify one full inner solve by hand.""" - - lr, momentum, n_iter, lamda = 10.0, 0.5, 5, 0.3 - J1 = randn_((3, 8)) - J2 = randn_((3, 8)) - A = J1 @ J2.T - m = J1.shape[0] - - W = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda) - w0 = tensor_([1.0 / m] * m) - w_tilde = tensor_([1.0 / m] * m) - _, expected = _sdmgrad_reference(A, w0, w_tilde, lr, momentum, n_iter, lamda) + """One inner-solve step matches the manually-computed expected output. - assert_close(W(A), expected) + With A = diag(2, 1, 0), n_iter=1, lr=10, lambda_=0.3, starting from uniform w=[1/3,1/3,1/3]: + grad = A @ 1.3*[1/3,1/3,1/3] = [13/15, 13/30, 0] + w - lr*grad = [-25/3, -4, 1/3] -> projected to [0, 0, 1] + return ([0,0,1] + 0.3*[1/3,...]) / 1.3 = [1/13, 1/13, 11/13] + """ + A = torch.diag(tensor_([2.0, 1.0, 0.0])) + W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) + assert_close(W(A), tensor_([1 / 13, 1 / 13, 11 / 13])) def test_two_consecutive_steps() -> None: - """Verify warm-started carry-over across two consecutive calls.""" + """Warm-started carry-over across two consecutive calls matches manually-computed values. - lr, momentum, n_iter, lamda = 10.0, 0.5, 5, 0.3 - J1 = randn_((3, 8)) - J2 = randn_((3, 8)) - J3 = randn_((3, 8)) - J4 = randn_((3, 8)) - A1 = J1 @ J2.T - A2 = J3 @ J4.T - m = J1.shape[0] - - W = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda) - w_tilde = tensor_([1.0 / m] * m) - - w0 = tensor_([1.0 / m] * m) - w1, out1 = _sdmgrad_reference(A1, w0, w_tilde, lr, momentum, n_iter, lamda) - _, out2 = _sdmgrad_reference(A2, w1, w_tilde, lr, momentum, n_iter, lamda) - - assert_close(W(A1), out1) - assert_close(W(A2), out2) + Step 1: A1=diag(2,1,0) -> state w=[0,0,1], return [1/13, 1/13, 11/13] (see test_update_recurrence) + Step 2: A2=eye(3), warm start w=[0,0,1]: + grad = [0.1, 0.1, 1.1]; w - lr*grad = [-1,-1,-10] -> projected to [0.5, 0.5, 0] + return ([0.5,0.5,0] + 0.3*[1/3,...]) / 1.3 = [6/13, 6/13, 1/13] + """ + A1 = torch.diag(tensor_([2.0, 1.0, 0.0])) + A2 = torch.eye(3) + W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) + assert_close(W(A1), tensor_([1 / 13, 1 / 13, 11 / 13])) + assert_close(W(A2), tensor_([6 / 13, 6 / 13, 1 / 13])) def test_custom_pref_vector() -> None: - """A custom preference vector is used as the target direction and changes the output.""" + """A custom preference vector changes the output relative to the uniform default. - lr, momentum, n_iter, lamda = 10.0, 0.5, 5, 0.3 - J1 = randn_((3, 8)) - J2 = randn_((3, 8)) - A = J1 @ J2.T - m = J1.shape[0] - pref = tensor_([0.1, 0.2, 0.7]) - - W = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda, pref_vector=pref) - w0 = tensor_([1.0 / m] * m) - _, expected = _sdmgrad_reference(A, w0, pref, lr, momentum, n_iter, lamda) - assert_close(W(A), expected) + With A=diag(2,1,0), pref=[0,0,1], n_iter=1, lr=10, lambda_=0.3: + grad = A @ ([1/3,1/3,1/3] + 0.3*[0,0,1]) = [2/3, 1/3, 0] + w - lr*grad = [-19/3, -3, 1/3] -> projected to [0, 0, 1] + return ([0,0,1] + 0.3*[0,0,1]) / 1.3 = [0, 0, 1] + This differs from the uniform-pref result [1/13, 1/13, 11/13]. + """ + A = torch.diag(tensor_([2.0, 1.0, 0.0])) + pref = tensor_([0.0, 0.0, 1.0]) + W_pref = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3, pref_vector=pref) + assert_close(W_pref(A), tensor_([0.0, 0.0, 1.0])) - # The custom preference vector should change the output compared to the uniform default. - W_default = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda) - assert not torch.allclose(W(A), W_default(A)) + W_default = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) + assert not torch.allclose(W_default(A), tensor_([0.0, 0.0, 1.0])) def test_pref_vector_wrong_length_raises() -> None: @@ -221,34 +169,3 @@ def test_non_differentiable() -> None: W = SDMGradWeighting() weights = W(A) assert not weights.requires_grad - - -def test_non_symmetric_input() -> None: - """SDMGradWeighting must accept and correctly process a non-symmetric cross-batch matrix.""" - - lr, momentum, n_iter, lamda = 10.0, 0.5, 5, 0.3 - J1 = randn_((3, 8)) - J2 = randn_((3, 8)) - A = J1 @ J2.T # not symmetric, not PSD in general - m = J1.shape[0] - - W = SDMGradWeighting(lr=lr, momentum=momentum, n_iter=n_iter, lamda=lamda) - w0 = tensor_([1.0 / m] * m) - w_tilde = tensor_([1.0 / m] * m) - _, expected = _sdmgrad_reference(A, w0, w_tilde, lr, momentum, n_iter, lamda) - - assert_close(W(A), expected) - assert (W(A) >= 0).all() - - -def test_projection2simplex_known_values() -> None: - """The simplex projection matches hand-computed Euclidean projections.""" - - assert_close( - SDMGradWeighting._projection2simplex(tensor_([0.5, 0.1, 0.1])), - tensor_([0.6, 0.2, 0.2]), - ) - assert_close( - SDMGradWeighting._projection2simplex(tensor_([1.0, 0.0, -0.5])), - tensor_([1.0, 0.0, 0.0]), - ) From db86b53c9f00db5b60ae097bb8ae549ed85bcb81 Mon Sep 17 00:00:00 2001 From: Khush Date: Wed, 10 Jun 2026 13:33:27 -0400 Subject: [PATCH 3/4] fix(aggregation): use eye_ helper to respect DTYPE in test_two_consecutive_steps --- tests/unit/aggregation/test_sdmgrad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/aggregation/test_sdmgrad.py b/tests/unit/aggregation/test_sdmgrad.py index e8c73546..b8616c36 100644 --- a/tests/unit/aggregation/test_sdmgrad.py +++ b/tests/unit/aggregation/test_sdmgrad.py @@ -1,7 +1,7 @@ import torch from pytest import raises from torch.testing import assert_close -from utils.tensors import randn_, tensor_ +from utils.tensors import eye_, randn_, tensor_ from torchjd.aggregation._sdmgrad import SDMGradWeighting @@ -116,7 +116,7 @@ def test_two_consecutive_steps() -> None: return ([0.5,0.5,0] + 0.3*[1/3,...]) / 1.3 = [6/13, 6/13, 1/13] """ A1 = torch.diag(tensor_([2.0, 1.0, 0.0])) - A2 = torch.eye(3) + A2 = eye_(3) W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) assert_close(W(A1), tensor_([1 / 13, 1 / 13, 11 / 13])) assert_close(W(A2), tensor_([6 / 13, 6 / 13, 1 / 13])) From 6f1c9d25bbc3bc9ffcb6bc1eabd546fbefe5d088 Mon Sep 17 00:00:00 2001 From: Khush Date: Wed, 10 Jun 2026 17:26:31 -0400 Subject: [PATCH 4/4] feat(aggregation): add scale normalization to SDMGradWeighting --- src/torchjd/aggregation/_sdmgrad.py | 11 ++++++++++- tests/unit/aggregation/test_sdmgrad.py | 23 ++++++++++++----------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/torchjd/aggregation/_sdmgrad.py b/src/torchjd/aggregation/_sdmgrad.py index 649c14fb..e4eb5ab4 100644 --- a/src/torchjd/aggregation/_sdmgrad.py +++ b/src/torchjd/aggregation/_sdmgrad.py @@ -42,6 +42,11 @@ class SDMGradWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): experiments, which is the value used here (and in `LibMTL `_). + Before the inner solve, the input matrix is scale-normalized by the mean of the square + roots of its non-negative diagonal entries (following both the official implementation and + LibMTL). This makes the inner SGD learning rate scale-invariant with respect to gradient + magnitude. The normalization is briefly described in section 6.1 of the paper. + .. admonition:: Example (three batches per step) The following example shows how to train with the SDMGrad algorithm. @@ -171,9 +176,13 @@ def forward(self, matrix: Matrix, /) -> Tensor: w = cast(Tensor, self._w) w_tilde = self._resolve_w_tilde(matrix) + diag = torch.diag(matrix).clamp(min=0.0) + scale = diag.sqrt().mean() + a = matrix / (scale.pow(2) + 1e-8) + velocity: Tensor | None = None for _ in range(self._n_iter): - grad = matrix @ (w + self._lambda * w_tilde) + grad = a @ (w + self._lambda * w_tilde) velocity = grad if velocity is None else self._momentum * velocity + grad w = _projection2simplex(w - self._lr * velocity) diff --git a/tests/unit/aggregation/test_sdmgrad.py b/tests/unit/aggregation/test_sdmgrad.py index b8616c36..e846b455 100644 --- a/tests/unit/aggregation/test_sdmgrad.py +++ b/tests/unit/aggregation/test_sdmgrad.py @@ -97,12 +97,13 @@ def test_output_lies_on_simplex() -> None: def test_update_recurrence() -> None: """One inner-solve step matches the manually-computed expected output. - With A = diag(2, 1, 0), n_iter=1, lr=10, lambda_=0.3, starting from uniform w=[1/3,1/3,1/3]: - grad = A @ 1.3*[1/3,1/3,1/3] = [13/15, 13/30, 0] - w - lr*grad = [-25/3, -4, 1/3] -> projected to [0, 0, 1] + With A = diag(1, 0.25, 0), n_iter=1, lr=10, lambda_=0.3, starting from uniform w=[1/3,1/3,1/3]: + scale = mean([1, 0.5, 0]) = 0.5 => A_norm = diag(4, 1, 0) + grad = A_norm @ 1.3*[1/3,1/3,1/3] = [52/30, 13/30, 0] + w - lr*grad = [-17, -4, 1/3] -> projected to [0, 0, 1] return ([0,0,1] + 0.3*[1/3,...]) / 1.3 = [1/13, 1/13, 11/13] """ - A = torch.diag(tensor_([2.0, 1.0, 0.0])) + A = torch.diag(tensor_([1.0, 0.25, 0.0])) W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) assert_close(W(A), tensor_([1 / 13, 1 / 13, 11 / 13])) @@ -110,12 +111,12 @@ def test_update_recurrence() -> None: def test_two_consecutive_steps() -> None: """Warm-started carry-over across two consecutive calls matches manually-computed values. - Step 1: A1=diag(2,1,0) -> state w=[0,0,1], return [1/13, 1/13, 11/13] (see test_update_recurrence) - Step 2: A2=eye(3), warm start w=[0,0,1]: + Step 1: A1=diag(1,0.25,0) -> A_norm=diag(4,1,0), state w=[0,0,1], return [1/13,1/13,11/13] + Step 2: A2=eye(3), scale=1 so A_norm=eye(3), warm start w=[0,0,1]: grad = [0.1, 0.1, 1.1]; w - lr*grad = [-1,-1,-10] -> projected to [0.5, 0.5, 0] return ([0.5,0.5,0] + 0.3*[1/3,...]) / 1.3 = [6/13, 6/13, 1/13] """ - A1 = torch.diag(tensor_([2.0, 1.0, 0.0])) + A1 = torch.diag(tensor_([1.0, 0.25, 0.0])) A2 = eye_(3) W = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3) assert_close(W(A1), tensor_([1 / 13, 1 / 13, 11 / 13])) @@ -125,13 +126,13 @@ def test_two_consecutive_steps() -> None: def test_custom_pref_vector() -> None: """A custom preference vector changes the output relative to the uniform default. - With A=diag(2,1,0), pref=[0,0,1], n_iter=1, lr=10, lambda_=0.3: - grad = A @ ([1/3,1/3,1/3] + 0.3*[0,0,1]) = [2/3, 1/3, 0] - w - lr*grad = [-19/3, -3, 1/3] -> projected to [0, 0, 1] + With A=diag(1,0.25,0), A_norm=diag(4,1,0), pref=[0,0,1], n_iter=1, lr=10, lambda_=0.3: + grad = A_norm @ ([1/3,1/3,1/3] + 0.3*[0,0,1]) = [4/3, 1/3, 0] + w - lr*grad = [-13, -3, 1/3] -> projected to [0, 0, 1] return ([0,0,1] + 0.3*[0,0,1]) / 1.3 = [0, 0, 1] This differs from the uniform-pref result [1/13, 1/13, 11/13]. """ - A = torch.diag(tensor_([2.0, 1.0, 0.0])) + A = torch.diag(tensor_([1.0, 0.25, 0.0])) pref = tensor_([0.0, 0.0, 1.0]) W_pref = SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=1, lambda_=0.3, pref_vector=pref) assert_close(W_pref(A), tensor_([0.0, 0.0, 1.0]))