diff --git a/README.md b/README.md index 42b33fe..a283dd5 100644 --- a/README.md +++ b/README.md @@ -1,98 +1,148 @@ -# AutoencoderKL +# Transfusion VAE -## About The Project +This repository is a branch of [lavinal712/AutoencoderKL](https://github.com/lavinal712/AutoencoderKL) with some modifications for the Transfusion VAE. -There are many great training scripts for VAE on Github. However, some repositories are not maintained and some are not updated to the latest version of PyTorch. Therefore, I decided to create this repository to provide a simple and easy-to-use training script for VAE by Lightning. Beside, the code is easy to transfer to other projects for time saving. +In paper [Transfusion](https://arxiv.org/abs/2408.11039), the authors proposed a new VAE for image generation. Different from the original VAE, the Transfusion VAE has 8 latent dimensions, and the training loss is: -- Support training and finetuning both [Stable Diffusion](https://github.com/CompVis/stable-diffusion) VAE and [Flux](https://github.com/black-forest-labs/flux) VAE. -- Support evaluating reconstruction quality (FID, PSNR, SSIM, LPIPS). -- A practical guidance of training VAE. -- Easy to modify the code for your own research. +```math +\mathcal{L}_{\text{VAE}} = \mathcal{L}_{1} + \mathcal{L}_{\text{LPIPS}} + 0.5 \mathcal{L}_{\text{GAN}} + 0.2 \mathcal{L}_{\text{ID}} + 0.000001 \mathcal{L}_{\text{KL}} +``` - -## Getting Started +where $`\mathcal{L}_{1}`$ is the L1 loss, $`\mathcal{L}_{\text{LPIPS}}`$ is the perceptual loss based on LPIPS similarity, $`\mathcal{L}_{\text{GAN}}`$ is a patch-based discriminator loss, $`\mathcal{L}_{\text{ID}}`$ is a perceptual loss based on Moco v2 model, and $`\mathcal{L}_{\text{KL}}`$ is the standard KL-regularization loss. + +## Visualization -To get a local copy up and running follow these simple example steps. +| Input | Reconstruction | +|--------------------------------------- |-----------------------------------------------------------| +| ![assets/inputs.png](assets/inputs.png) | ![assets/reconstructions.png](assets/reconstructions.png) | + +## Getting Started ### Installation -```bash -git clone https://github.com/lavinal712/AutoencoderKL.git +``` +git clone https://github.com/lavinal712/AutoencoderKL.git -b transfusion_vae cd AutoencoderKL conda create -n autoencoderkl python=3.10 -y conda activate autoencoderkl pip install -r requirements.txt ``` -### Training +### Model -To start training, you need to prepare a config file. You can refer to the config files in the `configs` folder. +You can load the pretrained model from [lavinal712/transfusion-vae](https://huggingface.co/lavinal712/transfusion-vae). -If you want to train on your own dataset, you should write your own data loader in `sgm/data` and modify the parameters in the config file. +```python +from diffusers import AutoencoderKL -Finetuning a VAE model is simple. You just need to specify the `ckpt_path` and `trainable_ae_params` in the config file. To keep the latent space of the original model, it is recommended to set decoder to be trainable. +vae = AutoencoderKL.from_pretrained("lavinal712/transfusion-vae") +``` -Then, you can start training by running the following command. +### Data -```bash -NUM_GPUS=4 -NUM_NODES=1 +We use combined dataset of ImageNet, COCO and FFHQ for training. Here is the structure of data: -torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} main.py \ - --base configs/autoencoder_kl_32x32x4.yaml \ - --train \ - --logdir logs/autoencoder_kl_32x32x4 \ - --scale_lr True \ - --wandb False \ +``` +ImageNet/ +├── train/ +│ ├── n01440764/ +│ │ ├── n01440764_18.JPEG +│ │ ├── n01440764_36.JPEG +│ │ └── ... +│ ├── n01443537/ +│ │ ├── n01443537_2.JPEG +│ │ ├── n01443537_16.JPEG +│ │ └── ... +│ ├── ... +├── val/ +│ ├── n01440764/ +│ │ ├── ILSVRC2012_val_00000293.JPEG +│ │ ├── ILSVRC2012_val_00002138.JPEG +│ │ └── ... +│ ├── n01443537/ +│ │ ├── ILSVRC2012_val_00000236.JPEG +│ │ ├── ILSVRC2012_val_00000262.JPEG +│ │ └── ... +│ ├── ... +└── ... ``` -### Evaluation +``` +COCO/ +├── annotations/ +│ ├── captions_train2017.json +│ ├── captions_val2017.json +│ ├── ... +├── test2017/ +│ ├── 000000000001.jpg +│ ├── 000000000016.jpg +│ └── ... +├── train2017/ +│ ├── 000000000009.jpg +│ ├── 000000000025.jpg +│ └── ... +├── val2017/ +│ ├── 000000000139.jpg +│ ├── 000000000285.jpg +│ └── ... +└── ... +``` -We provide a script to evaluate the reconstruction quality of the trained model. `--resume` provides a convenient way to load the checkpoint from the log directory. +``` +FFHQ/ +├── images1024x1024/ +│ ├── 00000/ +│ │ ├── 00000.png +│ │ ├── 00001.png +│ │ └── ... +│ ├── 01000/ +│ │ ├── 01000.png +│ │ ├── 01001.png +│ │ └── ... +│ ├── ... +│ ├── 69000/ +│ │ ├── 69000.png +│ │ ├── 69001.png +│ │ └── ... +│ └── LICENSE.txt +├── ffhq-dataset-v2.json +└── ... +``` -We introduce multi-GPU and multi-thread method for faster evaluation. +### Training -The default dataset is ImageNet. You can change the dataset by modifying the `--datadir` in the command line and the evaluation script. +We train the Transfusion VAE on combined dataset of ImageNet, COCO and FFHQ for 50 epochs. ```bash -NUM_GPUS=4 -NUM_NODES=1 - -torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} eval.py \ - --resume logs/autoencoder_kl_32x32x4 \ - --base configs/autoencoder_kl_32x32x4.yaml \ - --logdir eval/autoencoder_kl_32x32x4 \ - --datadir /path/to/ImageNet \ - --image_size 256 \ - --batch_size 16 \ - --num_workers 16 \ +torchrun --nproc_per_node=4 --nnodes=1 main.py \ + --base configs/transfusion_vae_32x32x8.yaml \ + --train \ + --scale_lr False \ + --wandb True \ ``` -### Converting to diffusers - -[huggingface/diffusers](https://github.com/huggingface/diffusers) is a library for diffusion models. It provides a script [convert_vae_pt_to_diffusers.py -](https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py) to convert a PyTorch Lightning model to a diffusers model. +### Evaluation -Currently, the script is not updated for all kinds of VAE models, just for SD VAE. +We evaluate the Transfusion VAE on ImageNet and COCO. -```bash -python convert_vae_pt_to_diffusers.py \ - --vae_path logs/autoencoder_kl_32x32x4/checkpoints/last.ckpt \ - --dump_path autoencoder_kl_32x32x4 \ -``` +ImageNet 2012 (256x256, val, 50000 images) -## Guidance +| Model | rFID | PSNR | SSIM | LPIPS | +|-----------------|-------|--------|-------|-------| +| Transfusion-VAE | 0.408 | 28.723 | 0.845 | 0.081 | +| SD-VAE | 0.692 | 26.910 | 0.772 | 0.130 | -Here are some guidance for training VAE. If there are any mistakes, please let me know. +COCO 2017 (256x256, val, 5000 images) -- Learning rate: In LDM repository [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion), the base learning rate is set to 4.5e-6 in the config file. However, the batch size is 12, accumulated gradient is 2 and `scale_lr` is set to `True`. Therefore, the effective learning rate is 4.5e-6 * 12 * 2 * 1 = 1.08e-4. It is better to set the learning rate from 1.0e-4 to 1.0e-5. In finetuning stage, it can be smaller than the first stage. - - `scale_lr`: It is better to set `scale_lr` to `False` when training on a large dataset. -- Discriminator: You should open the discriminator in the end of the training, when the VAE has good reconstruction performance. In default, `disc_start` is set to 50001. -- Perceptual loss: LPIPS is a good metric for evaluating the quality of the reconstructed images. Some models use other perceptual loss functions to gain better performance, such as [sypsyp97/convnext_perceptual_loss](https://github.com/sypsyp97/convnext_perceptual_loss). +| Model | rFID | PSNR | SSIM | LPIPS | +|-----------------|-------|--------|-------|-------| +| Transfusion-VAE | 2.749 | 28.556 | 0.855 | 0.078 | +| SD-VAE | 4.246 | 26.622 | 0.784 | 0.127 | -## Acknowledgments +## Acknowledgements -Thanks for the following repositories. Without their code, this project would not be possible. +Thanks to the following repositories for providing the code and models of Moco v2 and Moco v3, and a repository for the inspiration of the perceptual loss. -- [Stability-AI/generative-models](https://github.com/Stability-AI/generative-models). We heavily borrow the code from this repository, just modifing a few parameters for our concept. -- [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion). We follow the hyperparameter settings of this repository in config files. +- [facebookresearch/moco](https://github.com/facebookresearch/moco) +- [facebookresearch/moco-v3](https://github.com/facebookresearch/moco-v3) +- [sypsyp97/convnext_perceptual_loss](https://github.com/sypsyp97/convnext_perceptual_loss) diff --git a/assets/inputs.png b/assets/inputs.png new file mode 100644 index 0000000..096bd97 Binary files /dev/null and b/assets/inputs.png differ diff --git a/assets/reconstructions.png b/assets/reconstructions.png new file mode 100644 index 0000000..2742975 Binary files /dev/null and b/assets/reconstructions.png differ diff --git a/configs/convert_vae.yaml b/configs/convert_vae.yaml new file mode 100644 index 0000000..447f7e4 --- /dev/null +++ b/configs/convert_vae.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 8 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 8 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/configs/transfusion_vae_32x32x8.yaml b/configs/transfusion_vae_32x32x8.yaml new file mode 100644 index 0000000..a51bf57 --- /dev/null +++ b/configs/transfusion_vae_32x32x8.yaml @@ -0,0 +1,97 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.autoencoder.AutoencoderKL + params: + input_key: jpg + monitor: "val/loss/rec" + disc_start_iter: 50001 + embed_dim: 8 + + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 8 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + lossconfig: + target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator + params: + perceptual_weight: 1.0 + disc_start: 50001 + disc_weight: 0.5 + learn_logvar: false + + perceptual_name: mocov2_800ep + perceptual_config: + target: sgm.modules.autoencoding.moco.model.moco.MoCo + params: + arch: resnet50 + dim: 128 + K: 65536 + m: 0.999 + T: 0.2 + mlp: true + perceptual_weight_2: 0.2 + + regularization_weights: + kl_loss: 1.0e-6 + +data: + target: sgm.data.imagenet_coco_ffhq.ImageNetCOCOFFHQLoader + params: + batch_size: 12 + num_workers: 4 + prefetch_factor: 2 + shuffle: true + + train: + imagenet_dir: /path/to/ImageNet + coco_dir: /path/to/MSCOCO + ffhq_dir: /path/to/FFHQ + size: 256 + transform: random_crop + validation: + imagenet_dir: /path/to/ImageNet + coco_dir: /path/to/MSCOCO + ffhq_dir: /path/to/FFHQ + size: 256 + transform: random_crop + +lightning: + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: True + + modelcheckpoint: + params: + every_n_epochs: 1 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 50000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + precision: bf16 + devices: 0, 1, 2, 3 + limit_val_batches: 50 + benchmark: True + accumulate_grad_batches: 1 + check_val_every_n_epoch: 1 + max_epochs: 50 diff --git a/convert_vae_pt_to_diffusers.py b/convert_vae_pt_to_diffusers.py new file mode 100644 index 0000000..ceb8cf6 --- /dev/null +++ b/convert_vae_pt_to_diffusers.py @@ -0,0 +1,160 @@ +import argparse +import io + +import requests +import torch +import yaml + +from diffusers import AutoencoderKL +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + assign_to_checkpoint, + conv_attn_to_linear, + create_vae_diffusers_config, + renew_vae_attention_paths, + renew_vae_resnet_paths, +) + + +def custom_convert_ldm_vae_checkpoint(checkpoint, config): + vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def vae_pt_to_vae_diffuser( + checkpoint_path: str, + output_path: str, +): + # Only support V1 + r = requests.get( + " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + io_obj = io.BytesIO(r.content) + + original_config = yaml.safe_load(io_obj) + original_config = yaml.safe_load(open("configs/convert_vae.yaml")) + image_size = 256 + device = "cuda" if torch.cuda.is_available() else "cpu" + if checkpoint_path.endswith("safetensors"): + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"] + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + vae.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") + + args = parser.parse_args() + + vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path) diff --git a/main.py b/main.py index 79a1fd1..19d95d2 100644 --- a/main.py +++ b/main.py @@ -132,7 +132,7 @@ def str2bool(v): parser.add_argument( "--projectname", type=str, - default="autoencoderkl", + default="transfusion_vae", ) parser.add_argument( "-l", diff --git a/requirements_transfusion_vae.txt b/requirements_transfusion_vae.txt new file mode 100644 index 0000000..e1cd83f --- /dev/null +++ b/requirements_transfusion_vae.txt @@ -0,0 +1,239 @@ +absl-py==2.1.0 +accelerate==0.34.2 +addict==2.4.0 +aiofiles==23.2.1 +aiohappyeyeballs==2.4.2 +aiohttp==3.10.6 +aiosignal==1.3.1 +albucore==0.0.20 +albumentations==0.4.3 +altair==5.4.1 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.6.0 +asttokens==2.4.1 +async-timeout==4.0.3 +attrs==24.2.0 +av==13.0.0 +backcall==0.2.0 +basicsr==1.4.2 +beartype==0.19.0 +beautifulsoup4==4.12.3 +bencode.py==2.0.0 +bitstring==3.1.5 +black==23.7.0 +bleach==6.2.0 +blinker==1.8.2 +braceexpand==0.1.7 +cachetools==5.5.0 +certifi==2024.8.30 +chardet==5.1.0 +charset-normalizer==3.3.2 +clean-fid==0.1.35 +click==8.1.7 +clip @ git+https://github.com/openai/CLIP.git +cmake==3.30.3 +coloredlogs==15.0.1 +contourpy==1.3.0 +cycler==0.12.1 +datasets==3.1.0 +decorator==4.4.2 +decord==0.6.0 +deepspeed==0.15.1 +defusedxml==0.7.1 +diffusers==0.31.0 +dill==0.3.8 +docker-pycreds==0.4.0 +docopt==0.6.2 +einops==0.8.0 +einx==0.3.0 +entrypoints==0.4 +eval_type_backport==0.2.0 +exceptiongroup==1.2.2 +executing==2.1.0 +fairscale==0.4.13 +fastapi==0.115.0 +fastjsonschema==2.21.0 +ffmpy==0.4.0 +filelock==3.16.1 +fire==0.6.0 +flatbuffers==24.3.25 +fonttools==4.54.1 +frozendict==2.4.4 +frozenlist==1.4.1 +fsspec==2024.9.0 +ftfy==6.2.3 +future==1.0.0 +gitdb==4.0.11 +GitPython==3.1.43 +grpcio==1.66.1 +h11==0.14.0 +hjson==3.1.0 +httpcore==1.0.5 +httpx==0.27.2 +huggingface-hub==0.25.1 +humanfriendly==10.0 +idna==3.10 +imageio==2.6.0 +imageio-ffmpeg==0.5.1 +imgaug==0.2.6 +importlib_metadata==8.5.0 +importlib_resources==6.4.5 +invisible-watermark==0.2.0 +ipython==8.12.3 +jedi==0.19.1 +Jinja2==3.1.4 +jsonschema==4.23.0 +jsonschema-specifications==2023.12.1 +jupyter_client==8.6.3 +jupyter_core==5.7.2 +jupyterlab_pygments==0.3.0 +kiwisolver==1.4.7 +kornia==0.8.0 +kornia_rs==0.1.8 +lazy_loader==0.4 +lightning-utilities==0.11.7 +lit==18.1.8 +llvmlite==0.43.0 +lmdb==1.5.1 +loralib==0.1.2 +lpips==0.1.4 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.7.0 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mistune==3.0.2 +moviepy==1.0.3 +mpmath==1.3.0 +multidict==6.1.0 +multiprocess==0.70.16 +mypy-extensions==1.0.0 +narwhals==1.8.3 +natsort==8.4.0 +nbclient==0.10.0 +nbconvert==7.16.4 +nbformat==5.10.4 +networkx==3.3 +ninja==1.11.1.1 +numba==0.60.0 +numpy==1.26.4 +nvitop==1.3.2 +omegaconf==2.3.0 +onnxruntime==1.19.2 +open-clip-torch==2.24.0 +opencv-python==4.6.0.66 +opencv-python-headless==4.10.0.84 +orjson==3.10.7 +packaging==24.2 +pandas==2.2.3 +pandocfilters==1.5.1 +parso==0.8.4 +pathspec==0.12.1 +pexpect==4.9.0 +pickleshare==0.7.5 +pillow==11.1.0 +pip==24.2 +pipreqs==0.5.0 +platformdirs==4.3.6 +pooch==1.8.2 +proglog==0.1.10 +prompt_toolkit==3.0.48 +protobuf==3.20.3 +psutil==6.0.0 +ptyprocess==0.7.0 +pudb==2024.1.3 +pure_eval==0.2.3 +py-cpuinfo==9.0.0 +pyarrow==17.0.0 +pydantic==2.9.2 +pydantic_core==2.23.4 +pydeck==0.9.1 +pyDeprecate==0.3.1 +pydub==0.25.1 +Pygments==2.18.0 +PyMatting==1.1.12 +pyparsing==3.1.4 +PyPubSub==3.3.0 +python-dateutil==2.9.0.post0 +python-multipart==0.0.10 +pytorch-fid==0.3.0 +pytorch-lightning==2.0.1 +pytz==2024.2 +PyWavelets==1.7.0 +PyYAML==6.0.2 +pyzmq==26.2.0 +referencing==0.35.1 +regex==2024.9.11 +rembg==2.0.59 +requests==2.32.3 +rich==13.8.1 +rpds-py==0.20.0 +ruff==0.6.8 +safetensors==0.5.2 +scikit-image==0.20.0 +scipy==1.14.1 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==2.14.0 +setproctitle==1.3.3 +setuptools==75.1.0 +shellingham==1.5.4 +simsimd==6.1.1 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +soupsieve==2.6 +sqlparse==0.5.1 +stack-data==0.6.3 +starlette==0.38.6 +stringzilla==3.10.10 +sympy==1.13.1 +taming-transformers==0.0.1 +taming-transformers-rom1504==0.0.6 +tb-nightly==2.19.0a20240926 +tenacity==8.5.0 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tensorboardX==2.6 +termcolor==2.4.0 +test_tube==0.7.5 +tifffile==2024.9.20 +timm==1.0.14 +tinycss2==1.4.0 +tokenizers==0.19.1 +toml==0.10.2 +tomli==2.0.1 +tomlkit==0.12.0 +torch==2.5.1 +torchaudio==2.5.1 +torchmetrics==1.6.1 +torchvision==0.20.1 +tornado==6.4.1 +tqdm==4.67.1 +traitlets==5.14.3 +transformers==4.43.0 +triton==3.1.0 +typer==0.12.5 +typing_extensions==4.12.2 +tzdata==2024.2 +urllib3==1.26.20 +urwid==2.6.15 +urwid_readline==0.15.1 +uvicorn==0.30.6 +vector-quantize-pytorch==1.17.8 +wandb==0.19.6 +watchdog==4.0.2 +wcwidth==0.2.13 +webdataset==0.2.100 +webencodings==0.5.1 +websockets==11.0.3 +Werkzeug==3.0.4 +wheel==0.44.0 +xformers==0.0.29.post1 +xxhash==3.5.0 +yapf==0.40.2 +yarg==0.1.9 +yarl==1.13.0 +zipp==3.20.2 diff --git a/scale_factor.py b/scale_factor.py new file mode 100644 index 0000000..0b81da4 --- /dev/null +++ b/scale_factor.py @@ -0,0 +1,225 @@ +import argparse +import glob +import os +import subprocess +from concurrent.futures import ThreadPoolExecutor +from itertools import chain + +import numpy as np +import torch +import torch.distributed as dist +from natsort import natsorted +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from skimage.metrics import peak_signal_noise_ratio +from skimage.metrics import structural_similarity +from torchvision import transforms +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from sgm.data.imagenet import ImageNetDataset +from sgm.util import instantiate_from_config +from sgm.modules.autoencoding.lpips.loss.lpips import LPIPS + + +def get_parser(**parser_kwargs): + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=0, + help="seed for initialization", + ) + parser.add_argument( + "-d", + "--datadir", + type=str, + default="data", + help="directory for testing data", + ) + parser.add_argument( + "-iz", + "--image_size", + type=int, + default=256, + help="image size for testing data", + ) + parser.add_argument( + "-bz", + "--batch_size", + type=int, + default=1, + help="batch size for sampling data", + ) + parser.add_argument( + "-nw", + "--num_workers", + type=int, + default=0, + help="number of workers for sampling data", + ) + if version.parse(torch.__version__) >= version.parse("2.0.0"): + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="single checkpoint file to resume from", + ) + return parser + + +def get_checkpoint_name(logdir): + ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt") + ckpt = natsorted(glob.glob(ckpt)) + print('available "last" checkpoints:') + print(ckpt) + if len(ckpt) > 1: + print("got most recent checkpoint") + ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1] + print(f"Most recent ckpt is {ckpt}") + with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f: + f.write(ckpt + "\n") + try: + version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0]) + except Exception as e: + print("version confusion but not bad") + print(e) + version = 1 + # version = last_version + 1 + else: + # in this case, we only have one "last.ckpt" + ckpt = ckpt[0] + version = 1 + melk_ckpt_name = f"last-v{version}.ckpt" + print(f"Current melk ckpt name: {melk_ckpt_name}") + return ckpt, melk_ckpt_name + + +if __name__ == "__main__": + parser = get_parser() + + opt, unknown = parser.parse_known_args() + + if not opt.resume and not opt.resume_from_checkpoint: + raise ValueError( + "-r/--resume or --resume_from_checkpoint must be specified." + ) + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + # idx = len(paths)-paths[::-1].index("logs")+1 + # logdir = "/".join(paths[:idx]) + logdir = "/".join(paths[:-2]) + ckpt = opt.resume + _, melk_ckpt_name = get_checkpoint_name(logdir) + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt, melk_ckpt_name = get_checkpoint_name(logdir) + + print("#" * 100) + print(f'Resuming from checkpoint "{ckpt}"') + print("#" * 100) + + opt.resume_from_checkpoint = ckpt + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + opt.base = base_configs + opt.base + + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = opt.seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + # model + model = instantiate_from_config(config.model) + model.apply_ckpt(opt.resume_from_checkpoint) + model.to(device) + model.eval() + + perceptual_model = LPIPS().eval() + perceptual_model.to(device) + + # data + transform = transforms.Compose([ + transforms.Resize(opt.image_size), + transforms.CenterCrop(opt.image_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + dataset = ImageNetDataset(opt.datadir, split="val", transform=transform) + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False, + seed=opt.seed, + drop_last=False, + ) + loader = DataLoader( + dataset, + batch_size=opt.batch_size, + shuffle=False, + sampler=sampler, + num_workers=opt.num_workers, + pin_memory=True, + drop_last=False, + ) + print(f"Dataset contains {len(dataset):,} images ({opt.datadir})") + + z_list = [] + for batch in tqdm(loader): + x = batch["jpg"].to(device) + + with torch.no_grad(): + z = model.encode(x) + x_hat = model.decode(z) + + for i in range(x.shape[0]): + z_list.append(z[i].detach().cpu()) + + world_size = dist.get_world_size() + gather_z_list = [None for _ in range(world_size)] + dist.all_gather_object(gather_z_list, z_list) + + if rank == 0: + z_list = list(chain(*gather_z_list)) + z_list = torch.stack(z_list, dim=0) + print("scale factor: ", (1 / torch.std(z_list)).item()) + + dist.barrier() + dist.destroy_process_group() diff --git a/sgm/data/imagenet_coco_ffhq.py b/sgm/data/imagenet_coco_ffhq.py new file mode 100644 index 0000000..e5a841d --- /dev/null +++ b/sgm/data/imagenet_coco_ffhq.py @@ -0,0 +1,189 @@ +import json +import os +from typing import Optional + +import numpy as np +import pytorch_lightning as pl +import torch +from omegaconf import DictConfig +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from torchvision.datasets import ImageFolder + + +class ImageNetCOCOFFHQDataset(Dataset): + def __init__( + self, + imagenet_dir, + coco_dir, + ffhq_dir, + split="train", + transform=None, + ): + self.imagenet_dir = imagenet_dir + self.coco_dir = coco_dir + self.ffhq_dir = ffhq_dir + self.split = split + self.transform = transform + self.imagenet_dataset = ImageFolder(os.path.join(imagenet_dir, split), transform=self.transform) + self.coco_dataset = self._get_coco_dataset(coco_dir, split) + self.ffhq_dataset = self._get_ffhq_dataset(ffhq_dir, split) + + def _get_coco_dataset(self, root_dir, split): + data_json = os.path.join(root_dir, "annotations", f"captions_{split}2017.json") + with open(data_json, "r") as json_file: + self.json_data = json.load(json_file) + self.img_id_to_filepath = dict() + self.img_id_to_captions = dict() + + imagedirs = self.json_data["images"] + self.labels = {"image_ids": list()} + for imgdir in imagedirs: + self.img_id_to_filepath[imgdir["id"]] = os.path.join( + root_dir, f"{split}2017", imgdir["file_name"] + ) + self.img_id_to_captions[imgdir["id"]] = list() + self.labels["image_ids"].append(imgdir["id"]) + + capdirs = self.json_data["annotations"] + for capdir in capdirs: + # there are in average 5 captions per image + self.img_id_to_captions[capdir["image_id"]].append(np.array(capdir["caption"])) + + dataset = [] + for img_id in self.labels["image_ids"]: + dataset.append((self.img_id_to_filepath[img_id], img_id)) + + return dataset + + def _get_ffhq_dataset(self, root_dir, split): + data_json = os.path.join(root_dir, "ffhq-dataset-v2.json") + with open(data_json, "r") as json_file: + self.metadata = json.load(json_file) + split_map = {"train": "training", "val": "validation"} + + dataset = [] + for i, data in self.metadata.items(): + category = data["category"] + if category == split_map[split]: + image_path = os.path.join(root_dir, data["image"]["file_path"]) + dataset.append((image_path, int(i))) + + return dataset + + def __len__(self): + return len(self.imagenet_dataset) + len(self.coco_dataset) + len(self.ffhq_dataset) + + def __getitem__(self, idx): + if idx < len(self.imagenet_dataset): + return {"jpg": self.imagenet_dataset[idx][0], "cls": self.imagenet_dataset[idx][1]} + elif idx < len(self.imagenet_dataset) + len(self.coco_dataset): + idx = idx - len(self.imagenet_dataset) + img_path = self.coco_dataset[idx][0] + image = Image.open(img_path) + if image.mode != "RGB": + image = image.convert("RGB") + image = self.transform(image) + return {"jpg": image, "cls": self.coco_dataset[idx][1]} + else: + idx = idx - len(self.imagenet_dataset) - len(self.coco_dataset) + img_path = self.ffhq_dataset[idx][0] + image = Image.open(img_path) + if image.mode != "RGB": + image = image.convert("RGB") + image = self.transform(image) + return {"jpg": image, "cls": self.ffhq_dataset[idx][1]} + + +class ImageNetCOCOFFHQLoader(pl.LightningDataModule): + def __init__( + self, + batch_size, + train: DictConfig = None, + validation: Optional[DictConfig] = None, + num_workers: int = 0, + prefetch_factor: int = 2, + shuffle: bool = False, + shuffle_test_loader: bool = False, + shuffle_val_dataloader: bool = False, + ): + super().__init__() + + self.batch_size = batch_size + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.prefetch_factor = prefetch_factor + self.shuffle = shuffle + self.shuffle_test_loader = shuffle_test_loader + self.shuffle_val_dataloader = shuffle_val_dataloader + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] + ) + if train.get("transform", None): + size = train.get("size", 256) + if train.get("transform", None) == "center_crop": + transform = transforms.Compose([ + transforms.Resize(size), + transforms.CenterCrop(size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif train.get("transform", None) == "random_crop": + transform = transforms.Compose([ + transforms.Resize(size), + transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + else: + raise ValueError(f"Invalid transform: {train.get('transform', None)}") + + self.train_dataset = ImageNetCOCOFFHQDataset( + imagenet_dir=train.imagenet_dir, + coco_dir=train.coco_dir, + ffhq_dir=train.ffhq_dir, + split="train", + transform=transform, + ) + if validation is not None: + self.test_dataset = ImageNetCOCOFFHQDataset( + imagenet_dir=validation.imagenet_dir, + coco_dir=validation.coco_dir, + ffhq_dir=validation.ffhq_dir, + split="val", + transform=transform, + ) + else: + print("Warning: No Validation Datasetdefined, using that one from training") + self.test_dataset = self.train_dataset + + def prepare_data(self): + pass + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle_test_loader, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + def val_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle_val_dataloader, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) diff --git a/sgm/modules/autoencoding/losses/discriminator_loss.py b/sgm/modules/autoencoding/losses/discriminator_loss.py index 5333c0d..036a4df 100644 --- a/sgm/modules/autoencoding/losses/discriminator_loss.py +++ b/sgm/modules/autoencoding/losses/discriminator_loss.py @@ -1,17 +1,29 @@ from typing import Dict, Iterator, List, Optional, Tuple, Union +import kornia import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import torchvision +import torchvision.models as models from einops import rearrange from matplotlib import colormaps from matplotlib import pyplot as plt +from omegaconf import DictConfig, OmegaConf from ....util import default, instantiate_from_config from ..lpips.loss.lpips import LPIPS from ..lpips.model.model import weights_init from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss +from ..moco.model import vits +from ..moco.util import get_ckpt_path + + +ARCH_MAP = { + "mocov3_vits": "vit_small", + "mocov3_vitb": "vit_base", +} class GeneralLPIPSWithDiscriminator(nn.Module): @@ -24,6 +36,9 @@ def __init__( disc_factor: float = 1.0, disc_weight: float = 1.0, perceptual_weight: float = 1.0, + perceptual_name: Optional[str] = None, + perceptual_config: Optional[Dict] = None, + perceptual_weight_2: float = 1.0, pixel_loss: str = "l1", disc_loss: str = "hinge", scale_input_to_tgt_size: bool = False, @@ -45,14 +60,38 @@ def __init__( self.scale_input_to_tgt_size = scale_input_to_tgt_size assert pixel_loss in ["l1", "l2"] assert disc_loss in ["hinge", "vanilla"] - self.pixel_loss = lambda x, y: torch.abs(x - y) if pixel_loss == "l1" else torch.pow(x - y, 2) self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight + self.perceptual_name = default(perceptual_name, "mocov2_800ep") + perceptual_config = default( + perceptual_config, + { + "target": "sgm.modules.autoencoding.moco.model.moco.MoCo", + "params": { + "dim": 128, + "K": 65536, + "m": 0.999, + "T": 0.2, + "mlp": True, + }, + }, + ) + arch = perceptual_config["params"].pop("arch", "resnet50") + if isinstance(perceptual_config, DictConfig): + perceptual_config = OmegaConf.to_container(perceptual_config) + if arch == "resnet50": + perceptual_config["params"]["base_encoder"] = models.__dict__[arch] + else: + perceptual_config["params"]["base_encoder"] = vits.__dict__[ARCH_MAP[arch]] + self.perceptual_loss_2 = instantiate_from_config(perceptual_config).eval() + self.load_perceptual_from_pretrained(self.perceptual_name) + self.perceptual_weight_2 = perceptual_weight_2 # output log variance self.logvar = nn.Parameter( torch.full((), logvar_init), requires_grad=learn_logvar ) self.learn_logvar = learn_logvar + self.pixel_loss = lambda x, y: torch.abs(x - y) if pixel_loss == "l1" else torch.pow(x - y, 2) self.use_mean = use_mean discriminator_config = default( @@ -87,6 +126,16 @@ def __init__( self.additional_log_keys = set(default(additional_log_keys, [])) self.additional_log_keys.update(set(self.regularization_weights.keys())) + self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]), persistent=False) + self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]), persistent=False) + + def load_perceptual_from_pretrained(self, name="mocov2_800ep"): + ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/moco") + sd = torch.load(ckpt, map_location=torch.device("cpu"), weights_only=True)["state_dict"] + sd = {k.replace("module.", ""): v for k, v in sd.items() if k.startswith("module.")} + self.perceptual_loss_2.load_state_dict(sd, strict=False) + print("loaded pretrained MoCo loss from {}".format(ckpt)) + def get_trainable_parameters(self) -> Iterator[nn.Parameter]: return self.discriminator.parameters() @@ -209,6 +258,18 @@ def calculate_adaptive_weight( d_weight = d_weight * self.discriminator_weight return d_weight + def rescale(self, x: torch.Tensor) -> torch.Tensor: + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=True, + ) + x = (x + 1.0) / 2.0 + x = kornia.enhance.normalize(x, mean=self.mean, std=self.std) + return x + def forward( self, inputs: torch.Tensor, @@ -238,8 +299,19 @@ def forward( inputs.contiguous(), reconstructions.contiguous() ) rec_loss = rec_loss + self.perceptual_weight * p_loss + if self.perceptual_weight_2 > 0: + rescale_inputs = self.rescale(inputs) + rescale_reconstructions = self.rescale(reconstructions) + if "vit" in self.perceptual_name: + rescale_inputs_perceptual_features = self.perceptual_loss_2.base_encoder(rescale_inputs) + rescale_reconstructions_perceptual_features = self.perceptual_loss_2.base_encoder(rescale_reconstructions) + else: + rescale_inputs_perceptual_features = self.perceptual_loss_2.encoder_q(rescale_inputs) + rescale_reconstructions_perceptual_features = self.perceptual_loss_2.encoder_q(rescale_reconstructions) + p_loss_2 = F.mse_loss(rescale_inputs_perceptual_features, rescale_reconstructions_perceptual_features) + rec_loss = rec_loss + self.perceptual_weight_2 * p_loss_2 - nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights, self.use_mean) + nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights, use_mean=self.use_mean) # now the GAN part if optimizer_idx == 0: @@ -265,6 +337,21 @@ def forward( if k in self.additional_log_keys: log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() + # metrics + if not self.training: + metrics_inputs = (inputs.clamp(-1.0, 1.0) + 1.0) / 2.0 + metrics_reconstructions = (reconstructions.clamp(-1.0, 1.0) + 1.0) / 2.0 + psnr = kornia.metrics.psnr(metrics_inputs, metrics_reconstructions, max_val=1.0) + ssim = kornia.metrics.ssim(metrics_inputs, metrics_reconstructions, window_size=11, max_val=1.0) + lpips = p_loss if self.perceptual_weight > 0 else \ + self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + + log.update({ + f"{split}/metrics/psnr": psnr.detach().mean(), + f"{split}/metrics/ssim": ssim.detach().mean(), + f"{split}/metrics/lpips": lpips.detach().mean(), + }) + log.update( { f"{split}/loss/total": loss.clone().detach().mean(), diff --git a/sgm/modules/autoencoding/moco/.gitignore b/sgm/modules/autoencoding/moco/.gitignore new file mode 100644 index 0000000..7c40d2b --- /dev/null +++ b/sgm/modules/autoencoding/moco/.gitignore @@ -0,0 +1,5 @@ +moco_v1_200ep_pretrain.pth +moco_v2_200ep_pretrain.pth +moco_v2_800ep_pretrain.pth +vit-s-300ep.pth +vit-b-300ep.pth diff --git a/sgm/modules/autoencoding/moco/__init__.py b/sgm/modules/autoencoding/moco/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sgm/modules/autoencoding/moco/model/__init__.py b/sgm/modules/autoencoding/moco/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sgm/modules/autoencoding/moco/model/moco.py b/sgm/modules/autoencoding/moco/model/moco.py new file mode 100644 index 0000000..7952a98 --- /dev/null +++ b/sgm/modules/autoencoding/moco/model/moco.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +class MoCo(nn.Module): + """ + Build a MoCo model with: a query encoder, a key encoder, and a queue + https://arxiv.org/abs/1911.05722 + """ + + def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): + """ + dim: feature dimension (default: 128) + K: queue size; number of negative keys (default: 65536) + m: moco momentum of updating key encoder (default: 0.999) + T: softmax temperature (default: 0.07) + """ + super(MoCo, self).__init__() + + self.K = K + self.m = m + self.T = T + + # create the encoders + # num_classes is the output fc dimension + self.encoder_q = base_encoder(num_classes=dim) + self.encoder_k = base_encoder(num_classes=dim) + + if mlp: # hack: brute-force replacement + dim_mlp = self.encoder_q.fc.weight.shape[1] + self.encoder_q.fc = nn.Sequential( + nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc + ) + self.encoder_k.fc = nn.Sequential( + nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc + ) + + for param_q, param_k in zip( + self.encoder_q.parameters(), self.encoder_k.parameters() + ): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + # create the queue + self.register_buffer("queue", torch.randn(dim, K)) + self.queue = nn.functional.normalize(self.queue, dim=0) + + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip( + self.encoder_q.parameters(), self.encoder_k.parameters() + ): + param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys): + # gather keys before updating queue + keys = concat_all_gather(keys) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.K % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr : ptr + batch_size] = keys.T + ptr = (ptr + batch_size) % self.K # move pointer + + self.queue_ptr[0] = ptr + + @torch.no_grad() + def _batch_shuffle_ddp(self, x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all).cuda() + + # broadcast to all gpus + torch.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + @torch.no_grad() + def _batch_unshuffle_ddp(self, x, idx_unshuffle): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] + + def forward(self, im_q, im_k): + """ + Input: + im_q: a batch of query images + im_k: a batch of key images + Output: + logits, targets + """ + + # compute query features + q = self.encoder_q(im_q) # queries: NxC + q = nn.functional.normalize(q, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + # shuffle for making use of BN + im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) + + k = self.encoder_k(im_k) # keys: NxC + k = nn.functional.normalize(k, dim=1) + + # undo shuffle + k = self._batch_unshuffle_ddp(k, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) + + # logits: Nx(1+K) + logits = torch.cat([l_pos, l_neg], dim=1) + + # apply temperature + logits /= self.T + + # labels: positive key indicators + labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() + + # dequeue and enqueue + self._dequeue_and_enqueue(k) + + return logits, labels + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output diff --git a/sgm/modules/autoencoding/moco/model/mocov3.py b/sgm/modules/autoencoding/moco/model/mocov3.py new file mode 100644 index 0000000..268bd83 --- /dev/null +++ b/sgm/modules/autoencoding/moco/model/mocov3.py @@ -0,0 +1,137 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +class MoCo(nn.Module): + """ + Build a MoCo model with a base encoder, a momentum encoder, and two MLPs + https://arxiv.org/abs/1911.05722 + """ + def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0): + """ + dim: feature dimension (default: 256) + mlp_dim: hidden dimension in MLPs (default: 4096) + T: softmax temperature (default: 1.0) + """ + super(MoCo, self).__init__() + + self.T = T + + # build encoders + self.base_encoder = base_encoder(num_classes=mlp_dim) + self.momentum_encoder = base_encoder(num_classes=mlp_dim) + + self._build_projector_and_predictor_mlps(dim, mlp_dim) + + for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): + param_m.data.copy_(param_b.data) # initialize + param_m.requires_grad = False # not update by gradient + + def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True): + mlp = [] + for l in range(num_layers): + dim1 = input_dim if l == 0 else mlp_dim + dim2 = output_dim if l == num_layers - 1 else mlp_dim + + mlp.append(nn.Linear(dim1, dim2, bias=False)) + + if l < num_layers - 1: + mlp.append(nn.BatchNorm1d(dim2)) + mlp.append(nn.ReLU(inplace=True)) + elif last_bn: + # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 + # for simplicity, we further removed gamma in BN + mlp.append(nn.BatchNorm1d(dim2, affine=False)) + + return nn.Sequential(*mlp) + + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + pass + + @torch.no_grad() + def _update_momentum_encoder(self, m): + """Momentum update of the momentum encoder""" + for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): + param_m.data = param_m.data * m + param_b.data * (1. - m) + + def contrastive_loss(self, q, k): + # normalize + q = nn.functional.normalize(q, dim=1) + k = nn.functional.normalize(k, dim=1) + # gather all targets + k = concat_all_gather(k) + # Einstein sum is more intuitive + logits = torch.einsum('nc,mc->nm', [q, k]) / self.T + N = logits.shape[0] # batch size per GPU + labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda() + return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) + + def forward(self, x1, x2, m): + """ + Input: + x1: first views of images + x2: second views of images + m: moco momentum + Output: + loss + """ + + # compute features + q1 = self.predictor(self.base_encoder(x1)) + q2 = self.predictor(self.base_encoder(x2)) + + with torch.no_grad(): # no gradient + self._update_momentum_encoder(m) # update the momentum encoder + + # compute momentum features as targets + k1 = self.momentum_encoder(x1) + k2 = self.momentum_encoder(x2) + + return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1) + + +class MoCo_ResNet(MoCo): + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + hidden_dim = self.base_encoder.fc.weight.shape[1] + del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer + + # projectors + self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) + self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) + + # predictor + self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False) + + +class MoCo_ViT(MoCo): + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + hidden_dim = self.base_encoder.head.weight.shape[1] + del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer + + # projectors + self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) + self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) + + # predictor + self.predictor = self._build_mlp(2, dim, mlp_dim, dim) + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output diff --git a/sgm/modules/autoencoding/moco/model/vits.py b/sgm/modules/autoencoding/moco/model/vits.py new file mode 100644 index 0000000..41d77ae --- /dev/null +++ b/sgm/modules/autoencoding/moco/model/vits.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +import torch.nn as nn +from functools import partial, reduce +from operator import mul + +from timm.layers.helpers import to_2tuple +from timm.models.vision_transformer import VisionTransformer, _cfg +from timm.models.layers import PatchEmbed + +__all__ = [ + 'vit_small', + 'vit_base', + 'vit_conv_small', + 'vit_conv_base', +] + + +class VisionTransformerMoCo(VisionTransformer): + def __init__(self, stop_grad_conv1=False, **kwargs): + super().__init__(**kwargs) + # Use fixed 2D sin-cos position embedding + self.build_2d_sincos_position_embedding() + + # weight initialization + for name, m in self.named_modules(): + if isinstance(m, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) + nn.init.uniform_(m.weight, -val, val) + else: + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + nn.init.normal_(self.cls_token, std=1e-6) + + if isinstance(self.patch_embed, PatchEmbed): + # xavier_uniform initialization + val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) + nn.init.uniform_(self.patch_embed.proj.weight, -val, val) + nn.init.zeros_(self.patch_embed.proj.bias) + + if stop_grad_conv1: + self.patch_embed.proj.weight.requires_grad = False + self.patch_embed.proj.bias.requires_grad = False + + def build_2d_sincos_position_embedding(self, temperature=10000.): + h, w = self.patch_embed.grid_size + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h) + assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + pos_dim = self.embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1. / (temperature**omega) + out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) + out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) + pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] + + assert self.num_prefix_tokens == 1, 'Assuming one and only one token, [cls]' + pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) + self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) + self.pos_embed.requires_grad = False + + +class ConvStem(nn.Module): + """ + ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + + assert patch_size == 16, 'ConvStem only supports patch size of 16' + assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' + + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + # build stem, similar to the design in https://arxiv.org/abs/2106.14881 + stem = [] + input_dim, output_dim = 3, embed_dim // 8 + for l in range(4): + stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) + stem.append(nn.BatchNorm2d(output_dim)) + stem.append(nn.ReLU(inplace=True)) + input_dim = output_dim + output_dim *= 2 + stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) + self.proj = nn.Sequential(*stem) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +def vit_small(**kwargs): + model = VisionTransformerMoCo( + patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + +def vit_base(**kwargs): + model = VisionTransformerMoCo( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + +def vit_conv_small(**kwargs): + # minus one ViT block + model = VisionTransformerMoCo( + patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) + model.default_cfg = _cfg() + return model + +def vit_conv_base(**kwargs): + # minus one ViT block + model = VisionTransformerMoCo( + patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) + model.default_cfg = _cfg() + return model diff --git a/sgm/modules/autoencoding/moco/util.py b/sgm/modules/autoencoding/moco/util.py new file mode 100644 index 0000000..4a89603 --- /dev/null +++ b/sgm/modules/autoencoding/moco/util.py @@ -0,0 +1,61 @@ +import hashlib +import os +import tarfile + +import requests +import torch +import torch.nn as nn +from tqdm import tqdm + +URL_MAP = { + "mocov1_200ep": "https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v1_200ep/moco_v1_200ep_pretrain.pth.tar", + "mocov2_200ep": "https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_200ep/moco_v2_200ep_pretrain.pth.tar", + "mocov2_800ep": "https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", + "mocov3_vits": "https://dl.fbaipublicfiles.com/moco-v3/vit-s-300ep/vit-s-300ep.pth.tar", + "mocov3_vitb": "https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar", +} + +CKPT_MAP = { + "mocov1_200ep": "moco_v1_200ep_pretrain.pth", + "mocov2_200ep": "moco_v2_200ep_pretrain.pth", + "mocov2_800ep": "moco_v2_800ep_pretrain.pth", + "mocov3_vits": "vit-s-300ep.pth", + "mocov3_vitb": "vit-b-300ep.pth", +} + +MD5_MAP = { + "mocov1_200ep": "b251726a57be750490c34a7602b59076", + "mocov2_200ep": "59fd9945645e27d585be89338ce9232e", + "mocov2_800ep": "a04e12f8b0e44fdcac1fb4e06f33727b", + "mocov3_vits": "f32f0062e8884e64bcd2044c67bd3e43", + "mocov3_vitb": "7fe6d104c5ba222fecc6dc838f2dbcf9", +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path