diff --git a/README.md b/README.md index 3d9b1d58..91e867ee 100644 --- a/README.md +++ b/README.md @@ -1,273 +1,34 @@ - - - - Fallback image description - +# torchjd-gradnorm ---- - -[![Doc](https://img.shields.io/badge/Doc-torchjd.org-blue?logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiIHN0YW5kYWxvbmU9Im5vIj8%2BCjwhLS0gQ3JlYXRlZCB1c2luZyBLcml0YTogaHR0cDovL2tyaXRhLm9yZyAtLT4KCjxzdmcKICAgd2lkdGg9IjIwNDcuNzJwdCIKICAgaGVpZ2h0PSIyMDQ3LjcycHQiCiAgIHZpZXdCb3g9IjAgMCAyMDQ3LjcyIDIwNDcuNzIiCiAgIHZlcnNpb249IjEuMSIKICAgaWQ9InN2ZzEiCiAgIHNvZGlwb2RpOmRvY25hbWU9IlRvcmNoSkRfbG9nb19jaXJjdWxhci5zdmciCiAgIGlua3NjYXBlOnZlcnNpb249IjEuMy4yICgwOTFlMjBlZjBmLCAyMDIzLTExLTI1KSIKICAgeG1sbnM6aW5rc2NhcGU9Imh0dHA6Ly93d3cuaW5rc2NhcGUub3JnL25hbWVzcGFjZXMvaW5rc2NhcGUiCiAgIHhtbG5zOnNvZGlwb2RpPSJodHRwOi8vc29kaXBvZGkuc291cmNlZm9yZ2UubmV0L0RURC9zb2RpcG9kaS0wLmR0ZCIKICAgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIgogICB4bWxuczpzdmc9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KICA8c29kaXBvZGk6bmFtZWR2aWV3CiAgICAgaWQ9Im5hbWVkdmlldzEiCiAgICAgcGFnZWNvbG9yPSIjZmZmZmZmIgogICAgIGJvcmRlcmNvbG9yPSIjNjY2NjY2IgogICAgIGJvcmRlcm9wYWNpdHk9IjEuMCIKICAgICBpbmtzY2FwZTpzaG93cGFnZXNoYWRvdz0iMiIKICAgICBpbmtzY2FwZTpwYWdlb3BhY2l0eT0iMC4wIgogICAgIGlua3NjYXBlOnBhZ2VjaGVja2VyYm9hcmQ9IjAiCiAgICAgaW5rc2NhcGU6ZGVza2NvbG9yPSIjZDFkMWQxIgogICAgIGlua3NjYXBlOmRvY3VtZW50LXVuaXRzPSJwdCIKICAgICBpbmtzY2FwZTp6b29tPSIwLjE2Mjk4NjE1IgogICAgIGlua3NjYXBlOmN4PSIxMzk1LjgyNDEiCiAgICAgaW5rc2NhcGU6Y3k9Ijg3NC4zMDczOSIKICAgICBpbmtzY2FwZTp3aW5kb3ctd2lkdGg9IjI1NjAiCiAgICAgaW5rc2NhcGU6d2luZG93LWhlaWdodD0iMTM3MSIKICAgICBpbmtzY2FwZTp3aW5kb3cteD0iMCIKICAgICBpbmtzY2FwZTp3aW5kb3cteT0iMCIKICAgICBpbmtzY2FwZTp3aW5kb3ctbWF4aW1pemVkPSIxIgogICAgIGlua3NjYXBlOmN1cnJlbnQtbGF5ZXI9InN2ZzEiIC8%2BCiAgPGRlZnMKICAgICBpZD0iZGVmczEiIC8%2BCiAgPHBhdGgKICAgICBpZD0ic2hhcGUxIgogICAgIGZpbGw9IiMwMDAwMDAiCiAgICAgZmlsbC1ydWxlPSJldmVub2RkIgogICAgIGQ9Ik0yNTUuMjE1IDg5OS44NzVMMjU1Ljk2NCAyNTUuOTY0TDc2Ny44OTMgMjU1Ljk2NEw3NjcuODkzIDBMMCAwTDAuMDMxMjUzMyA4OTguODQ0QzAuMDMxNzMwNSA4OTguODE0IDg0LjU3MjYgODk5Ljg3NSAyNTUuMjE1IDg5OS44NzVaIgogICAgIHN0eWxlPSJmaWxsOiMxYTgxZWI7ZmlsbC1vcGFjaXR5OjEiCiAgICAgdHJhbnNmb3JtPSJtYXRyaXgoMS4wMDAwMDAwMTQzMDcwNyAwIDAgMS4wMDAwMDAwMTQzMDcwNyAxMjcuOTgyMjI2NTIyMDU2IDEyNy45ODIyMjY1MjIwNTYpIiAvPgogIDxwYXRoCiAgICAgaWQ9InNoYXBlMDEiCiAgICAgdHJhbnNmb3JtPSJtYXRyaXgoLTEuMDAwMDAwMDA5MjIxODUgMCAwIC0xLjAwMDAwMDAwOTIyMTg1IDE5MTkuOTEzNjE3Mzk4NzEgMTkxMC4zMzcxOTY5MzEyNSkiCiAgICAgZmlsbD0iIzAwMDAwMCIKICAgICBmaWxsLXJ1bGU9ImV2ZW5vZGQiCiAgICAgZD0iTTc2OC4wNzQgMTc3Mi42MUMtMjgyLjAwNCAxNTk4LjY1IC0yMjkuNzEyIDE1MS44MjEgNzY4LjA3NCAwQzc2Ny4wODMgMjkuOTMzNyA3NjguMDk2IDE0Mi43NiA3NjguMDc0IDI2MC44ODZDNDEuNDc0NiA0NTYuOTAzIDEzNy40MjMgMTM4MC4wNiA3NjguMDc0IDE1MTMuNjQiCiAgICAgc3R5bGU9ImZpbGw6IzFhODFlYjtmaWxsLW9wYWNpdHk6MSIgLz4KICA8cGF0aAogICAgIGlkPSJzaGFwZTAyIgogICAgIGZpbGw9IiMwMDAwMDAiCiAgICAgZmlsbC1ydWxlPSJldmVub2RkIgogICAgIGQ9Ik03NjcuOTA5IDg4Ny4zMzhDMjYzLjQwMiA4MDMuOTI2IDAuMDc1OTQyMSAzODcuOTY0IDAgMC4wODU2NDk3QzE0LjY4NjggLTAuMDI4NTQ5OSA5OS4wNTUxIC0wLjAyODU0OTkgMjU1LjAxMSAwLjA4NTY0OTdDMjU1LjMxMSAyODEuMTE0IDQ0OC43ODYgNTYyLjE2MyA3NjcuOTA5IDYyNi40OTkiCiAgICAgc3R5bGU9ImZpbGw6IzFhODFlYjtmaWxsLW9wYWNpdHk6MSIKICAgICB0cmFuc2Zvcm09Im1hdHJpeCgwLjk5OTk5OTk2MDczODQ0IDAgMCAwLjk5OTk5OTk2MDczODQ0IDEyNy45NjY1OTE0OTQzMjggMTAyMy43NzIxNDc4MzE0KSIgLz4KICA8ZWxsaXBzZQogICAgIHN0eWxlPSJmaWxsOiMxYTgxZWI7c3Ryb2tlLXdpZHRoOjEuMDY3OTtmaWxsLW9wYWNpdHk6MSIKICAgICBpZD0icGF0aDEiCiAgICAgY3g9IjEwMjYuMzYxIgogICAgIGN5PSIxMDE0LjIyMTEiCiAgICAgcng9IjE4My4yNTU0MyIKICAgICByeT0iMTgzLjUxNTU4IiAvPgo8L3N2Zz4K)](https://torchjd.org) -[![Static Badge](https://img.shields.io/badge/%F0%9F%92%AC_ChatBot-chat.torchjd.org-blue?logo=%F0%9F%92%AC)](https://chat.torchjd.org) -[![Tests](https://github.com/SimplexLab/TorchJD/actions/workflows/checks.yml/badge.svg)](https://github.com/SimplexLab/TorchJD/actions/workflows/checks.yml) -[![codecov](https://codecov.io/gh/SimplexLab/TorchJD/graph/badge.svg?token=8AUCZE76QH)](https://codecov.io/gh/SimplexLab/TorchJD) -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchjd)](https://pypi.org/project/torchjd/) -[![Static Badge](https://img.shields.io/badge/PyTorch-%3E%3D2.3-blue?logo=pytorch&logoColor=white)](https://pytorch.org/) -[![Static Badge](https://img.shields.io/badge/Discord%20-%20community%20-%20%235865F2?logo=discord&logoColor=%23FFFFFF&label=Discord)](https://discord.gg/76KkRnb3nk) - -TorchJD is a library extending autograd to enable -[Jacobian descent](https://arxiv.org/pdf/2406.16232) with PyTorch. It can be used to train neural -networks with multiple objectives. In particular, it supports multi-task learning, with a wide -variety of aggregators from the literature. It also enables the instance-wise risk minimization -paradigm. The full documentation is available at [torchjd.org](https://torchjd.org), with several -usage examples. - -## Jacobian descent (JD) -Jacobian descent is an extension of gradient descent supporting the optimization of vector-valued -functions. This algorithm can be used to train neural networks with multiple loss functions. In this -context, JD iteratively updates the parameters of the model using the Jacobian matrix of the vector -of losses (the matrix stacking each individual loss' gradient). For more details, please refer to -Section 2.1 of the [paper](https://arxiv.org/pdf/2406.16232). - -### How does this compare to averaging the different losses and using gradient descent? - -Averaging the losses and computing the gradient of the mean is mathematically equivalent to -computing the Jacobian and averaging its rows. However, this approach has limitations. If two -gradients are conflicting (they have a negative inner product), simply averaging them can result in -an update vector that is conflicting with one of the two gradients. Averaging the losses and making -a step of gradient descent can thus lead to an increase of one of the losses. - -This is illustrated in the following picture, in which the two objectives' gradients $g_1$ and $g_2$ -are conflicting, and averaging them gives an update direction that is detrimental to the first -objective. Note that in this picture, the dual cone, represented in green, is the set of vectors -that have a non-negative inner product with both $g_1$ and $g_2$. - -![image](docs/source/_static/gradients_cone_projections_upgrad_mean.svg) - -With Jacobian descent, $g_1$ and $g_2$ are computed individually and carefully aggregated using an -aggregator $\mathcal A$. In this example, the aggregator is the Unconflicting Projection of -Gradients $\mathcal A_{\text{UPGrad}}$: it -projects each gradient onto the dual cone, and averages the projections. This ensures that the -update will always be beneficial to each individual objective (given a sufficiently small step -size). In addition to $\mathcal A_{\text{UPGrad}}$, TorchJD supports -[more than 10 aggregators from the literature](https://torchjd.org/stable/docs/aggregation). - -## Installation - -TorchJD can be installed directly with pip: -```bash -pip install "torchjd[quadprog_projector]" -``` - -This includes the dependencies required by UPGrad and DualProj. Some other aggregators may have -additional dependencies. Please refer to the -[installation documentation](https://torchjd.org/stable/installation) for them. - -## Usage - -Compared to standard `torch`, `torchjd` simply changes the way to obtain the `.grad` fields of your -model parameters. - -### Using the `autojac` engine - -The autojac engine is for computing and aggregating Jacobians efficiently. - -#### 1. `backward` + `jac_to_grad` -In standard `torch`, you generally combine your `losses` into a single scalar `loss`, and call -`loss.backward()` to compute the gradient of the loss with respect to each model parameter and to -store it in the `.grad` fields of those parameters. The basic usage of `torchjd` is to replace this -`loss.backward()` by a call to -[`torchjd.autojac.backward(losses)`](https://torchjd.org/stable/docs/autojac/backward/). Instead of -computing the gradient of a scalar loss, it will compute the Jacobian of a vector of losses, and -store it in the `.jac` fields of the model parameters. You then have to call -[`torchjd.autojac.jac_to_grad`](https://torchjd.org/stable/docs/autojac/jac_to_grad/) to aggregate -this Jacobian using the specified -[`Aggregator`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Aggregator), and to -store the result into the `.grad` fields of the model parameters. See this -[usage example](https://torchjd.org/stable/examples/basic_usage/) for more details. - -#### 2. `mtl_backward` + `jac_to_grad` -In the case of multi-task learning, an alternative to -[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) is -[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/). It computes -the gradient of each task-specific loss with respect to the corresponding task's parameters, and -stores it in their `.grad` fields. It also computes the Jacobian of the vector of losses with -respect to the shared parameters and stores it in their `.jac` field. Then, the -[`torchjd.autojac.jac_to_grad`](https://torchjd.org/stable/docs/autojac/jac_to_grad/) function can -be called to aggregate this Jacobian and replace the `.jac` fields by `.grad` fields for the shared -parameters. - -The following example shows how to use TorchJD to train a multi-task model with Jacobian descent, -using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). - -```diff - import torch - from torch.nn import Linear, MSELoss, ReLU, Sequential - from torch.optim import SGD - -+ from torchjd.autojac import jac_to_grad, mtl_backward -+ from torchjd.aggregation import UPGrad - - shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) - task1_module = Linear(3, 1) - task2_module = Linear(3, 1) - params = [ - *shared_module.parameters(), - *task1_module.parameters(), - *task2_module.parameters(), - ] - - loss_fn = MSELoss() - optimizer = SGD(params, lr=0.1) -+ aggregator = UPGrad() - - inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 - task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task - task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task +A GradNorm implementation for multi-task learning, designed to balance tasks by normalizing gradient magnitudes. - for input, target1, target2 in zip(inputs, task1_targets, task2_targets): - features = shared_module(input) - output1 = task1_module(features) - output2 = task2_module(features) - loss1 = loss_fn(output1, target1) - loss2 = loss_fn(output2, target2) +## Quick Start +```python +from torchjd_gradnorm import GradNormScalarizer +# Initialize and use in your training loop! -- loss = loss1 + loss2 -- loss.backward() -+ mtl_backward([loss1, loss2], features=features) -+ jac_to_grad(shared_module.parameters(), aggregator) - optimizer.step() - optimizer.zero_grad() -``` +### 2. The "Publish" Commands +Open your terminal (where you run your code) and run these commands one by one to "give" your code to GitHub: -> [!NOTE] -> In this example, the Jacobian is only with respect to the shared parameters. The task-specific -> parameters are simply updated via the gradient of their task’s loss with respect to them. +1. **Initialize the git folder:** + `git init` +2. **Add all your files:** + `git add .` +3. **Save your work with a message:** + `git commit -m "Initial release of GradNorm scalarizer"` +4. **Connect to GitHub:** (Replace `YOUR_USERNAME` with your actual GitHub name) + `git remote add origin https://github.com/YOUR_USERNAME/torchjd-gradnorm.git` +5. **Send it to the cloud:** + `git push -u origin main` -> [!TIP] -> Once your model parameters all have a `.grad` field, it's the role of the -> [optimizer](https://docs.pytorch.org/docs/stable/optim.html#torch.optim.Optimizer) to update the -> parameters values. This is exactly the same as in standard `torch`. -#### 3. `jac` -If you're simply interested in computing Jacobians without storing them in the `.jac` fields, you -can also use the [`torchjd.autojac.jac`](https://torchjd.org/stable/docs/autojac/jac/) function, -that is analog to -[`torch.autograd.grad`](https://docs.pytorch.org/docs/stable/generated/torch.autograd.grad.html), -except that it computes the Jacobian of a vector of losses rather than the gradient of a scalar -loss. +### 3. Professional Finishing Touch: The `__init__.py` +To make sure other people can actually *import* your code, create an empty file named **`__init__.py`** inside your `torchjd_gradnorm` folder. This is a special file that tells Python, "This folder is a library!" -### Using the `autogram` engine - -The Gramian of the Jacobian, defined as the Jacobian multiplied by its transpose, contains all the -dot products between individual gradients. It thus contains all the information about conflict and -gradient imbalance. It turns out that most aggregators from the literature -(e.g. [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/)) make a linear combination of -the rows of the Jacobian, whose weights only depend on the Gramian of the Jacobian. - -An alternative implementation of Jacobian descent is thus to: -- Compute this Gramian incrementally (layer by layer), without ever storing the full Jacobian in - memory. -- Extract the weights from it using a - [`Weighting`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Weighting). -- Combine the losses using those weights and make a step of gradient descent on the combined loss. - -The main advantage of this approach is to save memory because the Jacobian (that is typically large) -never has to be stored in memory. The -[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine/) is precisely made to -compute the Gramian of the Jacobian efficiently. - -The following example shows how to use the `autogram` engine to minimize the vector of per-instance -losses with Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). - -```diff - import torch - from torch.nn import Linear, MSELoss, ReLU, Sequential - from torch.optim import SGD - -+ from torchjd.autogram import Engine -+ from torchjd.aggregation import UPGradWeighting - - model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU()) - -- loss_fn = MSELoss() -+ loss_fn = MSELoss(reduction="none") - optimizer = SGD(model.parameters(), lr=0.1) - -+ weighting = UPGradWeighting() -+ engine = Engine(model, batch_dim=0) - - inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 - targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task - - for input, target in zip(inputs, targets): - output = model(input).squeeze(dim=1) # shape [16] -- loss = loss_fn(output, target) # shape [1] -+ losses = loss_fn(output, target) # shape [16] - -- loss.backward() -+ gramian = engine.compute_gramian(losses) # shape: [16, 16] -+ weights = weighting(gramian) # shape: [16] -+ losses.backward(weights) - optimizer.step() - optimizer.zero_grad() -``` - -You can even go one step further by considering the multiple tasks and each element of the batch -independently (Instance-Wise Multitask Learning). See [this example](https://torchjd.org/stable/examples/iwmtl/) for more details. - -More usage examples can be found [here](https://torchjd.org/stable/examples/). - -## Supported Aggregators and Weightings -TorchJD provides many existing aggregators from the literature, listed in the following table. - - -| Aggregator | Weighting | Publication | -|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/#torchjd.aggregation.UPGrad) (recommended) | [UPGradWeighting](https://torchjd.org/stable/docs/aggregation/upgrad/#torchjd.aggregation.UPGradWeighting) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) | -| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTL) | [AlignedMTLWeighting](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTLWeighting) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) | -| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGrad) | [CAGradWeighting](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGradWeighting) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) | -| [ConFIG](https://torchjd.org/stable/docs/aggregation/config#torchjd.aggregation.ConFIG) | - | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) | -| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - | -| - | [CRMOGMWeighting](https://torchjd.org/stable/docs/aggregation/cr_mogm/#torchjd.aggregation.CRMOGMWeighting) | [On the Convergence of Stochastic Multi-Objective Gradient Manipulation and Beyond](https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf) | -| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) | -| [FairGrad](https://torchjd.org/stable/docs/aggregation/fairgrad#torchjd.aggregation.FairGrad) | [FairGradWeighting](https://torchjd.org/stable/docs/aggregation/fairgrad#torchjd.aggregation.FairGradWeighting) | [Fair Resource Allocation in Multi-Task Learning](https://arxiv.org/pdf/2402.15638) | -| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) | -| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) | -| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) | -| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) | -| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - | -| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDA) | [MGDAWeighting](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDAWeighting) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://comptes-rendus.academie-sciences.fr/mathematique/articles/10.1016/j.crma.2012.03.014/) | -| - | [MoDoWeighting](https://torchjd.org/stable/docs/aggregation/modo/#torchjd.aggregation.MoDoWeighting) | [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) | -| [NashMTL](https://torchjd.org/stable/docs/aggregation/nash_mtl#torchjd.aggregation.NashMTL) | - | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) | -| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGrad) | [PCGradWeighting](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGradWeighting) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) | -| [Random](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.Random) | [RandomWeighting](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.RandomWeighting) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) | -| [Sum](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.Sum) | [SumWeighting](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.SumWeighting) | - | -| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean#torchjd.aggregation.TrimmedMean) | - | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) | - -## Release Methodology - -We try to make a release whenever we have something worth sharing to users (bug fix, minor or large -feature, etc.). TorchJD follows [semantic versioning](https://semver.org/). Since the library is -still in beta (`0.x.y`), we sometimes make interface changes in minor versions. We prioritize the -long-term quality of the library, which occasionally means introducing breaking changes. Whenever a -release contains breaking changes, the [changelog](CHANGELOG.md) and the GitHub release notes always -include clear instructions on how to migrate. - -## Contribution -Please read the [Contribution page](CONTRIBUTING.md). - -Thanks to our amazing contributors for making this project possible: +--- - +### You are officially a Publisher! +Once that `git push` command finishes, you can go to your GitHub profile, and you will see your `torchjd-gradnorm` repository sitting there, live for the world to see. -## Citation -If you use TorchJD for your research, please cite: -``` -@article{jacobian_descent, - title={Jacobian Descent For Multi-Objective Optimization}, - author={Quinton, Pierre and Rey, Valérian}, - journal={arXiv preprint arXiv:2406.16232}, - year={2024} -} -``` +**When you have finished the push, let me know! I want to know how it feels to have your own AI tool published online. Are you ready to run these commands?** diff --git a/src/torchjd/README.md b/src/torchjd/README.md new file mode 100644 index 00000000..2edfa644 --- /dev/null +++ b/src/torchjd/README.md @@ -0,0 +1,14 @@ +# TorchJD: GradNorm Integration + +This fork adds the `GradNormScalarizer` to the `TorchJD` library to support dynamic loss balancing in multi-task learning. + +## Key Features +- Dynamic gradient norm balancing. +- Easy integration with existing `Scalarizer` interface. + +## Usage +```python +from torchjd.scalarization import GradNormScalarizer + +# Initialize the scalarizer +scalarizer = GradNormScalarizer(num_tasks=3) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 8a8633df..d3d8275a 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -3,7 +3,6 @@ import torch from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge - from torchjd._linalg import flatten, movedim, reshape from torchjd.linalg import PSDMatrix diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index e4488497..a11be738 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -3,11 +3,11 @@ from torch import Tensor from torch.utils._pytree import PyTree - from torchjd._linalg import compute_gramian -from torchjd.autogram._jacobian_computer import JacobianComputer from torchjd.linalg import Matrix, PSDMatrix +from torchjd.autogram._jacobian_computer import JacobianComputer + class GramianComputer(ABC): @abstractmethod diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 4caa4455..7a59bf3a 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -7,7 +7,6 @@ from torch.nn import Parameter from torch.overrides import is_tensor_like from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only - from torchjd.linalg import Matrix # Note about import from protected _pytree module: diff --git a/src/torchjd/setup.py b/src/torchjd/setup.py new file mode 100644 index 00000000..b23db8c6 --- /dev/null +++ b/src/torchjd/setup.py @@ -0,0 +1,8 @@ +from setuptools import find_packages, setup + +setup( + name="torchjd", + version="0.1.0", + package_dir={"": "src"}, + packages=find_packages(where="src"), +) diff --git a/src/torchjd/__init__.py b/src/torchjd/src/torchjd/__init__.py similarity index 100% rename from src/torchjd/__init__.py rename to src/torchjd/src/torchjd/__init__.py diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/src/torchjd/_linalg/__init__.py similarity index 100% rename from src/torchjd/_linalg/__init__.py rename to src/torchjd/src/torchjd/_linalg/__init__.py diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/src/torchjd/_linalg/_dual_cone.py similarity index 100% rename from src/torchjd/_linalg/_dual_cone.py rename to src/torchjd/src/torchjd/_linalg/_dual_cone.py diff --git a/src/torchjd/_linalg/_generalized_gramian.py b/src/torchjd/src/torchjd/_linalg/_generalized_gramian.py similarity index 100% rename from src/torchjd/_linalg/_generalized_gramian.py rename to src/torchjd/src/torchjd/_linalg/_generalized_gramian.py diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/src/torchjd/_linalg/_gramian.py similarity index 100% rename from src/torchjd/_linalg/_gramian.py rename to src/torchjd/src/torchjd/_linalg/_gramian.py diff --git a/src/torchjd/_linalg/_matrix.py b/src/torchjd/src/torchjd/_linalg/_matrix.py similarity index 100% rename from src/torchjd/_linalg/_matrix.py rename to src/torchjd/src/torchjd/_linalg/_matrix.py diff --git a/src/torchjd/_mixins.py b/src/torchjd/src/torchjd/_mixins.py similarity index 100% rename from src/torchjd/_mixins.py rename to src/torchjd/src/torchjd/_mixins.py diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/src/torchjd/aggregation/__init__.py similarity index 100% rename from src/torchjd/aggregation/__init__.py rename to src/torchjd/src/torchjd/aggregation/__init__.py diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/src/torchjd/aggregation/_aggregator_bases.py similarity index 100% rename from src/torchjd/aggregation/_aggregator_bases.py rename to src/torchjd/src/torchjd/aggregation/_aggregator_bases.py diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/src/torchjd/aggregation/_aligned_mtl.py similarity index 100% rename from src/torchjd/aggregation/_aligned_mtl.py rename to src/torchjd/src/torchjd/aggregation/_aligned_mtl.py diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/src/torchjd/aggregation/_cagrad.py similarity index 100% rename from src/torchjd/aggregation/_cagrad.py rename to src/torchjd/src/torchjd/aggregation/_cagrad.py diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/src/torchjd/aggregation/_config.py similarity index 100% rename from src/torchjd/aggregation/_config.py rename to src/torchjd/src/torchjd/aggregation/_config.py diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/src/torchjd/aggregation/_constant.py similarity index 100% rename from src/torchjd/aggregation/_constant.py rename to src/torchjd/src/torchjd/aggregation/_constant.py diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/src/torchjd/aggregation/_cr_mogm.py similarity index 100% rename from src/torchjd/aggregation/_cr_mogm.py rename to src/torchjd/src/torchjd/aggregation/_cr_mogm.py diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/src/torchjd/aggregation/_dualproj.py similarity index 100% rename from src/torchjd/aggregation/_dualproj.py rename to src/torchjd/src/torchjd/aggregation/_dualproj.py diff --git a/src/torchjd/aggregation/_fairgrad.py b/src/torchjd/src/torchjd/aggregation/_fairgrad.py similarity index 100% rename from src/torchjd/aggregation/_fairgrad.py rename to src/torchjd/src/torchjd/aggregation/_fairgrad.py diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/src/torchjd/aggregation/_graddrop.py similarity index 100% rename from src/torchjd/aggregation/_graddrop.py rename to src/torchjd/src/torchjd/aggregation/_graddrop.py diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/src/torchjd/aggregation/_gradvac.py similarity index 100% rename from src/torchjd/aggregation/_gradvac.py rename to src/torchjd/src/torchjd/aggregation/_gradvac.py diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/src/torchjd/aggregation/_imtl_g.py similarity index 100% rename from src/torchjd/aggregation/_imtl_g.py rename to src/torchjd/src/torchjd/aggregation/_imtl_g.py diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/src/torchjd/aggregation/_krum.py similarity index 100% rename from src/torchjd/aggregation/_krum.py rename to src/torchjd/src/torchjd/aggregation/_krum.py diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/src/torchjd/aggregation/_mean.py similarity index 100% rename from src/torchjd/aggregation/_mean.py rename to src/torchjd/src/torchjd/aggregation/_mean.py diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/src/torchjd/aggregation/_mgda.py similarity index 100% rename from src/torchjd/aggregation/_mgda.py rename to src/torchjd/src/torchjd/aggregation/_mgda.py diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/src/torchjd/aggregation/_mixins.py similarity index 100% rename from src/torchjd/aggregation/_mixins.py rename to src/torchjd/src/torchjd/aggregation/_mixins.py diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/src/torchjd/aggregation/_modo.py similarity index 100% rename from src/torchjd/aggregation/_modo.py rename to src/torchjd/src/torchjd/aggregation/_modo.py diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/src/torchjd/aggregation/_nash_mtl.py similarity index 100% rename from src/torchjd/aggregation/_nash_mtl.py rename to src/torchjd/src/torchjd/aggregation/_nash_mtl.py diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/src/torchjd/aggregation/_pcgrad.py similarity index 100% rename from src/torchjd/aggregation/_pcgrad.py rename to src/torchjd/src/torchjd/aggregation/_pcgrad.py diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/src/torchjd/aggregation/_random.py similarity index 100% rename from src/torchjd/aggregation/_random.py rename to src/torchjd/src/torchjd/aggregation/_random.py diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/src/torchjd/aggregation/_sum.py similarity index 100% rename from src/torchjd/aggregation/_sum.py rename to src/torchjd/src/torchjd/aggregation/_sum.py diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/src/torchjd/aggregation/_trimmed_mean.py similarity index 100% rename from src/torchjd/aggregation/_trimmed_mean.py rename to src/torchjd/src/torchjd/aggregation/_trimmed_mean.py diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/src/torchjd/aggregation/_upgrad.py similarity index 100% rename from src/torchjd/aggregation/_upgrad.py rename to src/torchjd/src/torchjd/aggregation/_upgrad.py diff --git a/src/torchjd/aggregation/_utils/__init__.py b/src/torchjd/src/torchjd/aggregation/_utils/__init__.py similarity index 100% rename from src/torchjd/aggregation/_utils/__init__.py rename to src/torchjd/src/torchjd/aggregation/_utils/__init__.py diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/src/torchjd/aggregation/_utils/pref_vector.py similarity index 100% rename from src/torchjd/aggregation/_utils/pref_vector.py rename to src/torchjd/src/torchjd/aggregation/_utils/pref_vector.py diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/src/torchjd/aggregation/_weighting_bases.py similarity index 100% rename from src/torchjd/aggregation/_weighting_bases.py rename to src/torchjd/src/torchjd/aggregation/_weighting_bases.py diff --git a/src/torchjd/src/torchjd/aggregation/gradnormaggree.py b/src/torchjd/src/torchjd/aggregation/gradnormaggree.py new file mode 100644 index 00000000..e69de29b diff --git a/src/torchjd/autojac/__init__.py b/src/torchjd/src/torchjd/autojac/__init__.py similarity index 100% rename from src/torchjd/autojac/__init__.py rename to src/torchjd/src/torchjd/autojac/__init__.py diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/src/torchjd/autojac/_accumulation.py similarity index 100% rename from src/torchjd/autojac/_accumulation.py rename to src/torchjd/src/torchjd/autojac/_accumulation.py diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/src/torchjd/autojac/_backward.py similarity index 100% rename from src/torchjd/autojac/_backward.py rename to src/torchjd/src/torchjd/autojac/_backward.py diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/src/torchjd/autojac/_jac.py similarity index 100% rename from src/torchjd/autojac/_jac.py rename to src/torchjd/src/torchjd/autojac/_jac.py diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/src/torchjd/autojac/_jac_to_grad.py similarity index 100% rename from src/torchjd/autojac/_jac_to_grad.py rename to src/torchjd/src/torchjd/autojac/_jac_to_grad.py diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/src/torchjd/autojac/_mtl_backward.py similarity index 100% rename from src/torchjd/autojac/_mtl_backward.py rename to src/torchjd/src/torchjd/autojac/_mtl_backward.py diff --git a/src/torchjd/autojac/_transform/__init__.py b/src/torchjd/src/torchjd/autojac/_transform/__init__.py similarity index 100% rename from src/torchjd/autojac/_transform/__init__.py rename to src/torchjd/src/torchjd/autojac/_transform/__init__.py diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/src/torchjd/autojac/_transform/_accumulate.py similarity index 100% rename from src/torchjd/autojac/_transform/_accumulate.py rename to src/torchjd/src/torchjd/autojac/_transform/_accumulate.py diff --git a/src/torchjd/autojac/_transform/_base.py b/src/torchjd/src/torchjd/autojac/_transform/_base.py similarity index 100% rename from src/torchjd/autojac/_transform/_base.py rename to src/torchjd/src/torchjd/autojac/_transform/_base.py diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/src/torchjd/autojac/_transform/_diagonalize.py similarity index 100% rename from src/torchjd/autojac/_transform/_diagonalize.py rename to src/torchjd/src/torchjd/autojac/_transform/_diagonalize.py diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/src/torchjd/autojac/_transform/_differentiate.py similarity index 100% rename from src/torchjd/autojac/_transform/_differentiate.py rename to src/torchjd/src/torchjd/autojac/_transform/_differentiate.py diff --git a/src/torchjd/autojac/_transform/_grad.py b/src/torchjd/src/torchjd/autojac/_transform/_grad.py similarity index 100% rename from src/torchjd/autojac/_transform/_grad.py rename to src/torchjd/src/torchjd/autojac/_transform/_grad.py diff --git a/src/torchjd/autojac/_transform/_init.py b/src/torchjd/src/torchjd/autojac/_transform/_init.py similarity index 100% rename from src/torchjd/autojac/_transform/_init.py rename to src/torchjd/src/torchjd/autojac/_transform/_init.py diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/src/torchjd/autojac/_transform/_jac.py similarity index 100% rename from src/torchjd/autojac/_transform/_jac.py rename to src/torchjd/src/torchjd/autojac/_transform/_jac.py diff --git a/src/torchjd/autojac/_transform/_materialize.py b/src/torchjd/src/torchjd/autojac/_transform/_materialize.py similarity index 100% rename from src/torchjd/autojac/_transform/_materialize.py rename to src/torchjd/src/torchjd/autojac/_transform/_materialize.py diff --git a/src/torchjd/autojac/_transform/_ordered_set.py b/src/torchjd/src/torchjd/autojac/_transform/_ordered_set.py similarity index 100% rename from src/torchjd/autojac/_transform/_ordered_set.py rename to src/torchjd/src/torchjd/autojac/_transform/_ordered_set.py diff --git a/src/torchjd/autojac/_transform/_select.py b/src/torchjd/src/torchjd/autojac/_transform/_select.py similarity index 100% rename from src/torchjd/autojac/_transform/_select.py rename to src/torchjd/src/torchjd/autojac/_transform/_select.py diff --git a/src/torchjd/autojac/_transform/_stack.py b/src/torchjd/src/torchjd/autojac/_transform/_stack.py similarity index 100% rename from src/torchjd/autojac/_transform/_stack.py rename to src/torchjd/src/torchjd/autojac/_transform/_stack.py diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/src/torchjd/autojac/_utils.py similarity index 100% rename from src/torchjd/autojac/_utils.py rename to src/torchjd/src/torchjd/autojac/_utils.py diff --git a/src/torchjd/linalg/__init__.py b/src/torchjd/src/torchjd/linalg/__init__.py similarity index 100% rename from src/torchjd/linalg/__init__.py rename to src/torchjd/src/torchjd/linalg/__init__.py diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/src/torchjd/scalarization/__init__.py similarity index 100% rename from src/torchjd/scalarization/__init__.py rename to src/torchjd/src/torchjd/scalarization/__init__.py diff --git a/src/torchjd/scalarization/_constant.py b/src/torchjd/src/torchjd/scalarization/_constant.py similarity index 100% rename from src/torchjd/scalarization/_constant.py rename to src/torchjd/src/torchjd/scalarization/_constant.py diff --git a/src/torchjd/scalarization/_geometric_mean.py b/src/torchjd/src/torchjd/scalarization/_geometric_mean.py similarity index 100% rename from src/torchjd/scalarization/_geometric_mean.py rename to src/torchjd/src/torchjd/scalarization/_geometric_mean.py diff --git a/src/torchjd/scalarization/_imtl_l.py b/src/torchjd/src/torchjd/scalarization/_imtl_l.py similarity index 100% rename from src/torchjd/scalarization/_imtl_l.py rename to src/torchjd/src/torchjd/scalarization/_imtl_l.py diff --git a/src/torchjd/scalarization/_mean.py b/src/torchjd/src/torchjd/scalarization/_mean.py similarity index 100% rename from src/torchjd/scalarization/_mean.py rename to src/torchjd/src/torchjd/scalarization/_mean.py diff --git a/src/torchjd/scalarization/_random.py b/src/torchjd/src/torchjd/scalarization/_random.py similarity index 100% rename from src/torchjd/scalarization/_random.py rename to src/torchjd/src/torchjd/scalarization/_random.py diff --git a/src/torchjd/scalarization/_scalarizer_base.py b/src/torchjd/src/torchjd/scalarization/_scalarizer_base.py similarity index 100% rename from src/torchjd/scalarization/_scalarizer_base.py rename to src/torchjd/src/torchjd/scalarization/_scalarizer_base.py diff --git a/src/torchjd/scalarization/_stch.py b/src/torchjd/src/torchjd/scalarization/_stch.py similarity index 100% rename from src/torchjd/scalarization/_stch.py rename to src/torchjd/src/torchjd/scalarization/_stch.py diff --git a/src/torchjd/scalarization/_sum.py b/src/torchjd/src/torchjd/scalarization/_sum.py similarity index 100% rename from src/torchjd/scalarization/_sum.py rename to src/torchjd/src/torchjd/scalarization/_sum.py diff --git a/src/torchjd/scalarization/_uw.py b/src/torchjd/src/torchjd/scalarization/_uw.py similarity index 100% rename from src/torchjd/scalarization/_uw.py rename to src/torchjd/src/torchjd/scalarization/_uw.py diff --git a/src/torchjd/src/torchjd/scalarization/gradnorm.py b/src/torchjd/src/torchjd/scalarization/gradnorm.py new file mode 100644 index 00000000..300f2c2f --- /dev/null +++ b/src/torchjd/src/torchjd/scalarization/gradnorm.py @@ -0,0 +1,33 @@ +import torch +from torch import Tensor, nn + +from ._scalarizer_base import Scalarizer + + +class GradNormScalarizer(Scalarizer): + def __init__(self, num_tasks: int, alpha: float = 1.5) -> None: + super().__init__() + self.num_tasks = num_tasks + self.weights = nn.Parameter(torch.ones(num_tasks)) + self.alpha = alpha + self.register_buffer("initial_losses", None) + + def forward(self, values: Tensor, model: nn.Module = None) -> Tensor: + if self.initial_losses is None: + self.initial_losses = values.detach().clone() + + if model is not None: + norms = self._compute_gradient_norms(values, model) + loss_ratios = values / self.initial_losses + target_norm = torch.mean(norms) * (loss_ratios**self.alpha) + self.weights.data = target_norm / norms + + return (values * self.weights).sum() + + def _compute_gradient_norms(self, values: Tensor, model: nn.Module) -> Tensor: + norms = [] + for loss in values: + grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True) + norm = torch.norm(torch.cat([g.view(-1) for g in grads])) + norms.append(norm) + return torch.stack(norms)