Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ changelog does not include internal changes that do not affect the user.
Suggested change: `mtl_backward(losses=losses, features=features)` =>
`mtl_backward(losses, features=features)`. The `features` parameter remains usable as positional
or keyword. All other parameters are now keyword-only.
- `Aggregator.__call__`: The `matrix` parameter is now positonal-only. Suggested change:
- `Aggregator.__call__`: The `matrix` parameter is now positional-only. Suggested change:
`aggregator(matrix=matrix)` => `aggregator(matrix)`.
- `Weighting.__call__`: The `stat` parameter is now positional-only. Suggested change:
`weighting(stat=gramian)` => `weighting(gramian)`.
Expand Down Expand Up @@ -177,7 +177,7 @@ changelog does not include internal changes that do not affect the user.

- Made some aggregators (`CAGrad`, `ConFIG`, `DualProj`, `GradDrop`, `IMTLG`, `NashMTL`, `PCGrad`
and `UPGrad`) raise a `NonDifferentiableError` whenever one tries to differentiate through them.
Before this change, trying to differentiate through them leaded to wrong gradients or unclear
Before this change, trying to differentiate through them led to wrong gradients or unclear
errors.

### Added
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo

## Release Methodology

We try to make a release whenever have something worth sharing to users (bug fix, minor or large
We try to make a release whenever we have something worth sharing to users (bug fix, minor or large
feature, etc.). TorchJD follows [semantic versioning](https://semver.org/). Since the library is
still in beta (`0.x.y`), we sometimes make interface changes in minor versions. We prioritize the
long-term quality of the library, which occasionally means introducing breaking changes. Whenever a
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Perform the Jacobian descent backward pass:

The first function will populate the ``.jac`` field of each model parameter with the corresponding
Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad``
field of the parameters. It also deletes the ``.jac`` fields save some memory.
field of the parameters. It also deletes the ``.jac`` fields to save some memory.

Update each parameter based on its ``.grad`` field, using the ``optimizer``:

Expand Down
10 changes: 5 additions & 5 deletions docs/source/examples/grouping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ The aggregation can be made independently on groups of parameters, at different
the parameters:

1. **Together** (baseline): one group covering all parameters. Corresponds to the `whole_model`
stategy in the paper.
strategy in the paper.

2. **Per network**: one group per top-level sub-network (e.g. encoder and decoder separately).
Corresponds to the `enc_dec` stategy in the paper.
Corresponds to the `enc_dec` strategy in the paper.

3. **Per layer**: one group per leaf module of the network. Corresponds to the `all_layer` stategy
3. **Per layer**: one group per leaf module of the network. Corresponds to the `all_layer` strategy
in the paper.

4. **Per tensor**: one group per individual parameter tensor. Corresponds to the `all_matrix`
stategy in the paper.
strategy in the paper.

In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated
aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance
should independently maintains its own state (e.g. the EMA :math:`\hat{\phi}` state in
should independently maintain its own state (e.g. the EMA :math:`\hat{\phi}` state in
:class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper).

.. note::
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/_linalg/_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from torch import Tensor

# Note: we're using classes and inherittance instead of NewType because it's possible to have
# multiple inherittance but there is no type intersection. However, these classes should never be
# Note: we're using classes and inheritance instead of NewType because it's possible to have
# multiple inheritance but there is no type intersection. However, these classes should never be
# instantiated: they're only used for static type checking.


Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class GradDrop(Aggregator):
Optimizing Deep Multitask Models with Gradient Sign Dropout
<https://arxiv.org/pdf/2010.06808.pdf>`_.

:param f: The function to apply to the Gradient Positive Sign Purity. It should be monotically
:param f: The function to apply to the Gradient Positive Sign Purity. It should be monotonically
increasing. Defaults to identity.
:param leak: The tensor of leak values, determining how much each row is allowed to leak
through. Defaults to None, which means no leak.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class IMTLG(GramianWeightedAggregator):
:class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in
`Towards Impartial Multi-task Learning <https://discovery.ucl.ac.uk/id/eprint/10120667/>`_.
This generalization, defined formally in `Jacobian Descent For Multi-Objective Optimization
<https://arxiv.org/pdf/2406.16232>`_, supports matrices with some linearly dependant rows.
<https://arxiv.org/pdf/2406.16232>`_, supports matrices with some linearly dependent rows.
"""

gramian_weighting: IMTLGWeighting
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:

class PCGrad(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in algorithm 1 of
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
"""

Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def forward(self, matrix: Tensor, /) -> Tensor:
class Random(WeightedAggregator):
"""
:class:`~torchjd.aggregation.WeightedAggregator` that computes a random combination of
the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of
the rows of the provided matrices, as defined in Algorithm 2 of `Reasonable Effectiveness of
Random Weighting: A Litmus Test for Multi-Task Learning
<https://arxiv.org/pdf/2111.10603.pdf>`_.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_jacobian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]

vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1]

# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
# vjp_func is a function that computes the vjp w.r.t. the primals (tuple). Here the
# functional has a single primal which is dict(module.named_parameters()). We therefore take
# the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
gradients = vjp_func(grad_outputs_j_)[0]
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class ModuleHookManager:
"""
Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation.
Class responsible for handling hooks and Nodes that compute the Gramian reverse accumulation.

:param target_edges: Registry for tracking gradient edges that serve as targets for the first
differentiation.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _create_task_transform(
backpropagate = Select(features)

# Transform that accumulates the gradient of the tensor w.r.t. the task-specific parameters into
# their .grad fields and backpropagates the gradient of the tensor w.r.t. to the features.
# their .grad fields and backpropagates the gradient of the tensor w.r.t. the features.
backward_task = (backpropagate | accumulate) << grad << Select(tensor)
return backward_task

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/autojac/_transform/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
class Grad(Differentiate):
"""
Transform from Gradients to Gradients, computing the gradient of each output element with
respect to each input tensor, and applying the linear transformations represented by provided
the grad_outputs to the results.
respect to each input tensor, and applying the linear transformations represented by the
grad_outputs to the results.

:param outputs: Tensors to differentiate.
:param inputs: Tensors with respect to which we differentiate.
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autojac/test_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def test_repeated_task_params() -> None:

def test_grad_tensors_value_is_correct() -> None:
"""
Tests that mtl_ackward correctly computes the element-wise product of grad_tensors and the
Tests that mtl_backward correctly computes the element-wise product of grad_tensors and the
tensors.
"""

Expand Down
Loading