Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 111 additions & 61 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,98 +1,148 @@
# AutoencoderKL
# Transfusion VAE

## About The Project
This repository is a branch of [lavinal712/AutoencoderKL](https://github.com/lavinal712/AutoencoderKL) with some modifications for the Transfusion VAE.

There are many great training scripts for VAE on Github. However, some repositories are not maintained and some are not updated to the latest version of PyTorch. Therefore, I decided to create this repository to provide a simple and easy-to-use training script for VAE by Lightning. Beside, the code is easy to transfer to other projects for time saving.
In paper [Transfusion](https://arxiv.org/abs/2408.11039), the authors proposed a new VAE for image generation. Different from the original VAE, the Transfusion VAE has 8 latent dimensions, and the training loss is:

- Support training and finetuning both [Stable Diffusion](https://github.com/CompVis/stable-diffusion) VAE and [Flux](https://github.com/black-forest-labs/flux) VAE.
- Support evaluating reconstruction quality (FID, PSNR, SSIM, LPIPS).
- A practical guidance of training VAE.
- Easy to modify the code for your own research.
```math
\mathcal{L}_{\text{VAE}} = \mathcal{L}_{1} + \mathcal{L}_{\text{LPIPS}} + 0.5 \mathcal{L}_{\text{GAN}} + 0.2 \mathcal{L}_{\text{ID}} + 0.000001 \mathcal{L}_{\text{KL}}
```

<!-- GETTING STARTED -->
## Getting Started
where $`\mathcal{L}_{1}`$ is the L1 loss, $`\mathcal{L}_{\text{LPIPS}}`$ is the perceptual loss based on LPIPS similarity, $`\mathcal{L}_{\text{GAN}}`$ is a patch-based discriminator loss, $`\mathcal{L}_{\text{ID}}`$ is a perceptual loss based on Moco v2 model, and $`\mathcal{L}_{\text{KL}}`$ is the standard KL-regularization loss.

## Visualization

To get a local copy up and running follow these simple example steps.
| Input | Reconstruction |
|--------------------------------------- |-----------------------------------------------------------|
| ![assets/inputs.png](assets/inputs.png) | ![assets/reconstructions.png](assets/reconstructions.png) |

## Getting Started

### Installation

```bash
git clone https://github.com/lavinal712/AutoencoderKL.git
```
git clone https://github.com/lavinal712/AutoencoderKL.git -b transfusion_vae
cd AutoencoderKL
conda create -n autoencoderkl python=3.10 -y
conda activate autoencoderkl
pip install -r requirements.txt
```

### Training
### Model

To start training, you need to prepare a config file. You can refer to the config files in the `configs` folder.
You can load the pretrained model from [lavinal712/transfusion-vae](https://huggingface.co/lavinal712/transfusion-vae).

If you want to train on your own dataset, you should write your own data loader in `sgm/data` and modify the parameters in the config file.
```python
from diffusers import AutoencoderKL

Finetuning a VAE model is simple. You just need to specify the `ckpt_path` and `trainable_ae_params` in the config file. To keep the latent space of the original model, it is recommended to set decoder to be trainable.
vae = AutoencoderKL.from_pretrained("lavinal712/transfusion-vae")
```

Then, you can start training by running the following command.
### Data

```bash
NUM_GPUS=4
NUM_NODES=1
We use combined dataset of ImageNet, COCO and FFHQ for training. Here is the structure of data:

torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} main.py \
--base configs/autoencoder_kl_32x32x4.yaml \
--train \
--logdir logs/autoencoder_kl_32x32x4 \
--scale_lr True \
--wandb False \
```
ImageNet/
├── train/
│ ├── n01440764/
│ │ ├── n01440764_18.JPEG
│ │ ├── n01440764_36.JPEG
│ │ └── ...
│ ├── n01443537/
│ │ ├── n01443537_2.JPEG
│ │ ├── n01443537_16.JPEG
│ │ └── ...
│ ├── ...
├── val/
│ ├── n01440764/
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ └── ...
│ ├── n01443537/
│ │ ├── ILSVRC2012_val_00000236.JPEG
│ │ ├── ILSVRC2012_val_00000262.JPEG
│ │ └── ...
│ ├── ...
└── ...
```

### Evaluation
```
COCO/
├── annotations/
│ ├── captions_train2017.json
│ ├── captions_val2017.json
│ ├── ...
├── test2017/
│ ├── 000000000001.jpg
│ ├── 000000000016.jpg
│ └── ...
├── train2017/
│ ├── 000000000009.jpg
│ ├── 000000000025.jpg
│ └── ...
├── val2017/
│ ├── 000000000139.jpg
│ ├── 000000000285.jpg
│ └── ...
└── ...
```

We provide a script to evaluate the reconstruction quality of the trained model. `--resume` provides a convenient way to load the checkpoint from the log directory.
```
FFHQ/
├── images1024x1024/
│ ├── 00000/
│ │ ├── 00000.png
│ │ ├── 00001.png
│ │ └── ...
│ ├── 01000/
│ │ ├── 01000.png
│ │ ├── 01001.png
│ │ └── ...
│ ├── ...
│ ├── 69000/
│ │ ├── 69000.png
│ │ ├── 69001.png
│ │ └── ...
│ └── LICENSE.txt
├── ffhq-dataset-v2.json
└── ...
```

We introduce multi-GPU and multi-thread method for faster evaluation.
### Training

The default dataset is ImageNet. You can change the dataset by modifying the `--datadir` in the command line and the evaluation script.
We train the Transfusion VAE on combined dataset of ImageNet, COCO and FFHQ for 50 epochs.

```bash
NUM_GPUS=4
NUM_NODES=1

torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} eval.py \
--resume logs/autoencoder_kl_32x32x4 \
--base configs/autoencoder_kl_32x32x4.yaml \
--logdir eval/autoencoder_kl_32x32x4 \
--datadir /path/to/ImageNet \
--image_size 256 \
--batch_size 16 \
--num_workers 16 \
torchrun --nproc_per_node=4 --nnodes=1 main.py \
--base configs/transfusion_vae_32x32x8.yaml \
--train \
--scale_lr False \
--wandb True \
```

### Converting to diffusers

[huggingface/diffusers](https://github.com/huggingface/diffusers) is a library for diffusion models. It provides a script [convert_vae_pt_to_diffusers.py
](https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py) to convert a PyTorch Lightning model to a diffusers model.
### Evaluation

Currently, the script is not updated for all kinds of VAE models, just for SD VAE.
We evaluate the Transfusion VAE on ImageNet and COCO.

```bash
python convert_vae_pt_to_diffusers.py \
--vae_path logs/autoencoder_kl_32x32x4/checkpoints/last.ckpt \
--dump_path autoencoder_kl_32x32x4 \
```
ImageNet 2012 (256x256, val, 50000 images)

## Guidance
| Model | rFID | PSNR | SSIM | LPIPS |
|-----------------|-------|--------|-------|-------|
| Transfusion-VAE | 0.408 | 28.723 | 0.845 | 0.081 |
| SD-VAE | 0.692 | 26.910 | 0.772 | 0.130 |

Here are some guidance for training VAE. If there are any mistakes, please let me know.
COCO 2017 (256x256, val, 5000 images)

- Learning rate: In LDM repository [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion), the base learning rate is set to 4.5e-6 in the config file. However, the batch size is 12, accumulated gradient is 2 and `scale_lr` is set to `True`. Therefore, the effective learning rate is 4.5e-6 * 12 * 2 * 1 = 1.08e-4. It is better to set the learning rate from 1.0e-4 to 1.0e-5. In finetuning stage, it can be smaller than the first stage.
- `scale_lr`: It is better to set `scale_lr` to `False` when training on a large dataset.
- Discriminator: You should open the discriminator in the end of the training, when the VAE has good reconstruction performance. In default, `disc_start` is set to 50001.
- Perceptual loss: LPIPS is a good metric for evaluating the quality of the reconstructed images. Some models use other perceptual loss functions to gain better performance, such as [sypsyp97/convnext_perceptual_loss](https://github.com/sypsyp97/convnext_perceptual_loss).
| Model | rFID | PSNR | SSIM | LPIPS |
|-----------------|-------|--------|-------|-------|
| Transfusion-VAE | 2.749 | 28.556 | 0.855 | 0.078 |
| SD-VAE | 4.246 | 26.622 | 0.784 | 0.127 |

## Acknowledgments
## Acknowledgements

Thanks for the following repositories. Without their code, this project would not be possible.
Thanks to the following repositories for providing the code and models of Moco v2 and Moco v3, and a repository for the inspiration of the perceptual loss.

- [Stability-AI/generative-models](https://github.com/Stability-AI/generative-models). We heavily borrow the code from this repository, just modifing a few parameters for our concept.
- [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion). We follow the hyperparameter settings of this repository in config files.
- [facebookresearch/moco](https://github.com/facebookresearch/moco)
- [facebookresearch/moco-v3](https://github.com/facebookresearch/moco-v3)
- [sypsyp97/convnext_perceptual_loss](https://github.com/sypsyp97/convnext_perceptual_loss)
Binary file added assets/inputs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/reconstructions.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions configs/convert_vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False

scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False

first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 8
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 8
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
97 changes: 97 additions & 0 deletions configs/transfusion_vae_32x32x8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.autoencoder.AutoencoderKL
params:
input_key: jpg
monitor: "val/loss/rec"
disc_start_iter: 50001
embed_dim: 8

ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 8
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0

lossconfig:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 1.0
disc_start: 50001
disc_weight: 0.5
learn_logvar: false

perceptual_name: mocov2_800ep
perceptual_config:
target: sgm.modules.autoencoding.moco.model.moco.MoCo
params:
arch: resnet50
dim: 128
K: 65536
m: 0.999
T: 0.2
mlp: true
perceptual_weight_2: 0.2

regularization_weights:
kl_loss: 1.0e-6

data:
target: sgm.data.imagenet_coco_ffhq.ImageNetCOCOFFHQLoader
params:
batch_size: 12
num_workers: 4
prefetch_factor: 2
shuffle: true

train:
imagenet_dir: /path/to/ImageNet
coco_dir: /path/to/MSCOCO
ffhq_dir: /path/to/FFHQ
size: 256
transform: random_crop
validation:
imagenet_dir: /path/to/ImageNet
coco_dir: /path/to/MSCOCO
ffhq_dir: /path/to/FFHQ
size: 256
transform: random_crop

lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True

modelcheckpoint:
params:
every_n_epochs: 1

callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000

image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True

trainer:
precision: bf16
devices: 0, 1, 2, 3
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
check_val_every_n_epoch: 1
max_epochs: 50
Loading