diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py
index 4fbd79d0..b2da3efe 100644
--- a/src/torchjd/aggregation/_cagrad.py
+++ b/src/torchjd/aggregation/_cagrad.py
@@ -18,7 +18,7 @@
# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
-class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
+class CAGradWeighting(_WithOptionalDeps, _GramianWeighting, _NonDifferentiable):
_REQUIRED_DEPS = ["numpy", "cvxpy", "clarabel"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"'
"""
@@ -94,7 +94,7 @@ def norm_eps(self, value: float) -> None:
self._norm_eps = value
-class CAGrad(_NonDifferentiable, GramianWeightedAggregator):
+class CAGrad(GramianWeightedAggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py
index 54e2c3dc..7ca654b7 100644
--- a/src/torchjd/aggregation/_config.py
+++ b/src/torchjd/aggregation/_config.py
@@ -14,7 +14,7 @@
# Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context.
-class ConFIG(_NonDifferentiable, Aggregator):
+class ConFIG(Aggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG:
Towards Conflict-free Training of Physics Informed Neural Networks
diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py
index e839d6e8..15b6aa87 100644
--- a/src/torchjd/aggregation/_dualproj.py
+++ b/src/torchjd/aggregation/_dualproj.py
@@ -11,7 +11,7 @@
# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph.
-class DualProjWeighting(_NonDifferentiable, _GramianWeighting):
+class DualProjWeighting(_GramianWeighting, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.DualProj`.
@@ -53,7 +53,7 @@ def projector(self, value: DualConeProjector | None) -> None:
self._projector = projector_or_default(value)
-class DualProj(_NonDifferentiable, GramianWeightedAggregator):
+class DualProj(GramianWeightedAggregator, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input
matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds
diff --git a/src/torchjd/aggregation/_fairgrad.py b/src/torchjd/aggregation/_fairgrad.py
index 4b2cc261..7bdb5ab3 100644
--- a/src/torchjd/aggregation/_fairgrad.py
+++ b/src/torchjd/aggregation/_fairgrad.py
@@ -21,7 +21,7 @@
# Non-differentiable: the scipy solver operates on numpy arrays, breaking the autograd graph.
-class FairGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
+class FairGradWeighting(_WithOptionalDeps, _GramianWeighting, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the
weights of :class:`~torchjd.aggregation.FairGrad`, as defined in Equation 4 of `Fair Resource
@@ -78,7 +78,7 @@ def alpha(self, value: float) -> None:
self._alpha = value
-class FairGrad(_NonDifferentiable, GramianWeightedAggregator):
+class FairGrad(GramianWeightedAggregator, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` using the step decision of Algorithm 1
of `Fair Resource Allocation in Multi-Task Learning
diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py
index 31590ebf..ec91a8cf 100644
--- a/src/torchjd/aggregation/_graddrop.py
+++ b/src/torchjd/aggregation/_graddrop.py
@@ -14,7 +14,7 @@ def _identity(P: Tensor) -> Tensor:
# Non-differentiable: the sign-based random masking is not differentiable.
-class GradDrop(_NonDifferentiable, Aggregator):
+class GradDrop(Aggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.Aggregator` that applies the gradient combination
steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py
index 61469407..8a031e64 100644
--- a/src/torchjd/aggregation/_gradvac.py
+++ b/src/torchjd/aggregation/_gradvac.py
@@ -13,7 +13,7 @@
# Non-differentiable: weights are modified in-place during the gradient correction loop.
-class GradVacWeighting(_NonDifferentiable, Stateful, _GramianWeighting):
+class GradVacWeighting(_GramianWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
@@ -128,7 +128,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
self._state_key = key
-class GradVac(_NonDifferentiable, Stateful, GramianWeightedAggregator):
+class GradVac(GramianWeightedAggregator, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of
diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py
index 47504e83..8cbf65f5 100644
--- a/src/torchjd/aggregation/_imtl_g.py
+++ b/src/torchjd/aggregation/_imtl_g.py
@@ -9,7 +9,7 @@
# Non-differentiable: differentiating through pinv(gramian) would give incorrect gradients.
-class IMTLGWeighting(_NonDifferentiable, _GramianWeighting):
+class IMTLGWeighting(_GramianWeighting, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.IMTLG`.
@@ -25,7 +25,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights
-class IMTLG(_NonDifferentiable, GramianWeightedAggregator):
+class IMTLG(GramianWeightedAggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in
`Towards Impartial Multi-task Learning `_.
diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py
index 29bf5592..b906140b 100644
--- a/src/torchjd/aggregation/_mixins.py
+++ b/src/torchjd/aggregation/_mixins.py
@@ -19,9 +19,8 @@ class _NonDifferentiable(nn.Module):
the call in :func:`torch.no_grad`.
.. warning::
- This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance
- list. Placing it after will silently have no effect, because :meth:`__call__` would be
- resolved to :class:`torch.nn.Module` before reaching this mixin.
+ Placing this mixin *before* the primary base will cause it to shadow the primary class's
+ :meth:`__call__` signature in generated documentation.
"""
def __call__(self, *args: Any, **kwargs: Any) -> Any:
diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py
index 1932b079..879b4221 100644
--- a/src/torchjd/aggregation/_nash_mtl.py
+++ b/src/torchjd/aggregation/_nash_mtl.py
@@ -21,7 +21,7 @@
# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
-class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting):
+class _NashMTLWeighting(_WithOptionalDeps, _MatrixWeighting, Stateful, _NonDifferentiable):
_REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"'
"""
@@ -204,7 +204,7 @@ def reset(self) -> None:
self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)
-class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator):
+class NashMTL(WeightedAggregator, Stateful, _NonDifferentiable):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of
diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py
index ffce10d3..8a20b496 100644
--- a/src/torchjd/aggregation/_pcgrad.py
+++ b/src/torchjd/aggregation/_pcgrad.py
@@ -11,7 +11,7 @@
# Non-differentiable: weights are modified in-place during the gradient projection loop.
-class PCGradWeighting(_NonDifferentiable, _GramianWeighting):
+class PCGradWeighting(_GramianWeighting, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.PCGrad`.
@@ -47,7 +47,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights.to(device)
-class PCGrad(_NonDifferentiable, GramianWeightedAggregator):
+class PCGrad(GramianWeightedAggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Gradient Surgery for Multi-Task Learning `_.
diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py
index 691232eb..e85b15f4 100644
--- a/src/torchjd/aggregation/_upgrad.py
+++ b/src/torchjd/aggregation/_upgrad.py
@@ -12,7 +12,7 @@
# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph.
-class UPGradWeighting(_NonDifferentiable, _GramianWeighting):
+class UPGradWeighting(_GramianWeighting, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.UPGrad`.
@@ -54,7 +54,7 @@ def projector(self, value: DualConeProjector | None) -> None:
self._projector = projector_or_default(value)
-class UPGrad(_NonDifferentiable, GramianWeightedAggregator):
+class UPGrad(GramianWeightedAggregator, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` that projects each row of the input
matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed