diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 4fbd79d0..b2da3efe 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -18,7 +18,7 @@ # Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. -class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting): +class CAGradWeighting(_WithOptionalDeps, _GramianWeighting, _NonDifferentiable): _REQUIRED_DEPS = ["numpy", "cvxpy", "clarabel"] _INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"' """ @@ -94,7 +94,7 @@ def norm_eps(self, value: float) -> None: self._norm_eps = value -class CAGrad(_NonDifferentiable, GramianWeightedAggregator): +class CAGrad(GramianWeightedAggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of `Conflict-Averse Gradient Descent for Multi-task Learning diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 54e2c3dc..7ca654b7 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -14,7 +14,7 @@ # Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context. -class ConFIG(_NonDifferentiable, Aggregator): +class ConFIG(Aggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index e839d6e8..15b6aa87 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -11,7 +11,7 @@ # Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph. -class DualProjWeighting(_NonDifferentiable, _GramianWeighting): +class DualProjWeighting(_GramianWeighting, _NonDifferentiable): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.DualProj`. @@ -53,7 +53,7 @@ def projector(self, value: DualConeProjector | None) -> None: self._projector = projector_or_default(value) -class DualProj(_NonDifferentiable, GramianWeightedAggregator): +class DualProj(GramianWeightedAggregator, _NonDifferentiable): r""" :class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds diff --git a/src/torchjd/aggregation/_fairgrad.py b/src/torchjd/aggregation/_fairgrad.py index 4b2cc261..7bdb5ab3 100644 --- a/src/torchjd/aggregation/_fairgrad.py +++ b/src/torchjd/aggregation/_fairgrad.py @@ -21,7 +21,7 @@ # Non-differentiable: the scipy solver operates on numpy arrays, breaking the autograd graph. -class FairGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting): +class FairGradWeighting(_WithOptionalDeps, _GramianWeighting, _NonDifferentiable): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.FairGrad`, as defined in Equation 4 of `Fair Resource @@ -78,7 +78,7 @@ def alpha(self, value: float) -> None: self._alpha = value -class FairGrad(_NonDifferentiable, GramianWeightedAggregator): +class FairGrad(GramianWeightedAggregator, _NonDifferentiable): r""" :class:`~torchjd.aggregation.GramianWeightedAggregator` using the step decision of Algorithm 1 of `Fair Resource Allocation in Multi-Task Learning diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index 31590ebf..ec91a8cf 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -14,7 +14,7 @@ def _identity(P: Tensor) -> Tensor: # Non-differentiable: the sign-based random masking is not differentiable. -class GradDrop(_NonDifferentiable, Aggregator): +class GradDrop(Aggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.Aggregator` that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign: diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 61469407..8a031e64 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -13,7 +13,7 @@ # Non-differentiable: weights are modified in-place during the gradient correction loop. -class GradVacWeighting(_NonDifferentiable, Stateful, _GramianWeighting): +class GradVacWeighting(_GramianWeighting, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] @@ -128,7 +128,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None: self._state_key = key -class GradVac(_NonDifferentiable, Stateful, GramianWeightedAggregator): +class GradVac(GramianWeightedAggregator, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 47504e83..8cbf65f5 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -9,7 +9,7 @@ # Non-differentiable: differentiating through pinv(gramian) would give incorrect gradients. -class IMTLGWeighting(_NonDifferentiable, _GramianWeighting): +class IMTLGWeighting(_GramianWeighting, _NonDifferentiable): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.IMTLG`. @@ -25,7 +25,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights -class IMTLG(_NonDifferentiable, GramianWeightedAggregator): +class IMTLG(GramianWeightedAggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in `Towards Impartial Multi-task Learning `_. diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 29bf5592..b906140b 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -19,9 +19,8 @@ class _NonDifferentiable(nn.Module): the call in :func:`torch.no_grad`. .. warning:: - This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance - list. Placing it after will silently have no effect, because :meth:`__call__` would be - resolved to :class:`torch.nn.Module` before reaching this mixin. + Placing this mixin *before* the primary base will cause it to shadow the primary class's + :meth:`__call__` signature in generated documentation. """ def __call__(self, *args: Any, **kwargs: Any) -> Any: diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 1932b079..879b4221 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -21,7 +21,7 @@ # Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. -class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting): +class _NashMTLWeighting(_WithOptionalDeps, _MatrixWeighting, Stateful, _NonDifferentiable): _REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"] _INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"' """ @@ -204,7 +204,7 @@ def reset(self) -> None: self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) -class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator): +class NashMTL(WeightedAggregator, Stateful, _NonDifferentiable): """ :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index ffce10d3..8a20b496 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -11,7 +11,7 @@ # Non-differentiable: weights are modified in-place during the gradient projection loop. -class PCGradWeighting(_NonDifferentiable, _GramianWeighting): +class PCGradWeighting(_GramianWeighting, _NonDifferentiable): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.PCGrad`. @@ -47,7 +47,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights.to(device) -class PCGrad(_NonDifferentiable, GramianWeightedAggregator): +class PCGrad(GramianWeightedAggregator, _NonDifferentiable): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 691232eb..e85b15f4 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -12,7 +12,7 @@ # Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph. -class UPGradWeighting(_NonDifferentiable, _GramianWeighting): +class UPGradWeighting(_GramianWeighting, _NonDifferentiable): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.UPGrad`. @@ -54,7 +54,7 @@ def projector(self, value: DualConeProjector | None) -> None: self._projector = projector_or_default(value) -class UPGrad(_NonDifferentiable, GramianWeightedAggregator): +class UPGrad(GramianWeightedAggregator, _NonDifferentiable): r""" :class:`~torchjd.aggregation.GramianWeightedAggregator` that projects each row of the input matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed