From f8bbb4dcb225886e082a881a086708b97390df7f Mon Sep 17 00:00:00 2001 From: ppraneth Date: Tue, 9 Jun 2026 08:11:48 +0530 Subject: [PATCH 1/5] add IMTL-L Signed-off-by: ppraneth --- CHANGELOG.md | 9 ++ docs/source/docs/scalarization/imtl.rst | 7 ++ docs/source/docs/scalarization/index.rst | 1 + src/torchjd/scalarization/__init__.py | 2 + src/torchjd/scalarization/_imtl.py | 83 +++++++++++++++++++ tests/unit/scalarization/test_imtl.py | 100 +++++++++++++++++++++++ 6 files changed, 202 insertions(+) create mode 100644 docs/source/docs/scalarization/imtl.rst create mode 100644 src/torchjd/scalarization/_imtl.py create mode 100644 tests/unit/scalarization/test_imtl.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 39f1095e..fda141c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). This changelog does not include internal changes that do not affect the user. +## [Unreleased] + +### Added + +- Added `IMTL` (the loss-balancing variant, IMTL-L, of Impartial Multi-Task Learning) from [Towards + Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful + `Scalarizer` that learns a per-task scale `s_i` and combines the values as + `Σ (exp(s_i) · L_i − s_i)`. + ## [0.13.0] - 2026-06-07 ### Added diff --git a/docs/source/docs/scalarization/imtl.rst b/docs/source/docs/scalarization/imtl.rst new file mode 100644 index 00000000..e0846dd3 --- /dev/null +++ b/docs/source/docs/scalarization/imtl.rst @@ -0,0 +1,7 @@ +:hide-toc: + +IMTL +==== + +.. autoclass:: torchjd.scalarization.IMTL + :members: __call__, reset diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst index e1d358af..75b03a04 100644 --- a/docs/source/docs/scalarization/index.rst +++ b/docs/source/docs/scalarization/index.rst @@ -16,6 +16,7 @@ Abstract base class constant.rst geometric_mean.rst + imtl.rst mean.rst random.rst stch.rst diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index a70d07bf..542112e9 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -21,6 +21,7 @@ from ._constant import Constant from ._geometric_mean import GeometricMean +from ._imtl import IMTL from ._mean import Mean from ._random import Random from ._scalarizer_base import Scalarizer @@ -31,6 +32,7 @@ __all__ = [ "Constant", "GeometricMean", + "IMTL", "Mean", "Random", "Scalarizer", diff --git a/src/torchjd/scalarization/_imtl.py b/src/torchjd/scalarization/_imtl.py new file mode 100644 index 00000000..3b808a8d --- /dev/null +++ b/src/torchjd/scalarization/_imtl.py @@ -0,0 +1,83 @@ +from collections.abc import Sequence + +import torch +from torch import Tensor, nn + +from torchjd._mixins import Stateful + +from ._scalarizer_base import Scalarizer + + +class IMTL(Scalarizer, Stateful): + r""" + :class:`~torchjd.Stateful` + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using learned + per-task scales. ``IMTL`` is the loss-balancing variant (called IMTL-L in the paper) of Impartial + Multi-Task Learning, proposed in `Towards Impartial Multi-Task Learning + `_. + + Each value :math:`L_i` is assigned a learnable scale parameter :math:`s_i`, and the values are + combined as + + .. math:: + \sum_i \left( e^{s_i} L_i - s_i \right) + + where: + + - :math:`L_i` is the :math:`i`-th value (typically the loss of task :math:`i`); + - :math:`s_i` is the learnable scale parameter of task :math:`i`. + + The factor :math:`e^{s_i}` rescales each loss so that the scaled losses stay at a comparable + magnitude across tasks, while the :math:`- s_i` term is a regularizer that prevents the trivial + solution :math:`s_i \to -\infty`. The :math:`s_i` are stored as an ``nn.Parameter``, so the + parameters of this scalarizer must be passed to the optimizer to be learned jointly with the + model. Unlike :class:`~torchjd.scalarization.UW`, IMTL-L makes no distribution assumption and + applies to any kind of loss. The complementary gradient-balancing variant (IMTL-G) is provided + as the :class:`~torchjd.aggregation.IMTLG` aggregator. + + :param shape: The shape of the values to scalarize, used to create one scale per value. An + ``int`` ``n`` is interpreted as the shape ``(n,)``. + + The following example shows how to train a model with Impartial Multi-Task Learning (loss + balance), as described in the paper. + + >>> import torch + >>> from torch.nn import Linear + >>> + >>> from torchjd.scalarization import IMTL + >>> + >>> model = Linear(3, 2) + >>> scalarizer = IMTL(2) # Move to the right device with e.g. IMTL(2).to(device="cuda") + >>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1) + >>> + >>> features = torch.randn(8, 3) + >>> # Compute some dummy losses just for the sake of the example + >>> losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. + >>> loss = scalarizer(losses) + >>> loss.backward() + >>> optimizer.step() + + .. note:: + The scales are initialized to ``0``, so at the start of training the scalarization reduces to + the plain sum of the values (since :math:`e^0 = 1`). Following the paper, IMTL-L is designed + to balance positive losses. + """ + + def __init__(self, shape: int | Sequence[int]) -> None: + super().__init__() + self.log_scale = nn.Parameter(torch.zeros(shape)) + + def forward(self, values: Tensor, /) -> Tensor: + if values.shape != self.log_scale.shape: + raise ValueError( + f"Parameter `values` should have shape {tuple(self.log_scale.shape)} (matching the " + f"shape of the scales). Found `values.shape = {tuple(values.shape)}`.", + ) + return (torch.exp(self.log_scale) * values - self.log_scale).sum() + + def reset(self) -> None: + with torch.no_grad(): + self.log_scale.zero_() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(shape={tuple(self.log_scale.shape)})" diff --git a/tests/unit/scalarization/test_imtl.py b/tests/unit/scalarization/test_imtl.py new file mode 100644 index 00000000..ad535584 --- /dev/null +++ b/tests/unit/scalarization/test_imtl.py @@ -0,0 +1,100 @@ +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises +from settings import DEVICE, DTYPE +from torch import Tensor +from utils.contexts import ExceptionContext +from utils.tensors import ones_, tensor_, zeros_ + +from torchjd.scalarization import IMTL + +from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import all_inputs + + +def _imtl(shape: int | tuple[int, ...]) -> IMTL: + """Builds an `IMTL` whose scales live on the test device and dtype.""" + return IMTL(shape).to(device=DEVICE, dtype=DTYPE) + + +def test_value() -> None: + # With scales initialized to 0, exp(0)=1 and -0=0, so the result is sum(values). + values = tensor_([1.0, 2.0, 4.0]) + torch.testing.assert_close(_imtl((3,))(values), tensor_(7.0)) + + +def test_int_shape_matches_tuple_shape() -> None: + values = tensor_([1.0, 2.0, 4.0]) + assert IMTL(3).log_scale.shape == (3,) + torch.testing.assert_close(_imtl(3)(values), _imtl((3,))(values)) + + +@mark.parametrize("values", all_inputs) +def test_expected_structure(values: Tensor) -> None: + assert_returns_scalar(_imtl(tuple(values.shape)), values) + + +@mark.parametrize("values", all_inputs) +def test_grad_flow(values: Tensor) -> None: + assert_grad_flow(_imtl(tuple(values.shape)), values) + + +@mark.parametrize("values", all_inputs) +def test_grad_flows_to_log_scale(values: Tensor) -> None: + scalarizer = _imtl(tuple(values.shape)) + scalarizer(values).backward() + assert scalarizer.log_scale.grad is not None + assert scalarizer.log_scale.grad.isfinite().all() + + +@mark.parametrize( + ["param_shape", "values_shape", "expectation"], + [ + ((5,), (5,), does_not_raise()), + ((3, 4), (3, 4), does_not_raise()), + ((), (), does_not_raise()), + ((5,), (4,), raises(ValueError)), + ((5,), (5, 1), raises(ValueError)), + ((3, 4), (4, 3), raises(ValueError)), + ], +) +def test_shape_check( + param_shape: tuple[int, ...], + values_shape: tuple[int, ...], + expectation: ExceptionContext, +) -> None: + scalarizer = _imtl(param_shape) + values = ones_(values_shape) + with expectation: + _ = scalarizer(values) + + +def test_reset_restores_initial_log_scale() -> None: + scalarizer = _imtl((3,)) + with torch.no_grad(): + scalarizer.log_scale.add_(1.0) + scalarizer.reset() + torch.testing.assert_close(scalarizer.log_scale.detach(), zeros_((3,))) + + +def test_does_not_raise_on_negative_input() -> None: + # IMTL-L is designed for positive losses but does not enforce a positivity precondition. + values = tensor_([-1.0, -2.0, 3.0]) + assert_returns_scalar(_imtl((3,)), values) + + +def test_is_trainable() -> None: + scalarizer = _imtl((2,)) + optimizer = torch.optim.SGD(scalarizer.parameters(), lr=0.1) + values = tensor_([2.0, 5.0]) + optimizer.zero_grad() + scalarizer(values).backward() + optimizer.step() + assert not torch.equal(scalarizer.log_scale.detach(), zeros_((2,))) + + +def test_representations() -> None: + assert repr(IMTL(3)) == "IMTL(shape=(3,))" + assert repr(IMTL((2, 3))) == "IMTL(shape=(2, 3))" + assert str(IMTL(3)) == "IMTL" From 43f4794cfb1378be02e0f36c70895db475c93406 Mon Sep 17 00:00:00 2001 From: ppraneth Date: Tue, 9 Jun 2026 08:31:55 +0530 Subject: [PATCH 2/5] add test cases Signed-off-by: ppraneth --- src/torchjd/scalarization/_imtl.py | 14 +++++++++++--- tests/unit/scalarization/test_imtl.py | 15 ++++++++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/torchjd/scalarization/_imtl.py b/src/torchjd/scalarization/_imtl.py index 3b808a8d..a83d886f 100644 --- a/src/torchjd/scalarization/_imtl.py +++ b/src/torchjd/scalarization/_imtl.py @@ -31,9 +31,17 @@ class IMTL(Scalarizer, Stateful): magnitude across tasks, while the :math:`- s_i` term is a regularizer that prevents the trivial solution :math:`s_i \to -\infty`. The :math:`s_i` are stored as an ``nn.Parameter``, so the parameters of this scalarizer must be passed to the optimizer to be learned jointly with the - model. Unlike :class:`~torchjd.scalarization.UW`, IMTL-L makes no distribution assumption and - applies to any kind of loss. The complementary gradient-balancing variant (IMTL-G) is provided - as the :class:`~torchjd.aggregation.IMTLG` aggregator. + model. + + Although it is derived without any distribution assumption (unlike + :class:`~torchjd.scalarization.UW`, which is derived from Gaussian/Laplace likelihoods), IMTL-L + is in fact almost equivalent to :class:`~torchjd.scalarization.UW`: this scalarization equals + :math:`2\,\mathrm{UW}` evaluated at the negated parameter, so the two differ only by a constant + factor of two and the sign convention of the learned parameter, and share the same per-task + weighting and the same optima. + + The complementary gradient-balancing variant (IMTL-G) is provided as the + :class:`~torchjd.aggregation.IMTLG` aggregator. :param shape: The shape of the values to scalarize, used to create one scale per value. An ``int`` ``n`` is interpreted as the shape ``(n,)``. diff --git a/tests/unit/scalarization/test_imtl.py b/tests/unit/scalarization/test_imtl.py index ad535584..47e7c101 100644 --- a/tests/unit/scalarization/test_imtl.py +++ b/tests/unit/scalarization/test_imtl.py @@ -7,7 +7,7 @@ from utils.contexts import ExceptionContext from utils.tensors import ones_, tensor_, zeros_ -from torchjd.scalarization import IMTL +from torchjd.scalarization import IMTL, UW from ._asserts import assert_grad_flow, assert_returns_scalar from ._inputs import all_inputs @@ -94,6 +94,19 @@ def test_is_trainable() -> None: assert not torch.equal(scalarizer.log_scale.detach(), zeros_((2,))) +def test_equivalent_to_uw_up_to_factor_and_sign() -> None: + # Locks the documented relationship: IMTL(s) == 2 * UW(-s), i.e. the two scalarizations are + # equal up to a constant factor of 2 and the sign of the learned parameter. + values = tensor_([0.5, 2.0, 4.0]) + imtl = _imtl((3,)) + uw = UW((3,)).to(device=DEVICE, dtype=DTYPE) + with torch.no_grad(): + s = tensor_([0.3, -0.7, 1.2]) + imtl.log_scale.copy_(s) + uw.log_var.copy_(-s) + torch.testing.assert_close(imtl(values), 2.0 * uw(values)) + + def test_representations() -> None: assert repr(IMTL(3)) == "IMTL(shape=(3,))" assert repr(IMTL((2, 3))) == "IMTL(shape=(2, 3))" From 5086575a778a47ed99e0d25f093bafd924d7b76b Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri Date: Tue, 9 Jun 2026 18:52:55 +0530 Subject: [PATCH 3/5] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7fa233e..2e259ea0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `IMTL` (the loss-balancing variant, IMTL-L, of Impartial Multi-Task Learning) from [Towards +- Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful `Scalarizer` that learns a per-task scale `s_i` and combines the values as `Σ (exp(s_i) · L_i − s_i)`. From c38bf4e3189e9804532b48a1301aa8c2a1f093a2 Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri Date: Tue, 9 Jun 2026 18:53:07 +0530 Subject: [PATCH 4/5] Update src/torchjd/scalarization/_imtl.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- src/torchjd/scalarization/_imtl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/scalarization/_imtl.py b/src/torchjd/scalarization/_imtl.py index a83d886f..690ead9e 100644 --- a/src/torchjd/scalarization/_imtl.py +++ b/src/torchjd/scalarization/_imtl.py @@ -12,7 +12,7 @@ class IMTL(Scalarizer, Stateful): r""" :class:`~torchjd.Stateful` :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using learned - per-task scales. ``IMTL`` is the loss-balancing variant (called IMTL-L in the paper) of Impartial + per-task scales. ``IMTL-L`` is the loss-balancing variant of Impartial Multi-Task Learning, proposed in `Towards Impartial Multi-Task Learning `_. From 9ce44564563be1d380bf18db05c3b5762d76ab46 Mon Sep 17 00:00:00 2001 From: ppraneth Date: Tue, 9 Jun 2026 19:05:22 +0530 Subject: [PATCH 5/5] minor changes Signed-off-by: ppraneth --- docs/source/docs/scalarization/imtl.rst | 7 ---- docs/source/docs/scalarization/imtl_l.rst | 7 ++++ docs/source/docs/scalarization/index.rst | 2 +- src/torchjd/scalarization/__init__.py | 4 +- .../scalarization/{_imtl.py => _imtl_l.py} | 6 +-- .../{test_imtl.py => test_imtl_l.py} | 42 +++++++++---------- 6 files changed, 34 insertions(+), 34 deletions(-) delete mode 100644 docs/source/docs/scalarization/imtl.rst create mode 100644 docs/source/docs/scalarization/imtl_l.rst rename src/torchjd/scalarization/{_imtl.py => _imtl_l.py} (95%) rename tests/unit/scalarization/{test_imtl.py => test_imtl_l.py} (70%) diff --git a/docs/source/docs/scalarization/imtl.rst b/docs/source/docs/scalarization/imtl.rst deleted file mode 100644 index e0846dd3..00000000 --- a/docs/source/docs/scalarization/imtl.rst +++ /dev/null @@ -1,7 +0,0 @@ -:hide-toc: - -IMTL -==== - -.. autoclass:: torchjd.scalarization.IMTL - :members: __call__, reset diff --git a/docs/source/docs/scalarization/imtl_l.rst b/docs/source/docs/scalarization/imtl_l.rst new file mode 100644 index 00000000..c95dca27 --- /dev/null +++ b/docs/source/docs/scalarization/imtl_l.rst @@ -0,0 +1,7 @@ +:hide-toc: + +IMTL-L +====== + +.. autoclass:: torchjd.scalarization.IMTLL + :members: __call__, reset diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst index 75b03a04..f2d42be7 100644 --- a/docs/source/docs/scalarization/index.rst +++ b/docs/source/docs/scalarization/index.rst @@ -16,7 +16,7 @@ Abstract base class constant.rst geometric_mean.rst - imtl.rst + imtl_l.rst mean.rst random.rst stch.rst diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index 542112e9..6b37d281 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -21,7 +21,7 @@ from ._constant import Constant from ._geometric_mean import GeometricMean -from ._imtl import IMTL +from ._imtl_l import IMTLL from ._mean import Mean from ._random import Random from ._scalarizer_base import Scalarizer @@ -32,7 +32,7 @@ __all__ = [ "Constant", "GeometricMean", - "IMTL", + "IMTLL", "Mean", "Random", "Scalarizer", diff --git a/src/torchjd/scalarization/_imtl.py b/src/torchjd/scalarization/_imtl_l.py similarity index 95% rename from src/torchjd/scalarization/_imtl.py rename to src/torchjd/scalarization/_imtl_l.py index 690ead9e..91445f77 100644 --- a/src/torchjd/scalarization/_imtl.py +++ b/src/torchjd/scalarization/_imtl_l.py @@ -8,7 +8,7 @@ from ._scalarizer_base import Scalarizer -class IMTL(Scalarizer, Stateful): +class IMTLL(Scalarizer, Stateful): r""" :class:`~torchjd.Stateful` :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using learned @@ -52,10 +52,10 @@ class IMTL(Scalarizer, Stateful): >>> import torch >>> from torch.nn import Linear >>> - >>> from torchjd.scalarization import IMTL + >>> from torchjd.scalarization import IMTLL >>> >>> model = Linear(3, 2) - >>> scalarizer = IMTL(2) # Move to the right device with e.g. IMTL(2).to(device="cuda") + >>> scalarizer = IMTLL(2) # Move to the right device with e.g. IMTLL(2).to(device="cuda") >>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1) >>> >>> features = torch.randn(8, 3) diff --git a/tests/unit/scalarization/test_imtl.py b/tests/unit/scalarization/test_imtl_l.py similarity index 70% rename from tests/unit/scalarization/test_imtl.py rename to tests/unit/scalarization/test_imtl_l.py index 47e7c101..cdb6790d 100644 --- a/tests/unit/scalarization/test_imtl.py +++ b/tests/unit/scalarization/test_imtl_l.py @@ -7,42 +7,42 @@ from utils.contexts import ExceptionContext from utils.tensors import ones_, tensor_, zeros_ -from torchjd.scalarization import IMTL, UW +from torchjd.scalarization import IMTLL, UW from ._asserts import assert_grad_flow, assert_returns_scalar from ._inputs import all_inputs -def _imtl(shape: int | tuple[int, ...]) -> IMTL: - """Builds an `IMTL` whose scales live on the test device and dtype.""" - return IMTL(shape).to(device=DEVICE, dtype=DTYPE) +def _imtl_l(shape: int | tuple[int, ...]) -> IMTLL: + """Builds an `IMTLL` whose scales live on the test device and dtype.""" + return IMTLL(shape).to(device=DEVICE, dtype=DTYPE) def test_value() -> None: # With scales initialized to 0, exp(0)=1 and -0=0, so the result is sum(values). values = tensor_([1.0, 2.0, 4.0]) - torch.testing.assert_close(_imtl((3,))(values), tensor_(7.0)) + torch.testing.assert_close(_imtl_l((3,))(values), tensor_(7.0)) def test_int_shape_matches_tuple_shape() -> None: values = tensor_([1.0, 2.0, 4.0]) - assert IMTL(3).log_scale.shape == (3,) - torch.testing.assert_close(_imtl(3)(values), _imtl((3,))(values)) + assert IMTLL(3).log_scale.shape == (3,) + torch.testing.assert_close(_imtl_l(3)(values), _imtl_l((3,))(values)) @mark.parametrize("values", all_inputs) def test_expected_structure(values: Tensor) -> None: - assert_returns_scalar(_imtl(tuple(values.shape)), values) + assert_returns_scalar(_imtl_l(tuple(values.shape)), values) @mark.parametrize("values", all_inputs) def test_grad_flow(values: Tensor) -> None: - assert_grad_flow(_imtl(tuple(values.shape)), values) + assert_grad_flow(_imtl_l(tuple(values.shape)), values) @mark.parametrize("values", all_inputs) def test_grad_flows_to_log_scale(values: Tensor) -> None: - scalarizer = _imtl(tuple(values.shape)) + scalarizer = _imtl_l(tuple(values.shape)) scalarizer(values).backward() assert scalarizer.log_scale.grad is not None assert scalarizer.log_scale.grad.isfinite().all() @@ -64,14 +64,14 @@ def test_shape_check( values_shape: tuple[int, ...], expectation: ExceptionContext, ) -> None: - scalarizer = _imtl(param_shape) + scalarizer = _imtl_l(param_shape) values = ones_(values_shape) with expectation: _ = scalarizer(values) def test_reset_restores_initial_log_scale() -> None: - scalarizer = _imtl((3,)) + scalarizer = _imtl_l((3,)) with torch.no_grad(): scalarizer.log_scale.add_(1.0) scalarizer.reset() @@ -81,11 +81,11 @@ def test_reset_restores_initial_log_scale() -> None: def test_does_not_raise_on_negative_input() -> None: # IMTL-L is designed for positive losses but does not enforce a positivity precondition. values = tensor_([-1.0, -2.0, 3.0]) - assert_returns_scalar(_imtl((3,)), values) + assert_returns_scalar(_imtl_l((3,)), values) def test_is_trainable() -> None: - scalarizer = _imtl((2,)) + scalarizer = _imtl_l((2,)) optimizer = torch.optim.SGD(scalarizer.parameters(), lr=0.1) values = tensor_([2.0, 5.0]) optimizer.zero_grad() @@ -95,19 +95,19 @@ def test_is_trainable() -> None: def test_equivalent_to_uw_up_to_factor_and_sign() -> None: - # Locks the documented relationship: IMTL(s) == 2 * UW(-s), i.e. the two scalarizations are + # Locks the documented relationship: IMTL-L(s) == 2 * UW(-s), i.e. the two scalarizations are # equal up to a constant factor of 2 and the sign of the learned parameter. values = tensor_([0.5, 2.0, 4.0]) - imtl = _imtl((3,)) + imtl_l = _imtl_l((3,)) uw = UW((3,)).to(device=DEVICE, dtype=DTYPE) with torch.no_grad(): s = tensor_([0.3, -0.7, 1.2]) - imtl.log_scale.copy_(s) + imtl_l.log_scale.copy_(s) uw.log_var.copy_(-s) - torch.testing.assert_close(imtl(values), 2.0 * uw(values)) + torch.testing.assert_close(imtl_l(values), 2.0 * uw(values)) def test_representations() -> None: - assert repr(IMTL(3)) == "IMTL(shape=(3,))" - assert repr(IMTL((2, 3))) == "IMTL(shape=(2, 3))" - assert str(IMTL(3)) == "IMTL" + assert repr(IMTLL(3)) == "IMTLL(shape=(3,))" + assert repr(IMTLL((2, 3))) == "IMTLL(shape=(2, 3))" + assert str(IMTLL(3)) == "IMTLL"