diff --git a/README.md b/README.md index 35463e4..62075f1 100644 --- a/README.md +++ b/README.md @@ -1,114 +1,39 @@ # AutoencoderKL -## About The Project +This repository is a branch of [lavinal712/AutoencoderKL](https://github.com/lavinal712/AutoencoderKL) with some modifications for the NextStep-1 VAE. -There are many great training scripts for VAE on Github. However, some repositories are not maintained and some are not updated to the latest version of PyTorch. Therefore, I decided to create this repository to provide a simple and easy-to-use training script for VAE by Lightning. Beside, the code is easy to transfer to other projects for time saving. +## Sigma-VAE -- Support training and finetuning both [Stable Diffusion](https://github.com/CompVis/stable-diffusion) VAE and [FLUX](https://github.com/black-forest-labs/flux) VAE. -- Support evaluating reconstruction quality (FID, PSNR, SSIM, LPIPS). -- A practical guidance of training VAE. -- Easy to modify the code for your own research. +Sigma-VAE is proposed by [Multimodal Latent Language Modeling with Next-Token Diffusion](https://arxiv.org/abs/2412.08635) to prevent variance collapse by enforcing a fixed variance in the latent space. The reconstruction pass is computed as: -## Visualization - -This is the visualization of AutoencoderKL. From left to right, there are the original image, the reconstructed image and the difference between them. From top to bottom, there are the results of SD VAE, SDXL VAE and FLUX VAE. - -Image source: [https://www.bilibili.com/opus/762402574076739817](https://www.bilibili.com/opus/762402574076739817) - -![baka](assets/visualization.png) - - -## Getting Started - -To get a local copy up and running follow these simple example steps. - -### Installation - -```bash -git clone https://github.com/lavinal712/AutoencoderKL.git -cd AutoencoderKL -conda create -n autoencoderkl python=3.10 -y -conda activate autoencoderkl -pip install -r requirements.txt +```math +\begin{aligned} +\mu &= \text{Encoder}_\phi(x) \\ +z &= \mu + \sigma \odot \epsilon, \quad \text{where } \epsilon \sim \mathcal{N}(0,1),\ \sigma \sim \mathcal{N}(0,C_\sigma) \\ +\hat{x} &= \text{Decoder}_\psi(z) +\end{aligned} ``` -### Training - -To start training, you need to prepare a config file. You can refer to the config files in the `configs` folder. - -If you want to train on your own dataset, you should write your own data loader in `sgm/data` and modify the parameters in the config file. - -Finetuning a VAE model is simple. You just need to specify the `ckpt_path` and `trainable_ae_params` in the config file. To keep the latent space of the original model, it is recommended to set decoder to be trainable. +In [NextStep-1: Toward Autoregressive Image Generation with Continuous Tokens at Scale](https://arxiv.org/abs/2508.10711), they find that **a regularized latent space is critical for generation**. Specifically, applying higher noise intensity during tokenizer training increases generation loss but paradoxically improves the quality of the generated images. They attribute this phenomenon to noise regularization, cultivating a well-conditioned latent space. This process enhances two key properties: the tokenizer decoder’s robustness to latent perturbations and a more dispersed latent distribution, a property prior work has also found beneficial for generation. -Then, you can start training by running the following command. - -```bash -NUM_GPUS=4 -NUM_NODES=1 - -torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} main.py \ - --base configs/autoencoder_kl_32x32x4.yaml \ - --train \ - --logdir logs/autoencoder_kl_32x32x4 \ - --scale_lr True \ - --wandb False \ -``` +![distributions](assets/distributions.png) -### Evaluation - -We provide a script to evaluate the reconstruction quality of the trained model. `--resume` provides a convenient way to load the checkpoint from the log directory. - -We introduce multi-GPU and multi-thread method for faster evaluation. - -The default dataset is ImageNet. You can change the dataset by modifying the `--datadir` in the command line and the evaluation script. - -```bash -NUM_GPUS=4 -NUM_NODES=1 - -torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} eval.py \ - --resume logs/autoencoder_kl_32x32x4 \ - --base configs/autoencoder_kl_32x32x4.yaml \ - --logdir eval/autoencoder_kl_32x32x4 \ - --datadir /path/to/ImageNet \ - --image_size 256 \ - --batch_size 16 \ - --num_workers 16 \ -``` - -Here are the evaluation results on ImageNet. - -| Model | rFID | PSNR | SSIM | LPIPS | -| ------------- | ----- | ------ | ----- | ----- | -| sd-vae-ft-mse | 0.692 | 26.910 | 0.772 | 0.130 | -| sdxl-vae | 0.665 | 27.376 | 0.794 | 0.122 | -| flux-vae | 0.165 | 32.871 | 0.924 | 0.045 | - -### Converting to diffusers - -[huggingface/diffusers](https://github.com/huggingface/diffusers) is a library for diffusion models. It provides a script [convert_vae_pt_to_diffusers.py -](https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py) to convert a PyTorch Lightning model to a diffusers model. - -Currently, the script is not updated for all kinds of VAE models, just for SD VAE. +## Visualization -```bash -python convert_vae_pt_to_diffusers.py \ - --vae_path logs/autoencoder_kl_32x32x4/checkpoints/last.ckpt \ - --dump_path autoencoder_kl_32x32x4 \ -``` +This is the visualization of AutoencoderKL. From left to right, there are the original image, the reconstructed image and the visualizations of the latent space. From top to bottom, there are the results of SDXL VAE, FLUX VAE and NextStep-1 VAE. -## Guidance +The latent space of SDXL contains a lot of noise, while the noise is significantly reduced in FLUX and NextStep-1. -Here are some guidance for training VAE. If there are any mistakes, please let me know. +![baka](assets/pca.png) -- Learning rate: In LDM repository [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion), the base learning rate is set to 4.5e-6 in the config file. However, the batch size is 12, accumulated gradient is 2 and `scale_lr` is set to `True`. Therefore, the effective learning rate is 4.5e-6 * 12 * 2 * 1 = 1.08e-4. It is better to set the learning rate from 1.0e-4 to 1.0e-5. In finetuning stage, it can be smaller than the first stage. - - `scale_lr`: It is better to set `scale_lr` to `False` when training on a large dataset. -- Discriminator: You should open the discriminator in the end of the training, when the VAE has good reconstruction performance. In default, `disc_start` is set to 50001. -- Perceptual loss: LPIPS is a good metric for evaluating the quality of the reconstructed images. Some models use other perceptual loss functions to gain better performance. +| Model | rFID | PSNR | SSIM | LPIPS | +| -------------- | ----- | ------ | ----- | ----- | +| sdxl-vae | 0.665 | 27.376 | 0.794 | 0.122 | +| flux-vae | 0.165 | 32.871 | 0.924 | 0.045 | +| NextStep-1 VAE | 1.201 | 30.160 | 0.883 | 0.061 | ## Acknowledgments -Thanks for the following repositories. Without their code, this project would not be possible. +Thank NextStep-1 for providing the open-source code and model weights. -- [Stability-AI/generative-models](https://github.com/Stability-AI/generative-models). We heavily borrow the code from this repository, just modifing a few parameters for our concept. -- [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion). We follow the hyperparameter settings of this repository in config files. +- [stepfun-ai/NextStep-1](https://github.com/stepfun-ai/NextStep-1). diff --git a/assets/distributions.png b/assets/distributions.png new file mode 100644 index 0000000..914d3ae Binary files /dev/null and b/assets/distributions.png differ diff --git a/assets/pca.png b/assets/pca.png new file mode 100644 index 0000000..d39a7c4 Binary files /dev/null and b/assets/pca.png differ diff --git a/configs/nextstep-1-f8ch16.yaml b/configs/nextstep-1-f8ch16.yaml new file mode 100644 index 0000000..fc897da --- /dev/null +++ b/configs/nextstep-1-f8ch16.yaml @@ -0,0 +1,97 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.autoencoder.AutoencoderKLSigma + params: + input_key: jpg + monitor: "val/loss/rec" + disc_start_iter: 50001 + ckpt_path: stepfun-ai/NextStep-1-f8ch16-Tokenizer/checkpoint.pt + trainable_ae_params: + - ["decoder"] + noise_strength: 0.5 # 0.0 for evaluation + + encoder_config: + target: sgm.modules.diffusionmodules.model.Encoder + params: + attn_type: vanilla-xformers + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: ${model.params.encoder_config.params} + + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + params: + sample: true + deterministic: true + normalize_latents: true + patch_size: 2 + + loss_config: + target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator + params: + perceptual_weight: 0.25 + disc_start: 50001 + disc_weight: 0.0 + learn_logvar: false + pixel_loss: "l2" + regularization_weights: + kl_loss: 1.0 + +data: + target: sgm.data.imagenet.ImageNetLoader + params: + batch_size: 16 + num_workers: 4 + prefetch_factor: 2 + shuffle: true + + train: + root_dir: /path/to/ImageNet + size: 256 + transform: true + validation: + root_dir: /path/to/ImageNet + size: 256 + transform: true + +lightning: + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: True + + modelcheckpoint: + params: + every_n_epochs: 1 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 50000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + precision: bf16 + devices: 0, 1, 2, 3 + limit_val_batches: 50 + benchmark: True + accumulate_grad_batches: 1 + check_val_every_n_epoch: 1 diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py index 4114c22..b4bb1c4 100644 --- a/sgm/models/autoencoder.py +++ b/sgm/models/autoencoder.py @@ -54,6 +54,8 @@ def apply_ckpt(self, ckpt: Union[None, str, dict]): sd = torch.load(ckpt, map_location="cpu")["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) + elif ckpt.endswith("pt"): + sd = torch.load(ckpt, map_location="cpu") else: raise NotImplementedError @@ -527,3 +529,19 @@ def __init__(self, **kwargs): }, **kwargs, ) + + +class AutoencoderKLSigma(AutoencodingEngine): + def __init__(self, noise_strength: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.noise_strength = noise_strength + + def forward( + self, x: torch.Tensor, **additional_decode_kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + if self.noise_strength > 0.0: + p = torch.distributions.Uniform(0, self.noise_strength) + z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * torch.randn_like(z) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py index ff2b181..051e1b2 100644 --- a/sgm/modules/autoencoding/regularizers/__init__.py +++ b/sgm/modules/autoencoding/regularizers/__init__.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch import torch.nn as nn @@ -11,16 +11,58 @@ class DiagonalGaussianRegularizer(AbstractRegularizer): - def __init__(self, sample: bool = True): + def __init__( + self, + sample: bool = True, + deterministic: bool = False, + normalize_latents: bool = False, + patch_size: Optional[int] = None, + ): super().__init__() self.sample = sample + self.deterministic = deterministic + self.normalize_latents = normalize_latents + self.patch_size = patch_size def get_trainable_parameters(self) -> Any: yield from () + def patchify(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + p = self.patch_size + h_, w_ = h // p, w // p + + x = x.reshape(b, c, h_, p, w_, p) + x = torch.einsum("bchpwq->bcpqhw", x) + x = x.reshape(b, c * p ** 2, h_, w_) + return x + + def unpatchify(self, x: torch.Tensor) -> torch.Tensor: + b, _, h_, w_ = x.shape + p = self.patch_size + c = x.shape[1] // (p ** 2) + + x = x.reshape(b, c, p, p, h_, w_) + x = torch.einsum("bcpqhw->bchpwq", x) + x = x.reshape(b, c, h_ * p, w_ * p) + return x + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: log = dict() - posterior = DiagonalGaussianDistribution(z) + + # src: https://github.com/stepfun-ai/NextStep-1/blob/main/nextstep/models/modeling_flux_vae.py + mean, logvar = torch.chunk(z, 2, dim=1) + if self.patch_size is not None: + mean = self.patchify(mean) + if self.normalize_latents: + mean = mean.permute(0, 2, 3, 1) + mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6) + mean = mean.permute(0, 3, 1, 2) + if self.patch_size is not None: + mean = self.unpatchify(mean) + z = torch.cat([mean, logvar], dim=1).contiguous() + + posterior = DiagonalGaussianDistribution(z, deterministic=self.deterministic) if self.sample: z = posterior.sample() else: