From e8647ed7ab77acc5be18375b29d9424955e183f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 20 May 2026 21:36:43 +0200 Subject: [PATCH 1/2] fix(aggregation): Fix __call__ docs by placing _NonDifferentiable last in MRO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this change, every non-differentiable aggregator/weighting had _NonDifferentiable listed first in its base-class tuple, e.g. `class PCGradWeighting(_NonDifferentiable, _GramianWeighting)`. Because Python's MRO resolves `__call__` to the first class that defines it, Sphinx documented the method with _NonDifferentiable.__call__'s generic `(*args, **kwargs)` signature instead of the more specific one from _GramianWeighting or Aggregator. The fix is to list _NonDifferentiable after the primary base class. The cooperative super().__call__() chain then becomes: _GramianWeighting.__call__(gramian) → Weighting.__call__(gramian) → _NonDifferentiable.__call__(*args) [applies no_grad] → nn.Module.__call__(...) The no_grad wrapping is fully preserved because every class in the chain calls super().__call__(), so _NonDifferentiable is still reached — just later in the chain. The old warning in _NonDifferentiable said it must appear "before any nn.Module base class", which was imprecise: what actually matters is that it appears before nn.Module *itself* in the resolved MRO, which C3 linearization guarantees as long as super() chains are cooperative. All 2982 unit tests pass. Generated docs now show the correct parameter names (gramian / matrix) for every affected __call__. Co-Authored-By: Claude Sonnet 4.6 --- src/torchjd/aggregation/_cagrad.py | 4 ++-- src/torchjd/aggregation/_config.py | 2 +- src/torchjd/aggregation/_dualproj.py | 4 ++-- src/torchjd/aggregation/_fairgrad.py | 4 ++-- src/torchjd/aggregation/_graddrop.py | 2 +- src/torchjd/aggregation/_gradvac.py | 4 ++-- src/torchjd/aggregation/_imtl_g.py | 4 ++-- src/torchjd/aggregation/_mixins.py | 10 +++++++--- src/torchjd/aggregation/_nash_mtl.py | 4 ++-- src/torchjd/aggregation/_pcgrad.py | 4 ++-- src/torchjd/aggregation/_upgrad.py | 4 ++-- 11 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 4fbd79d0..b2da3efe 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -18,7 +18,7 @@ # Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. -class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting): +class CAGradWeighting(_WithOptionalDeps, _GramianWeighting, _NonDifferentiable): _REQUIRED_DEPS = ["numpy", "cvxpy", "clarabel"] _INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"' """ @@ -94,7 +94,7 @@ def norm_eps(self, value: float) -> None: self._norm_eps = value -class CAGrad(_NonDifferentiable, GramianWeightedAggregator): +class CAGrad(GramianWeightedAggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of `Conflict-Averse Gradient Descent for Multi-task Learning diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 54e2c3dc..7ca654b7 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -14,7 +14,7 @@ # Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context. -class ConFIG(_NonDifferentiable, Aggregator): +class ConFIG(Aggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index e839d6e8..15b6aa87 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -11,7 +11,7 @@ # Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph. -class DualProjWeighting(_NonDifferentiable, _GramianWeighting): +class DualProjWeighting(_GramianWeighting, _NonDifferentiable): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.DualProj`. @@ -53,7 +53,7 @@ def projector(self, value: DualConeProjector | None) -> None: self._projector = projector_or_default(value) -class DualProj(_NonDifferentiable, GramianWeightedAggregator): +class DualProj(GramianWeightedAggregator, _NonDifferentiable): r""" :class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds diff --git a/src/torchjd/aggregation/_fairgrad.py b/src/torchjd/aggregation/_fairgrad.py index 4b2cc261..7bdb5ab3 100644 --- a/src/torchjd/aggregation/_fairgrad.py +++ b/src/torchjd/aggregation/_fairgrad.py @@ -21,7 +21,7 @@ # Non-differentiable: the scipy solver operates on numpy arrays, breaking the autograd graph. -class FairGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting): +class FairGradWeighting(_WithOptionalDeps, _GramianWeighting, _NonDifferentiable): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.FairGrad`, as defined in Equation 4 of `Fair Resource @@ -78,7 +78,7 @@ def alpha(self, value: float) -> None: self._alpha = value -class FairGrad(_NonDifferentiable, GramianWeightedAggregator): +class FairGrad(GramianWeightedAggregator, _NonDifferentiable): r""" :class:`~torchjd.aggregation.GramianWeightedAggregator` using the step decision of Algorithm 1 of `Fair Resource Allocation in Multi-Task Learning diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index 31590ebf..ec91a8cf 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -14,7 +14,7 @@ def _identity(P: Tensor) -> Tensor: # Non-differentiable: the sign-based random masking is not differentiable. -class GradDrop(_NonDifferentiable, Aggregator): +class GradDrop(Aggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.Aggregator` that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign: diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 61469407..8a031e64 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -13,7 +13,7 @@ # Non-differentiable: weights are modified in-place during the gradient correction loop. -class GradVacWeighting(_NonDifferentiable, Stateful, _GramianWeighting): +class GradVacWeighting(_GramianWeighting, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] @@ -128,7 +128,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None: self._state_key = key -class GradVac(_NonDifferentiable, Stateful, GramianWeightedAggregator): +class GradVac(GramianWeightedAggregator, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 47504e83..8cbf65f5 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -9,7 +9,7 @@ # Non-differentiable: differentiating through pinv(gramian) would give incorrect gradients. -class IMTLGWeighting(_NonDifferentiable, _GramianWeighting): +class IMTLGWeighting(_GramianWeighting, _NonDifferentiable): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.IMTLG`. @@ -25,7 +25,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights -class IMTLG(_NonDifferentiable, GramianWeightedAggregator): +class IMTLG(GramianWeightedAggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in `Towards Impartial Multi-task Learning `_. diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 29bf5592..1c693061 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -19,9 +19,13 @@ class _NonDifferentiable(nn.Module): the call in :func:`torch.no_grad`. .. warning:: - This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance - list. Placing it after will silently have no effect, because :meth:`__call__` would be - resolved to :class:`torch.nn.Module` before reaching this mixin. + This mixin must appear **after** the primary base class (e.g. + :class:`~torchjd.aggregation.Aggregator`, + :class:`~torchjd.aggregation._weighting_bases._GramianWeighting`) in the inheritance list, + so that the primary class's :meth:`__call__` is resolved first and its ``super().__call__`` + call chains through this mixin before reaching :class:`torch.nn.Module`. Placing this mixin + *before* the primary base will cause it to shadow the primary class's :meth:`__call__` + signature in generated documentation. """ def __call__(self, *args: Any, **kwargs: Any) -> Any: diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 1932b079..879b4221 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -21,7 +21,7 @@ # Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. -class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting): +class _NashMTLWeighting(_WithOptionalDeps, _MatrixWeighting, Stateful, _NonDifferentiable): _REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"] _INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"' """ @@ -204,7 +204,7 @@ def reset(self) -> None: self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) -class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator): +class NashMTL(WeightedAggregator, Stateful, _NonDifferentiable): """ :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index ffce10d3..8a20b496 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -11,7 +11,7 @@ # Non-differentiable: weights are modified in-place during the gradient projection loop. -class PCGradWeighting(_NonDifferentiable, _GramianWeighting): +class PCGradWeighting(_GramianWeighting, _NonDifferentiable): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.PCGrad`. @@ -47,7 +47,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights.to(device) -class PCGrad(_NonDifferentiable, GramianWeightedAggregator): +class PCGrad(GramianWeightedAggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 691232eb..e85b15f4 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -12,7 +12,7 @@ # Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph. -class UPGradWeighting(_NonDifferentiable, _GramianWeighting): +class UPGradWeighting(_GramianWeighting, _NonDifferentiable): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.UPGrad`. @@ -54,7 +54,7 @@ def projector(self, value: DualConeProjector | None) -> None: self._projector = projector_or_default(value) -class UPGrad(_NonDifferentiable, GramianWeightedAggregator): +class UPGrad(GramianWeightedAggregator, _NonDifferentiable): r""" :class:`~torchjd.aggregation.GramianWeightedAggregator` that projects each row of the input matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed From 2c563fd5050ca6d0cacff16378b0da042df0fbb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 20 May 2026 22:01:10 +0200 Subject: [PATCH 2/2] Simplify even more the docstring of _NonDifferentiable --- src/torchjd/aggregation/_mixins.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 1c693061..b906140b 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -19,13 +19,8 @@ class _NonDifferentiable(nn.Module): the call in :func:`torch.no_grad`. .. warning:: - This mixin must appear **after** the primary base class (e.g. - :class:`~torchjd.aggregation.Aggregator`, - :class:`~torchjd.aggregation._weighting_bases._GramianWeighting`) in the inheritance list, - so that the primary class's :meth:`__call__` is resolved first and its ``super().__call__`` - call chains through this mixin before reaching :class:`torch.nn.Module`. Placing this mixin - *before* the primary base will cause it to shadow the primary class's :meth:`__call__` - signature in generated documentation. + Placing this mixin *before* the primary base will cause it to shadow the primary class's + :meth:`__call__` signature in generated documentation. """ def __call__(self, *args: Any, **kwargs: Any) -> Any: