diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bd9921e..764eed3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `STCH` from [Smooth Tchebycheff Scalarization for Multi-Objective + Optimization](https://openreview.net/pdf?id=m4dO5L6eCp), a `Scalarizer` that combines the input + tensor of values into a smooth approximation of their (weighted, shifted) maximum. - Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a simplex-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`. - Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst index 8fd87dc8..90a93fad 100644 --- a/docs/source/docs/scalarization/index.rst +++ b/docs/source/docs/scalarization/index.rst @@ -19,4 +19,5 @@ Abstract base class geometric_mean.rst mean.rst random.rst + stch.rst sum.rst diff --git a/docs/source/docs/scalarization/stch.rst b/docs/source/docs/scalarization/stch.rst new file mode 100644 index 00000000..731fe31b --- /dev/null +++ b/docs/source/docs/scalarization/stch.rst @@ -0,0 +1,7 @@ +:hide-toc: + +STCH +==== + +.. autoclass:: torchjd.scalarization.STCH + :members: __call__ diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index 337d38ca..98f6c79d 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -24,6 +24,7 @@ from ._mean import Mean from ._random import Random from ._scalarizer_base import Scalarizer +from ._stch import STCH from ._sum import Sum -__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "Sum"] +__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "STCH", "Sum"] diff --git a/src/torchjd/scalarization/_stch.py b/src/torchjd/scalarization/_stch.py new file mode 100644 index 00000000..38a5dcea --- /dev/null +++ b/src/torchjd/scalarization/_stch.py @@ -0,0 +1,88 @@ +import torch +from torch import Tensor + +from ._scalarizer_base import Scalarizer + + +class STCH(Scalarizer): + r""" + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using smooth + Tchebycheff scalarization, as defined in `Smooth Tchebycheff Scalarization for Multi-Objective + Optimization `_. + + It returns + + .. math:: + \mu \log \sum_{i=1}^m \exp\left(\frac{\lambda_i (f_i - z_i^*)}{\mu}\right), + + a smooth approximation of the (non-differentiable) weighted maximum + :math:`\max_i \lambda_i (f_i - z_i^*)` that becomes tighter as ``mu`` decreases. + + Following the paper's notation: + + - :math:`f_i` is the :math:`i`-th input value (the :math:`i`-th objective), + - :math:`m` is the number of objectives (the number of elements of the input), + - :math:`\lambda_i` is its preference weight (the ``weights`` parameter), + - :math:`z_i^*` is the :math:`i`-th component of the ideal point (the ``reference`` parameter), + - :math:`\mu` is the smoothing parameter (the ``mu`` parameter). + + :param mu: The smoothing parameter :math:`\mu`. Must be strictly positive. Smaller values make + the scalarization closer to the maximum. The paper evaluates :math:`\mu \in \{0.01, 0.1, + 0.5, 1\}` and reports that a small :math:`\mu` works reasonably well, while no single value + is best across all problems. + :param weights: The preference vector :math:`\lambda` applied to the values (in the paper, on + the probability simplex). If ``None``, a uniform preference summing to one is used. If + provided, it must have the same shape as the values passed at call time. + :param reference: The ideal point :math:`z^*` subtracted from the values. If ``None``, no shift + is applied. If provided, it must have the same shape as the values passed at call time. + """ + + def __init__( + self, + mu: float, + weights: Tensor | None = None, + reference: Tensor | None = None, + ) -> None: + if mu <= 0.0: + raise ValueError(f"Parameter `mu` should be strictly positive. Found `mu = {mu}`.") + + super().__init__() + self.mu = mu + self.weights = weights + self.reference = reference + + def forward(self, values: Tensor, /) -> Tensor: + if self.weights is not None and self.weights.shape != values.shape: + raise ValueError( + f"Parameter `weights` should have the same shape as `values`. Found " + f"`weights.shape = {tuple(self.weights.shape)}` and `values.shape = " + f"{tuple(values.shape)}`." + ) + if self.reference is not None and self.reference.shape != values.shape: + raise ValueError( + f"Parameter `reference` should have the same shape as `values`. Found " + f"`reference.shape = {tuple(self.reference.shape)}` and `values.shape = " + f"{tuple(values.shape)}`." + ) + + if self.weights is None: + weights = torch.full_like(values, 1.0 / values.numel()) + else: + weights = self.weights + + shifted = values if self.reference is None else values - self.reference + + # Center the weighted values before dividing by mu (Appendix B.1 of the paper). This keeps + # the largest exponent at 0 so the `/ mu` step never overflows for large values and small + # mu. Adding `max_y` back makes it value-preserving: the result and its gradient are + # mathematically identical to `mu * logsumexp(weights * shifted / mu)`. + y = weights * shifted + max_y = y.max() + exponents = (y - max_y) / self.mu + return self.mu * torch.logsumexp(exponents.flatten(), dim=-1) + max_y + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(mu={self.mu}, weights={self.weights!r}, " + f"reference={self.reference!r})" + ) diff --git a/tests/unit/scalarization/test_stch.py b/tests/unit/scalarization/test_stch.py new file mode 100644 index 00000000..659a6a3b --- /dev/null +++ b/tests/unit/scalarization/test_stch.py @@ -0,0 +1,82 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from utils.tensors import tensor_ + +from torchjd.scalarization import STCH + +from ._asserts import ( + assert_grad_flow, + assert_permutation_invariant, + assert_returns_scalar, +) +from ._inputs import all_inputs + + +def test_value_default() -> None: + # Uniform weights, no reference: mu * logsumexp([0, 0]) = log(2). + out = STCH(mu=1.0)(tensor_([0.0, 0.0])) + torch.testing.assert_close(out, torch.log(tensor_(2.0))) + + +def test_value_with_weights() -> None: + # weights = [1, 1] on values [1, 1]: mu * logsumexp([1, 1]) = 1 + log(2). + out = STCH(mu=1.0, weights=tensor_([1.0, 1.0]))(tensor_([1.0, 1.0])) + torch.testing.assert_close(out, 1.0 + torch.log(tensor_(2.0))) + + +def test_value_with_reference() -> None: + # reference shifts values to [0, 0], so the result collapses back to log(2). + out = STCH(mu=1.0, weights=tensor_([1.0, 1.0]), reference=tensor_([1.0, 1.0]))( + tensor_([1.0, 1.0]) + ) + torch.testing.assert_close(out, torch.log(tensor_(2.0))) + + +@mark.parametrize("losses", all_inputs) +def test_expected_structure(losses: Tensor) -> None: + assert_returns_scalar(STCH(mu=1.0), losses) + + +@mark.parametrize("losses", all_inputs) +def test_grad_flow(losses: Tensor) -> None: + assert_grad_flow(STCH(mu=1.0), losses) + + +@mark.parametrize("losses", all_inputs) +def test_permutation_invariant(losses: Tensor) -> None: + # With uniform weights and no reference, STCH is symmetric in its inputs. + assert_permutation_invariant(STCH(mu=1.0), losses) + + +def test_does_not_overflow_for_large_values_and_small_mu() -> None: + # `weights * values / mu` would overflow to inf before logsumexp can stabilize it. The + # value-preserving centering keeps the result finite and equal to the dominant (max) term. + values = tensor_([1e30, 2e30, 3e30]) + out = STCH(mu=1e-10)(values) + assert out.isfinite() + torch.testing.assert_close(out, tensor_(1e30)) # 3e30 weighted by the uniform 1/3. + + +@mark.parametrize("mu", [0.0, -1.0]) +def test_raises_on_non_positive_mu(mu: float) -> None: + with raises(ValueError): + STCH(mu=mu) + + +def test_raises_on_weights_shape_mismatch() -> None: + scalarizer = STCH(mu=1.0, weights=tensor_([1.0, 1.0, 1.0])) + with raises(ValueError): + scalarizer(tensor_([1.0, 1.0])) + + +def test_raises_on_reference_shape_mismatch() -> None: + scalarizer = STCH(mu=1.0, reference=tensor_([1.0, 1.0, 1.0])) + with raises(ValueError): + scalarizer(tensor_([1.0, 1.0])) + + +def test_representations() -> None: + s = STCH(mu=0.5) + assert repr(s) == "STCH(mu=0.5, weights=None, reference=None)" + assert str(s) == "STCH"