Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 22 additions & 97 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,114 +1,39 @@
# AutoencoderKL

## About The Project
This repository is a branch of [lavinal712/AutoencoderKL](https://github.com/lavinal712/AutoencoderKL) with some modifications for the NextStep-1 VAE.

There are many great training scripts for VAE on Github. However, some repositories are not maintained and some are not updated to the latest version of PyTorch. Therefore, I decided to create this repository to provide a simple and easy-to-use training script for VAE by Lightning. Beside, the code is easy to transfer to other projects for time saving.
## Sigma-VAE

- Support training and finetuning both [Stable Diffusion](https://github.com/CompVis/stable-diffusion) VAE and [FLUX](https://github.com/black-forest-labs/flux) VAE.
- Support evaluating reconstruction quality (FID, PSNR, SSIM, LPIPS).
- A practical guidance of training VAE.
- Easy to modify the code for your own research.
Sigma-VAE is proposed by [Multimodal Latent Language Modeling with Next-Token Diffusion](https://arxiv.org/abs/2412.08635) to prevent variance collapse by enforcing a fixed variance in the latent space. The reconstruction pass is computed as:

## Visualization

This is the visualization of AutoencoderKL. From left to right, there are the original image, the reconstructed image and the difference between them. From top to bottom, there are the results of SD VAE, SDXL VAE and FLUX VAE.

Image source: [https://www.bilibili.com/opus/762402574076739817](https://www.bilibili.com/opus/762402574076739817)

![baka](assets/visualization.png)

<!-- GETTING STARTED -->
## Getting Started

To get a local copy up and running follow these simple example steps.

### Installation

```bash
git clone https://github.com/lavinal712/AutoencoderKL.git
cd AutoencoderKL
conda create -n autoencoderkl python=3.10 -y
conda activate autoencoderkl
pip install -r requirements.txt
```math
\begin{aligned}
\mu &= \text{Encoder}_\phi(x) \\
z &= \mu + \sigma \odot \epsilon, \quad \text{where } \epsilon \sim \mathcal{N}(0,1),\ \sigma \sim \mathcal{N}(0,C_\sigma) \\
\hat{x} &= \text{Decoder}_\psi(z)
\end{aligned}
```

### Training

To start training, you need to prepare a config file. You can refer to the config files in the `configs` folder.

If you want to train on your own dataset, you should write your own data loader in `sgm/data` and modify the parameters in the config file.

Finetuning a VAE model is simple. You just need to specify the `ckpt_path` and `trainable_ae_params` in the config file. To keep the latent space of the original model, it is recommended to set decoder to be trainable.
In [NextStep-1: Toward Autoregressive Image Generation with Continuous Tokens at Scale](https://arxiv.org/abs/2508.10711), they find that **a regularized latent space is critical for generation**. Specifically, applying higher noise intensity during tokenizer training increases generation loss but paradoxically improves the quality of the generated images. They attribute this phenomenon to noise regularization, cultivating a well-conditioned latent space. This process enhances two key properties: the tokenizer decoder’s robustness to latent perturbations and a more dispersed latent distribution, a property prior work has also found beneficial for generation.

Then, you can start training by running the following command.

```bash
NUM_GPUS=4
NUM_NODES=1

torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} main.py \
--base configs/autoencoder_kl_32x32x4.yaml \
--train \
--logdir logs/autoencoder_kl_32x32x4 \
--scale_lr True \
--wandb False \
```
![distributions](assets/distributions.png)

### Evaluation

We provide a script to evaluate the reconstruction quality of the trained model. `--resume` provides a convenient way to load the checkpoint from the log directory.

We introduce multi-GPU and multi-thread method for faster evaluation.

The default dataset is ImageNet. You can change the dataset by modifying the `--datadir` in the command line and the evaluation script.

```bash
NUM_GPUS=4
NUM_NODES=1

torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} eval.py \
--resume logs/autoencoder_kl_32x32x4 \
--base configs/autoencoder_kl_32x32x4.yaml \
--logdir eval/autoencoder_kl_32x32x4 \
--datadir /path/to/ImageNet \
--image_size 256 \
--batch_size 16 \
--num_workers 16 \
```

Here are the evaluation results on ImageNet.

| Model | rFID | PSNR | SSIM | LPIPS |
| ------------- | ----- | ------ | ----- | ----- |
| sd-vae-ft-mse | 0.692 | 26.910 | 0.772 | 0.130 |
| sdxl-vae | 0.665 | 27.376 | 0.794 | 0.122 |
| flux-vae | 0.165 | 32.871 | 0.924 | 0.045 |

### Converting to diffusers

[huggingface/diffusers](https://github.com/huggingface/diffusers) is a library for diffusion models. It provides a script [convert_vae_pt_to_diffusers.py
](https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py) to convert a PyTorch Lightning model to a diffusers model.

Currently, the script is not updated for all kinds of VAE models, just for SD VAE.
## Visualization

```bash
python convert_vae_pt_to_diffusers.py \
--vae_path logs/autoencoder_kl_32x32x4/checkpoints/last.ckpt \
--dump_path autoencoder_kl_32x32x4 \
```
This is the visualization of AutoencoderKL. From left to right, there are the original image, the reconstructed image and the visualizations of the latent space. From top to bottom, there are the results of SDXL VAE, FLUX VAE and NextStep-1 VAE.

## Guidance
The latent space of SDXL contains a lot of noise, while the noise is significantly reduced in FLUX and NextStep-1.

Here are some guidance for training VAE. If there are any mistakes, please let me know.
![baka](assets/pca.png)

- Learning rate: In LDM repository [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion), the base learning rate is set to 4.5e-6 in the config file. However, the batch size is 12, accumulated gradient is 2 and `scale_lr` is set to `True`. Therefore, the effective learning rate is 4.5e-6 * 12 * 2 * 1 = 1.08e-4. It is better to set the learning rate from 1.0e-4 to 1.0e-5. In finetuning stage, it can be smaller than the first stage.
- `scale_lr`: It is better to set `scale_lr` to `False` when training on a large dataset.
- Discriminator: You should open the discriminator in the end of the training, when the VAE has good reconstruction performance. In default, `disc_start` is set to 50001.
- Perceptual loss: LPIPS is a good metric for evaluating the quality of the reconstructed images. Some models use other perceptual loss functions to gain better performance.
| Model | rFID | PSNR | SSIM | LPIPS |
| -------------- | ----- | ------ | ----- | ----- |
| sdxl-vae | 0.665 | 27.376 | 0.794 | 0.122 |
| flux-vae | 0.165 | 32.871 | 0.924 | 0.045 |
| NextStep-1 VAE | 1.201 | 30.160 | 0.883 | 0.061 |

## Acknowledgments

Thanks for the following repositories. Without their code, this project would not be possible.
Thank NextStep-1 for providing the open-source code and model weights.

- [Stability-AI/generative-models](https://github.com/Stability-AI/generative-models). We heavily borrow the code from this repository, just modifing a few parameters for our concept.
- [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion). We follow the hyperparameter settings of this repository in config files.
- [stepfun-ai/NextStep-1](https://github.com/stepfun-ai/NextStep-1).
Binary file added assets/distributions.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pca.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 97 additions & 0 deletions configs/nextstep-1-f8ch16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.autoencoder.AutoencoderKLSigma
params:
input_key: jpg
monitor: "val/loss/rec"
disc_start_iter: 50001
ckpt_path: stepfun-ai/NextStep-1-f8ch16-Tokenizer/checkpoint.pt
trainable_ae_params:
- ["decoder"]
noise_strength: 0.5 # 0.0 for evaluation

encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla-xformers
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0

decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params: ${model.params.encoder_config.params}

regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
params:
sample: true
deterministic: true
normalize_latents: true
patch_size: 2

loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 50001
disc_weight: 0.0
learn_logvar: false
pixel_loss: "l2"
regularization_weights:
kl_loss: 1.0

data:
target: sgm.data.imagenet.ImageNetLoader
params:
batch_size: 16
num_workers: 4
prefetch_factor: 2
shuffle: true

train:
root_dir: /path/to/ImageNet
size: 256
transform: true
validation:
root_dir: /path/to/ImageNet
size: 256
transform: true

lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True

modelcheckpoint:
params:
every_n_epochs: 1

callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000

image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True

trainer:
precision: bf16
devices: 0, 1, 2, 3
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
check_val_every_n_epoch: 1
18 changes: 18 additions & 0 deletions sgm/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def apply_ckpt(self, ckpt: Union[None, str, dict]):
sd = torch.load(ckpt, map_location="cpu")["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
elif ckpt.endswith("pt"):
sd = torch.load(ckpt, map_location="cpu")
else:
raise NotImplementedError

Expand Down Expand Up @@ -527,3 +529,19 @@ def __init__(self, **kwargs):
},
**kwargs,
)


class AutoencoderKLSigma(AutoencodingEngine):
def __init__(self, noise_strength: float = 0.0, **kwargs):
super().__init__(**kwargs)
self.noise_strength = noise_strength

def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
if self.noise_strength > 0.0:
p = torch.distributions.Uniform(0, self.noise_strength)
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * torch.randn_like(z)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
48 changes: 45 additions & 3 deletions sgm/modules/autoencoding/regularizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Tuple
from typing import Any, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -11,16 +11,58 @@


class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
def __init__(
self,
sample: bool = True,
deterministic: bool = False,
normalize_latents: bool = False,
patch_size: Optional[int] = None,
):
super().__init__()
self.sample = sample
self.deterministic = deterministic
self.normalize_latents = normalize_latents
self.patch_size = patch_size

def get_trainable_parameters(self) -> Any:
yield from ()

def patchify(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
p = self.patch_size
h_, w_ = h // p, w // p

x = x.reshape(b, c, h_, p, w_, p)
x = torch.einsum("bchpwq->bcpqhw", x)
x = x.reshape(b, c * p ** 2, h_, w_)
return x

def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
b, _, h_, w_ = x.shape
p = self.patch_size
c = x.shape[1] // (p ** 2)

x = x.reshape(b, c, p, p, h_, w_)
x = torch.einsum("bcpqhw->bchpwq", x)
x = x.reshape(b, c, h_ * p, w_ * p)
return x

def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)

# src: https://github.com/stepfun-ai/NextStep-1/blob/main/nextstep/models/modeling_flux_vae.py
mean, logvar = torch.chunk(z, 2, dim=1)
if self.patch_size is not None:
mean = self.patchify(mean)
if self.normalize_latents:
mean = mean.permute(0, 2, 3, 1)
mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6)
mean = mean.permute(0, 3, 1, 2)
if self.patch_size is not None:
mean = self.unpatchify(mean)
z = torch.cat([mean, logvar], dim=1).contiguous()

posterior = DiagonalGaussianDistribution(z, deterministic=self.deterministic)
if self.sample:
z = posterior.sample()
else:
Expand Down