diff --git a/README.md b/README.md index 42b33fe..55f84c3 100644 --- a/README.md +++ b/README.md @@ -1,98 +1,106 @@ -# AutoencoderKL +# Deep Compression Autoencoder -## About The Project +This repository is a branch of [lavinal712/AutoencoderKL](https://github.com/lavinal712/AutoencoderKL) with some modifications for the DC-AE. This code aims to provide a simple and easy-to-use training script regarding to [issue #173](https://github.com/mit-han-lab/efficientvit/issues/173). -There are many great training scripts for VAE on Github. However, some repositories are not maintained and some are not updated to the latest version of PyTorch. Therefore, I decided to create this repository to provide a simple and easy-to-use training script for VAE by Lightning. Beside, the code is easy to transfer to other projects for time saving. +We provide three different parameter configurations corresponding to the three phases of DC-AE training. -- Support training and finetuning both [Stable Diffusion](https://github.com/CompVis/stable-diffusion) VAE and [Flux](https://github.com/black-forest-labs/flux) VAE. -- Support evaluating reconstruction quality (FID, PSNR, SSIM, LPIPS). -- A practical guidance of training VAE. -- Easy to modify the code for your own research. +``` +configs/ +├── dc-ae-f32c32-in-1.0_phase1.yaml +├── dc-ae-f32c32-in-1.0_phase2.yaml +└── dc-ae-f32c32-in-1.0_phase3.yaml +``` - -## Getting Started +![phases](assets/phases.png) + +## Visualization + +Model: [mit-han-lab/dc-ae-f32c32-in-1.0](https://huggingface.co/models?other=dc-ae-f32c32-in-1.0) + +| Input | Reconstruction | +|--------------------------------------- |-----------------------------------------------------------| +| ![assets/inputs.png](assets/inputs.png) | ![assets/reconstructions.png](assets/reconstructions.png) | + +Evaluation (from [original paper](https://arxiv.org/abs/2410.10733)) -To get a local copy up and running follow these simple example steps. +| Model | rFID | PSNR | SSIM | LPIPS | +|----------------|------|-------|------|-------| +| DC-AE (f32c32) | 0.69 | 23.85 | 0.66 | 0.082 | +| SD-VAE | 0.69 | 26.91 | 0.77 | 0.130 | + +## Getting Started ### Installation -```bash -git clone https://github.com/lavinal712/AutoencoderKL.git +``` +git clone https://github.com/lavinal712/AutoencoderKL.git -b dc-ae cd AutoencoderKL conda create -n autoencoderkl python=3.10 -y conda activate autoencoderkl pip install -r requirements.txt ``` -### Training - -To start training, you need to prepare a config file. You can refer to the config files in the `configs` folder. +### Data -If you want to train on your own dataset, you should write your own data loader in `sgm/data` and modify the parameters in the config file. +We use the ImageNet dataset for training and validation. -Finetuning a VAE model is simple. You just need to specify the `ckpt_path` and `trainable_ae_params` in the config file. To keep the latent space of the original model, it is recommended to set decoder to be trainable. +``` +ImageNet/ +├── train/ +│ ├── n01440764/ +│ │ ├── n01440764_18.JPEG +│ │ ├── n01440764_36.JPEG +│ │ └── ... +│ ├── n01443537/ +│ │ ├── n01443537_2.JPEG +│ │ ├── n01443537_16.JPEG +│ │ └── ... +│ ├── ... +├── val/ +│ ├── n01440764/ +│ │ ├── ILSVRC2012_val_00000293.JPEG +│ │ ├── ILSVRC2012_val_00002138.JPEG +│ │ └── ... +│ ├── n01443537/ +│ │ ├── ILSVRC2012_val_00000236.JPEG +│ │ ├── ILSVRC2012_val_00000262.JPEG +│ │ └── ... +│ ├── ... +└── ... +``` -Then, you can start training by running the following command. +### Training ```bash -NUM_GPUS=4 -NUM_NODES=1 - -torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} main.py \ - --base configs/autoencoder_kl_32x32x4.yaml \ +torchrun --nproc_per_node=4 --nnodes=1 main.py \ + --base configs/dc-ae-f32c32-in-1.0_phase1.yaml \ --train \ - --logdir logs/autoencoder_kl_32x32x4 \ - --scale_lr True \ - --wandb False \ + --scale_lr False \ + --wandb True \ ``` -### Evaluation - -We provide a script to evaluate the reconstruction quality of the trained model. `--resume` provides a convenient way to load the checkpoint from the log directory. - -We introduce multi-GPU and multi-thread method for faster evaluation. - -The default dataset is ImageNet. You can change the dataset by modifying the `--datadir` in the command line and the evaluation script. +Remember to specify the model checkpoint path in the next phases in the command line or in the config file. ```bash -NUM_GPUS=4 -NUM_NODES=1 - -torchrun --nproc_per_node=${NUM_GPUS} --nnodes=${NUM_NODES} eval.py \ - --resume logs/autoencoder_kl_32x32x4 \ - --base configs/autoencoder_kl_32x32x4.yaml \ - --logdir eval/autoencoder_kl_32x32x4 \ - --datadir /path/to/ImageNet \ - --image_size 256 \ - --batch_size 16 \ - --num_workers 16 \ +torchrun --nproc_per_node=4 --nnodes=1 main.py \ + --base configs/dc-ae-f32c32-in-1.0_phase2.yaml \ + --train \ + --scale_lr False \ + --wandb True \ + --resume_from_checkpoint /path/to/last.ckpt ``` -### Converting to diffusers - -[huggingface/diffusers](https://github.com/huggingface/diffusers) is a library for diffusion models. It provides a script [convert_vae_pt_to_diffusers.py -](https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py) to convert a PyTorch Lightning model to a diffusers model. - -Currently, the script is not updated for all kinds of VAE models, just for SD VAE. - ```bash -python convert_vae_pt_to_diffusers.py \ - --vae_path logs/autoencoder_kl_32x32x4/checkpoints/last.ckpt \ - --dump_path autoencoder_kl_32x32x4 \ +torchrun --nproc_per_node=4 --nnodes=1 main.py \ + --base configs/dc-ae-f32c32-in-1.0_phase3.yaml \ + --train \ + --scale_lr False \ + --wandb True \ + --resume_from_checkpoint /path/to/last.ckpt ``` -## Guidance - -Here are some guidance for training VAE. If there are any mistakes, please let me know. - -- Learning rate: In LDM repository [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion), the base learning rate is set to 4.5e-6 in the config file. However, the batch size is 12, accumulated gradient is 2 and `scale_lr` is set to `True`. Therefore, the effective learning rate is 4.5e-6 * 12 * 2 * 1 = 1.08e-4. It is better to set the learning rate from 1.0e-4 to 1.0e-5. In finetuning stage, it can be smaller than the first stage. - - `scale_lr`: It is better to set `scale_lr` to `False` when training on a large dataset. -- Discriminator: You should open the discriminator in the end of the training, when the VAE has good reconstruction performance. In default, `disc_start` is set to 50001. -- Perceptual loss: LPIPS is a good metric for evaluating the quality of the reconstructed images. Some models use other perceptual loss functions to gain better performance, such as [sypsyp97/convnext_perceptual_loss](https://github.com/sypsyp97/convnext_perceptual_loss). - -## Acknowledgments +## Acknowledgements -Thanks for the following repositories. Without their code, this project would not be possible. +Thanks for the original introduction and implementation of [DC-AE](https://github.com/mit-han-lab/efficientvit) . -- [Stability-AI/generative-models](https://github.com/Stability-AI/generative-models). We heavily borrow the code from this repository, just modifing a few parameters for our concept. -- [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion). We follow the hyperparameter settings of this repository in config files. +- [mit-han-lab/efficientvit](https://github.com/mit-han-lab/efficientvit) diff --git a/assets/inputs.png b/assets/inputs.png new file mode 100644 index 0000000..1bf639c Binary files /dev/null and b/assets/inputs.png differ diff --git a/assets/phases.png b/assets/phases.png new file mode 100644 index 0000000..e3cd7f0 Binary files /dev/null and b/assets/phases.png differ diff --git a/assets/reconstructions.png b/assets/reconstructions.png new file mode 100644 index 0000000..fed5cdb Binary files /dev/null and b/assets/reconstructions.png differ diff --git a/configs/dc-ae-f32c32-in-1.0_phase1.yaml b/configs/dc-ae-f32c32-in-1.0_phase1.yaml new file mode 100644 index 0000000..52cc00c --- /dev/null +++ b/configs/dc-ae-f32c32-in-1.0_phase1.yaml @@ -0,0 +1,90 @@ +model: + base_learning_rate: 6.4e-5 + target: sgm.models.autoencoder.AutoencoderDC + params: + input_key: jpg + monitor: "val/loss/rec" + disc_start_iter: 1000000000 + + encoder_config: + target: sgm.modules.efficientvitmodules.model.EncoderConfig + params: + in_channels: 3 + latent_channels: 32 + block_type: [ResBlock, ResBlock, ResBlock, EViT_GLU, EViT_GLU, EViT_GLU] + width_list: [128, 256, 512, 512, 1024, 1024] + depth_list: [0, 4, 8, 2, 2, 2] + + decoder_config: + target: sgm.modules.efficientvitmodules.model.DecoderConfig + params: + in_channels: 3 + latent_channels: 32 + block_type: [ResBlock, ResBlock, ResBlock, EViT_GLU, EViT_GLU, EViT_GLU] + width_list: [128, 256, 512, 512, 1024, 1024] + depth_list: [0, 5, 10, 2, 2, 2] + norm: [bn2d, bn2d, bn2d, trms2d, trms2d, trms2d] + act: [relu, relu, relu, silu, silu, silu] + + loss_config: + target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator + params: + perceptual_weight: 0.25 + disc_start: 50001 + disc_weight: 0.0 + learn_logvar: false + pixel_loss: "l1" + + optimizer_config: + target: torch.optim.AdamW + params: + betas: [0.9, 0.999] + weight_decay: 0.1 + +data: + target: sgm.data.imagenet.ImageNetLoader + params: + batch_size: 24 + num_workers: 4 + prefetch_factor: 2 + shuffle: true + + train: + root_dir: /path/to/ImageNet + size: 256 + transform: true + validation: + root_dir: /path/to/ImageNet + size: 256 + transform: true + +lightning: + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: True + + modelcheckpoint: + params: + every_n_epochs: 1 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 50000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + precision: bf16 + devices: 0, 1, 2, 3 + limit_val_batches: 50 + benchmark: True + accumulate_grad_batches: 1 + check_val_every_n_epoch: 1 diff --git a/configs/dc-ae-f32c32-in-1.0_phase2.yaml b/configs/dc-ae-f32c32-in-1.0_phase2.yaml new file mode 100644 index 0000000..81ab5ca --- /dev/null +++ b/configs/dc-ae-f32c32-in-1.0_phase2.yaml @@ -0,0 +1,92 @@ +model: + base_learning_rate: 1.6e-5 + target: sgm.models.autoencoder.AutoencoderDC + params: + input_key: jpg + monitor: "val/loss/rec" + disc_start_iter: 1000000000 + trainable_ae_params: + - ["encoder.project_out", "decoder.project_in"] + + encoder_config: + target: sgm.modules.efficientvitmodules.model.EncoderConfig + params: + in_channels: 3 + latent_channels: 32 + block_type: [ResBlock, ResBlock, ResBlock, EViT_GLU, EViT_GLU, EViT_GLU] + width_list: [128, 256, 512, 512, 1024, 1024] + depth_list: [0, 4, 8, 2, 2, 2] + + decoder_config: + target: sgm.modules.efficientvitmodules.model.DecoderConfig + params: + in_channels: 3 + latent_channels: 32 + block_type: [ResBlock, ResBlock, ResBlock, EViT_GLU, EViT_GLU, EViT_GLU] + width_list: [128, 256, 512, 512, 1024, 1024] + depth_list: [0, 5, 10, 2, 2, 2] + norm: [bn2d, bn2d, bn2d, trms2d, trms2d, trms2d] + act: [relu, relu, relu, silu, silu, silu] + + loss_config: + target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator + params: + perceptual_weight: 0.25 + disc_start: 50001 + disc_weight: 0.0 + learn_logvar: false + pixel_loss: "l1" + + optimizer_config: + target: torch.optim.AdamW + params: + betas: [0.9, 0.999] + weight_decay: 0.001 + +data: + target: sgm.data.imagenet.ImageNetLoader + params: + batch_size: 4 + num_workers: 4 + prefetch_factor: 2 + shuffle: true + + train: + root_dir: /path/to/ImageNet + size: 512 + transform: true + validation: + root_dir: /path/to/ImageNet + size: 512 + transform: true + +lightning: + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: True + + modelcheckpoint: + params: + every_n_epochs: 1 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 50000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + precision: bf16 + devices: 0, 1, 2, 3 + limit_val_batches: 50 + benchmark: True + accumulate_grad_batches: 1 + check_val_every_n_epoch: 1 diff --git a/configs/dc-ae-f32c32-in-1.0_phase3.yaml b/configs/dc-ae-f32c32-in-1.0_phase3.yaml new file mode 100644 index 0000000..785957e --- /dev/null +++ b/configs/dc-ae-f32c32-in-1.0_phase3.yaml @@ -0,0 +1,91 @@ +model: + base_learning_rate: 5.4e-5 + target: sgm.models.autoencoder.AutoencoderDC + params: + input_key: jpg + monitor: "val/loss/rec" + disc_start_iter: 1 + trainable_ae_params: + - ["decoder.project_out"] + + encoder_config: + target: sgm.modules.efficientvitmodules.model.EncoderConfig + params: + in_channels: 3 + latent_channels: 32 + block_type: [ResBlock, ResBlock, ResBlock, EViT_GLU, EViT_GLU, EViT_GLU] + width_list: [128, 256, 512, 512, 1024, 1024] + depth_list: [0, 4, 8, 2, 2, 2] + + decoder_config: + target: sgm.modules.efficientvitmodules.model.DecoderConfig + params: + in_channels: 3 + latent_channels: 32 + block_type: [ResBlock, ResBlock, ResBlock, EViT_GLU, EViT_GLU, EViT_GLU] + width_list: [128, 256, 512, 512, 1024, 1024] + depth_list: [0, 5, 10, 2, 2, 2] + norm: [bn2d, bn2d, bn2d, trms2d, trms2d, trms2d] + act: [relu, relu, relu, silu, silu, silu] + + loss_config: + target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator + params: + perceptual_weight: 0.25 + disc_start: 1 + disc_weight: 0.5 + learn_logvar: false + pixel_loss: "l1" + + optimizer_config: + target: torch.optim.AdamW + params: + betas: [0.5, 0.9] + +data: + target: sgm.data.imagenet.ImageNetLoader + params: + batch_size: 24 + num_workers: 4 + prefetch_factor: 2 + shuffle: true + + train: + root_dir: /path/to/ImageNet + size: 256 + transform: true + validation: + root_dir: /path/to/ImageNet + size: 256 + transform: true + +lightning: + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: True + + modelcheckpoint: + params: + every_n_epochs: 1 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 50000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + precision: bf16 + devices: 0, 1, 2, 3 + limit_val_batches: 50 + benchmark: True + accumulate_grad_batches: 1 + check_val_every_n_epoch: 1 diff --git a/main.py b/main.py index 79a1fd1..d6e2ddd 100644 --- a/main.py +++ b/main.py @@ -132,7 +132,7 @@ def str2bool(v): parser.add_argument( "--projectname", type=str, - default="autoencoderkl", + default="dc-ae", ) parser.add_argument( "-l", diff --git a/sgm/data/coco.py b/sgm/data/coco.py index 399adf9..6599f9f 100644 --- a/sgm/data/coco.py +++ b/sgm/data/coco.py @@ -78,7 +78,7 @@ def __init__( if train.get("transform", None): size = train.get("size", 256) transform = transforms.Compose([ - transforms.Resize(size), + transforms.Resize(size, interpolation=Image.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), diff --git a/sgm/data/imagenet.py b/sgm/data/imagenet.py index 5be43d5..4fc86fb 100644 --- a/sgm/data/imagenet.py +++ b/sgm/data/imagenet.py @@ -3,6 +3,7 @@ import pytorch_lightning as pl from omegaconf import DictConfig +from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision import transforms from torchvision.datasets import ImageFolder @@ -49,7 +50,7 @@ def __init__( if train.get("transform", None): size = train.get("size", 256) transform = transforms.Compose([ - transforms.Resize(size), + transforms.Resize(size, interpolation=Image.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py index 4114c22..26403b4 100644 --- a/sgm/models/autoencoder.py +++ b/sgm/models/autoencoder.py @@ -13,6 +13,7 @@ from safetensors.torch import load_file as load_safetensors from ..modules.autoencoding.regularizers import AbstractRegularizer +from ..modules.efficientvitmodules.model import Encoder, Decoder from ..modules.ema import LitEma from ..util import (default, get_nested_attribute, get_obj_from_str, instantiate_from_config) @@ -527,3 +528,202 @@ def __init__(self, **kwargs): }, **kwargs, ) + + +class AutoencoderDC(AutoencodingEngine): + """ + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + trainable_ae_params: Optional[List[List[str]]] = None, + ae_optimizer_args: Optional[List[dict]] = None, + trainable_disc_params: Optional[List[List[str]]] = None, + disc_optimizer_args: Optional[List[dict]] = None, + disc_start_iter: int = 0, + diff_boost_factor: float = 3.0, + ckpt_engine: Union[None, str, dict] = None, + ckpt_path: Optional[str] = None, + **kwargs, + ): + AbstractAutoencoder.__init__(self, *args, **kwargs) + self.automatic_optimization = False # pytorch lightning + + encoder_config = instantiate_from_config(encoder_config) + decoder_config = instantiate_from_config(decoder_config) + self.encoder: torch.nn.Module = Encoder(encoder_config) + self.decoder: torch.nn.Module = Decoder(decoder_config) + self.loss: torch.nn.Module = instantiate_from_config(loss_config) + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.AdamW"} + ) + self.diff_boost_factor = diff_boost_factor + self.disc_start_iter = disc_start_iter + self.lr_g_factor = lr_g_factor + self.trainable_ae_params = trainable_ae_params + if self.trainable_ae_params is not None: + self.ae_optimizer_args = default( + ae_optimizer_args, + [{} for _ in range(len(self.trainable_ae_params))], + ) + assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) + else: + self.ae_optimizer_args = [{}] # makes type consitent + + self.trainable_disc_params = trainable_disc_params + if self.trainable_disc_params is not None: + self.disc_optimizer_args = default( + disc_optimizer_args, + [{} for _ in range(len(self.trainable_disc_params))], + ) + assert len(self.disc_optimizer_args) == len(self.trainable_disc_params) + else: + self.disc_optimizer_args = [{}] # makes type consitent + + if ckpt_path is not None: + assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" + logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + + @property + def spatial_compression_ratio(self) -> int: + return 2 ** (self.decoder.num_stages - 1) + + def get_autoencoder_params(self) -> list: + params = [] + if hasattr(self.loss, "get_trainable_autoencoder_parameters"): + params += list(self.loss.get_trainable_autoencoder_parameters()) + params = params + list(self.encoder.parameters()) + params = params + list(self.decoder.parameters()) + return params + + def get_last_layer(self): + for name, module in self.decoder.project_out.named_modules(): + if isinstance(module, nn.Conv2d): + return module.weight + raise ValueError("No last layer found") + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + assert unregularized, "AutoencoderDC does not support regularization" + z = self.encoder(x) + if unregularized: + if return_reg_log: + return z, dict() + return z + + def decode(self, z: torch.Tensor) -> torch.Tensor: + x = self.decoder(z) + return x + + def forward( + self, x: torch.Tensor, global_step: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z) + return z, dec, reg_log + + def inner_training_step( + self, batch: dict, batch_idx: int, optimizer_idx: int = 0 + ) -> torch.Tensor: + x = self.get_input(batch) + z, xrec, regularization_log = self(x, self.global_step) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": optimizer_idx, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "train", + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + + if optimizer_idx == 0: + # autoencode + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {"train/loss/rec": aeloss.detach()} + + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True, + sync_dist=False, + ) + self.log( + "loss", + aeloss.mean().detach(), + prog_bar=True, + logger=True, + on_epoch=False, + on_step=True, + ) + return aeloss + elif optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + # -> discriminator always needs to return a tuple + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True + ) + return discloss + else: + raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") + + @torch.no_grad() + def log_images( + self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs + ) -> dict: + log = dict() + x = self.get_input(batch) + + _, xrec, _ = self(x) + log["inputs"] = x + log["reconstructions"] = xrec + diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x) + diff.clamp_(0, 1.0) + log["diff"] = 2.0 * diff - 1.0 + # diff_boost shows location of small errors, by boosting their + # brightness. + log["diff_boost"] = ( + 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 + ) + if hasattr(self.loss, "log_images"): + log.update(self.loss.log_images(x, xrec)) + with self.ema_scope(): + _, xrec_ema, _ = self(x) + log["reconstructions_ema"] = xrec_ema + diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) + diff_ema.clamp_(0, 1.0) + log["diff_ema"] = 2.0 * diff_ema - 1.0 + log["diff_boost_ema"] = ( + 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + ) + if additional_log_kwargs: + _, xrec_add, _ = self(x) + log_str = "reconstructions-" + "-".join( + [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs] + ) + log[log_str] = xrec_add + return log diff --git a/sgm/modules/efficientvitmodules/__init__.py b/sgm/modules/efficientvitmodules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sgm/modules/efficientvitmodules/model.py b/sgm/modules/efficientvitmodules/model.py new file mode 100644 index 0000000..c79d809 --- /dev/null +++ b/sgm/modules/efficientvitmodules/model.py @@ -0,0 +1,385 @@ +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +import torch.nn as nn +from omegaconf import MISSING, OmegaConf + +from .nn.act import build_act +from .nn.norm import build_norm +from .nn.ops import ( + ChannelDuplicatingPixelUnshuffleUpSampleLayer, + ConvLayer, + ConvPixelShuffleUpSampleLayer, + ConvPixelUnshuffleDownSampleLayer, + EfficientViTBlock, + IdentityLayer, + InterpolateConvUpSampleLayer, + OpSequential, + PixelUnshuffleChannelAveragingDownSampleLayer, + ResBlock, + ResidualBlock, +) + +__all__ = ["DCAE", "dc_ae_f32c32", "dc_ae_f64c128", "dc_ae_f128c512"] + + +@dataclass +class EncoderConfig: + in_channels: int = MISSING + latent_channels: int = MISSING + width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024) + depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2) + block_type: Any = "ResBlock" + norm: str = "trms2d" + act: str = "silu" + downsample_block_type: str = "ConvPixelUnshuffle" + downsample_match_channel: bool = True + downsample_shortcut: Optional[str] = "averaging" + out_norm: Optional[str] = None + out_act: Optional[str] = None + out_shortcut: Optional[str] = "averaging" + double_latent: bool = False + + +@dataclass +class DecoderConfig: + in_channels: int = MISSING + latent_channels: int = MISSING + in_shortcut: Optional[str] = "duplicating" + width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024) + depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2) + block_type: Any = "ResBlock" + norm: Any = "trms2d" + act: Any = "silu" + upsample_block_type: str = "ConvPixelShuffle" + upsample_match_channel: bool = True + upsample_shortcut: str = "duplicating" + out_norm: str = "trms2d" + out_act: str = "relu" + + +def build_block( + block_type: str, in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str] +) -> nn.Module: + if block_type == "ResBlock": + assert in_channels == out_channels + main_block = ResBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + use_bias=(True, False), + norm=(None, norm), + act_func=(act, None), + ) + block = ResidualBlock(main_block, IdentityLayer()) + elif block_type == "EViT_GLU": + assert in_channels == out_channels + block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=()) + elif block_type == "EViTS5_GLU": + assert in_channels == out_channels + block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=(5,)) + else: + raise ValueError(f"block_type {block_type} is not supported") + return block + + +def build_stage_main( + width: int, depth: int, block_type: str | list[str], norm: str, act: str, input_width: int +) -> list[nn.Module]: + assert isinstance(block_type, str) or (isinstance(block_type, list) and depth == len(block_type)) + stage = [] + for d in range(depth): + current_block_type = block_type[d] if isinstance(block_type, list) else block_type + block = build_block( + block_type=current_block_type, + in_channels=width if d > 0 else input_width, + out_channels=width, + norm=norm, + act=act, + ) + stage.append(block) + return stage + + +def build_downsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module: + if block_type == "Conv": + block = ConvLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + use_bias=True, + norm=None, + act_func=None, + ) + elif block_type == "ConvPixelUnshuffle": + block = ConvPixelUnshuffleDownSampleLayer( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2 + ) + else: + raise ValueError(f"block_type {block_type} is not supported for downsampling") + if shortcut is None: + pass + elif shortcut == "averaging": + shortcut_block = PixelUnshuffleChannelAveragingDownSampleLayer( + in_channels=in_channels, out_channels=out_channels, factor=2 + ) + block = ResidualBlock(block, shortcut_block) + else: + raise ValueError(f"shortcut {shortcut} is not supported for downsample") + return block + + +def build_upsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module: + if block_type == "ConvPixelShuffle": + block = ConvPixelShuffleUpSampleLayer( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2 + ) + elif block_type == "InterpolateConv": + block = InterpolateConvUpSampleLayer( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2 + ) + else: + raise ValueError(f"block_type {block_type} is not supported for upsampling") + if shortcut is None: + pass + elif shortcut == "duplicating": + shortcut_block = ChannelDuplicatingPixelUnshuffleUpSampleLayer( + in_channels=in_channels, out_channels=out_channels, factor=2 + ) + block = ResidualBlock(block, shortcut_block) + else: + raise ValueError(f"shortcut {shortcut} is not supported for upsample") + return block + + +def build_encoder_project_in_block(in_channels: int, out_channels: int, factor: int, downsample_block_type: str): + if factor == 1: + block = ConvLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + use_bias=True, + norm=None, + act_func=None, + ) + elif factor == 2: + block = build_downsample_block( + block_type=downsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None + ) + else: + raise ValueError(f"downsample factor {factor} is not supported for encoder project in") + return block + + +def build_encoder_project_out_block( + in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str], shortcut: Optional[str] +): + block = OpSequential( + [ + build_norm(norm), + build_act(act), + ConvLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + use_bias=True, + norm=None, + act_func=None, + ), + ] + ) + if shortcut is None: + pass + elif shortcut == "averaging": + shortcut_block = PixelUnshuffleChannelAveragingDownSampleLayer( + in_channels=in_channels, out_channels=out_channels, factor=1 + ) + block = ResidualBlock(block, shortcut_block) + else: + raise ValueError(f"shortcut {shortcut} is not supported for encoder project out") + return block + + +def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut: Optional[str]): + block = ConvLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + use_bias=True, + norm=None, + act_func=None, + ) + if shortcut is None: + pass + elif shortcut == "duplicating": + shortcut_block = ChannelDuplicatingPixelUnshuffleUpSampleLayer( + in_channels=in_channels, out_channels=out_channels, factor=1 + ) + block = ResidualBlock(block, shortcut_block) + else: + raise ValueError(f"shortcut {shortcut} is not supported for decoder project in") + return block + + +def build_decoder_project_out_block( + in_channels: int, out_channels: int, factor: int, upsample_block_type: str, norm: Optional[str], act: Optional[str] +): + layers: list[nn.Module] = [ + build_norm(norm, in_channels), + build_act(act), + ] + if factor == 1: + layers.append( + ConvLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + use_bias=True, + norm=None, + act_func=None, + ) + ) + elif factor == 2: + layers.append( + build_upsample_block( + block_type=upsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None + ) + ) + else: + raise ValueError(f"upsample factor {factor} is not supported for decoder project out") + return OpSequential(layers) + + +class Encoder(nn.Module): + def __init__(self, cfg: EncoderConfig): + super().__init__() + self.cfg = cfg + num_stages = len(cfg.width_list) + self.num_stages = num_stages + assert len(cfg.depth_list) == num_stages + assert len(cfg.width_list) == num_stages + cfg.block_type = list(cfg.block_type) + assert isinstance(cfg.block_type, str) or ( + isinstance(cfg.block_type, list) and len(cfg.block_type) == num_stages + ) + + self.project_in = build_encoder_project_in_block( + in_channels=cfg.in_channels, + out_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1], + factor=1 if cfg.depth_list[0] > 0 else 2, + downsample_block_type=cfg.downsample_block_type, + ) + + self.stages: list[OpSequential] = [] + for stage_id, (width, depth) in enumerate(zip(cfg.width_list, cfg.depth_list)): + block_type = cfg.block_type[stage_id] if isinstance(cfg.block_type, list) else cfg.block_type + stage = build_stage_main( + width=width, depth=depth, block_type=block_type, norm=cfg.norm, act=cfg.act, input_width=width + ) + + if stage_id < num_stages - 1 and depth > 0: + downsample_block = build_downsample_block( + block_type=cfg.downsample_block_type, + in_channels=width, + out_channels=cfg.width_list[stage_id + 1] if cfg.downsample_match_channel else width, + shortcut=cfg.downsample_shortcut, + ) + stage.append(downsample_block) + self.stages.append(OpSequential(stage)) + self.stages = nn.ModuleList(self.stages) + + self.project_out = build_encoder_project_out_block( + in_channels=cfg.width_list[-1], + out_channels=2 * cfg.latent_channels if cfg.double_latent else cfg.latent_channels, + norm=cfg.out_norm, + act=cfg.out_act, + shortcut=cfg.out_shortcut, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.project_in(x) + for stage in self.stages: + if len(stage.op_list) == 0: + continue + x = stage(x) + x = self.project_out(x) + return x + + +class Decoder(nn.Module): + def __init__(self, cfg: DecoderConfig): + super().__init__() + self.cfg = cfg + num_stages = len(cfg.width_list) + self.num_stages = num_stages + assert len(cfg.depth_list) == num_stages + assert len(cfg.width_list) == num_stages + cfg.block_type = list(cfg.block_type) + assert isinstance(cfg.block_type, str) or ( + isinstance(cfg.block_type, list) and len(cfg.block_type) == num_stages + ) + cfg.norm = list(cfg.norm) + cfg.act = list(cfg.act) + assert isinstance(cfg.norm, str) or (isinstance(cfg.norm, list) and len(cfg.norm) == num_stages) + assert isinstance(cfg.act, str) or (isinstance(cfg.act, list) and len(cfg.act) == num_stages) + + self.project_in = build_decoder_project_in_block( + in_channels=cfg.latent_channels, + out_channels=cfg.width_list[-1], + shortcut=cfg.in_shortcut, + ) + + self.stages: list[OpSequential] = [] + for stage_id, (width, depth) in reversed(list(enumerate(zip(cfg.width_list, cfg.depth_list)))): + stage = [] + if stage_id < num_stages - 1 and depth > 0: + upsample_block = build_upsample_block( + block_type=cfg.upsample_block_type, + in_channels=cfg.width_list[stage_id + 1], + out_channels=width if cfg.upsample_match_channel else cfg.width_list[stage_id + 1], + shortcut=cfg.upsample_shortcut, + ) + stage.append(upsample_block) + + block_type = cfg.block_type[stage_id] if isinstance(cfg.block_type, list) else cfg.block_type + norm = cfg.norm[stage_id] if isinstance(cfg.norm, list) else cfg.norm + act = cfg.act[stage_id] if isinstance(cfg.act, list) else cfg.act + stage.extend( + build_stage_main( + width=width, + depth=depth, + block_type=block_type, + norm=norm, + act=act, + input_width=( + width if cfg.upsample_match_channel else cfg.width_list[min(stage_id + 1, num_stages - 1)] + ), + ) + ) + self.stages.insert(0, OpSequential(stage)) + self.stages = nn.ModuleList(self.stages) + + self.project_out = build_decoder_project_out_block( + in_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1], + out_channels=cfg.in_channels, + factor=1 if cfg.depth_list[0] > 0 else 2, + upsample_block_type=cfg.upsample_block_type, + norm=cfg.out_norm, + act=cfg.out_act, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.project_in(x) + for stage in reversed(self.stages): + if len(stage.op_list) == 0: + continue + x = stage(x) + x = self.project_out(x) + return x diff --git a/sgm/modules/efficientvitmodules/nn/__init__.py b/sgm/modules/efficientvitmodules/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sgm/modules/efficientvitmodules/nn/act.py b/sgm/modules/efficientvitmodules/nn/act.py new file mode 100644 index 0000000..36f1143 --- /dev/null +++ b/sgm/modules/efficientvitmodules/nn/act.py @@ -0,0 +1,27 @@ +from functools import partial +from typing import Optional + +import torch.nn as nn + +from ..utils import build_kwargs_from_config + +__all__ = ["build_act"] + + +# register activation function here +REGISTERED_ACT_DICT: dict[str, type] = { + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "hswish": nn.Hardswish, + "silu": nn.SiLU, + "gelu": partial(nn.GELU, approximate="tanh"), +} + + +def build_act(name: str, **kwargs) -> Optional[nn.Module]: + if name in REGISTERED_ACT_DICT: + act_cls = REGISTERED_ACT_DICT[name] + args = build_kwargs_from_config(kwargs, act_cls) + return act_cls(**args) + else: + return None diff --git a/sgm/modules/efficientvitmodules/nn/norm.py b/sgm/modules/efficientvitmodules/nn/norm.py new file mode 100644 index 0000000..eb91601 --- /dev/null +++ b/sgm/modules/efficientvitmodules/nn/norm.py @@ -0,0 +1,227 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +from .triton_rms_norm import TritonRMSNorm2dFunc +from ..utils import build_kwargs_from_config + +__all__ = ["LayerNorm2d", "TritonRMSNorm2d", "build_norm", "reset_bn", "set_norm_eps"] + + +class LayerNorm2d(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = x - torch.mean(x, dim=1, keepdim=True) + out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) + if self.elementwise_affine: + out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) + return out + + +class TritonRMSNorm2d(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return TritonRMSNorm2dFunc.apply(x, self.weight, self.bias, self.eps) + + +# register normalization function here +REGISTERED_NORM_DICT: dict[str, type] = { + "bn2d": nn.BatchNorm2d, + "ln": nn.LayerNorm, + "ln2d": LayerNorm2d, + "trms2d": TritonRMSNorm2d, +} + + +def build_norm(name="bn2d", num_features=None, **kwargs) -> Optional[nn.Module]: + if name in ["ln", "ln2d", "trms2d"]: + kwargs["normalized_shape"] = num_features + else: + kwargs["num_features"] = num_features + if name in REGISTERED_NORM_DICT: + norm_cls = REGISTERED_NORM_DICT[name] + args = build_kwargs_from_config(kwargs, norm_cls) + return norm_cls(**args) + else: + return None + + +def reset_bn( + model: nn.Module, + data_loader: list, + sync=True, + progress_bar=False, +) -> None: + import copy + + import torch.nn.functional as F + from tqdm import tqdm + + from ..utils import get_device, list_join + + bn_mean = {} + bn_var = {} + + tmp_model = copy.deepcopy(model) + for name, m in tmp_model.named_modules(): + if isinstance(m, _BatchNorm): + bn_mean[name] = AverageMeter(is_distributed=False) + bn_var[name] = AverageMeter(is_distributed=False) + + def new_forward(bn, mean_est, var_est): + def lambda_forward(x): + x = x.contiguous() + if sync: + batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 + batch_mean = sync_tensor(batch_mean, reduce="cat") + batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) + + batch_var = (x - batch_mean) * (x - batch_mean) + batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + batch_var = sync_tensor(batch_var, reduce="cat") + batch_var = torch.mean(batch_var, dim=0, keepdim=True) + else: + batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 + batch_var = (x - batch_mean) * (x - batch_mean) + batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + + batch_mean = torch.squeeze(batch_mean) + batch_var = torch.squeeze(batch_var) + + mean_est.update(batch_mean.data, x.size(0)) + var_est.update(batch_var.data, x.size(0)) + + # bn forward using calculated mean & var + _feature_dim = batch_mean.shape[0] + return F.batch_norm( + x, + batch_mean, + batch_var, + bn.weight[:_feature_dim], + bn.bias[:_feature_dim], + False, + 0.0, + bn.eps, + ) + + return lambda_forward + + m.forward = new_forward(m, bn_mean[name], bn_var[name]) + + # skip if there is no batch normalization layers in the network + if len(bn_mean) == 0: + return + + tmp_model.eval() + with torch.no_grad(): + with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t: + for images in data_loader: + images = images.to(get_device(tmp_model)) + tmp_model(images) + t.set_postfix( + { + "bs": images.size(0), + "res": list_join(images.shape[-2:], "x"), + } + ) + t.update() + + for name, m in model.named_modules(): + if name in bn_mean and bn_mean[name].count > 0: + feature_dim = bn_mean[name].avg.size(0) + assert isinstance(m, _BatchNorm) + m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) + m.running_var.data[:feature_dim].copy_(bn_var[name].avg) + + +def set_norm_eps(model: nn.Module, eps: Optional[float] = None) -> None: + for m in model.modules(): + if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): + if eps is not None: + m.eps = eps + + +## utils +import os +from ..utils import list_mean, list_sum + + +def dist_init() -> None: + if is_dist_initialized(): + return + try: + torch.distributed.init_process_group(backend="nccl") + assert torch.distributed.is_initialized() + except Exception: + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_RANK"] = "0" + print("warning: dist not init") + + +def is_dist_initialized() -> bool: + return torch.distributed.is_initialized() + + +def get_dist_rank() -> int: + return int(os.environ["RANK"]) + + +def get_dist_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def is_master() -> bool: + return get_dist_rank() == 0 + + +def dist_barrier() -> None: + if is_dist_initialized(): + torch.distributed.barrier() + + +def get_dist_local_rank() -> int: + return int(os.environ["LOCAL_RANK"]) + + +def sync_tensor(tensor: torch.Tensor | float, reduce="mean") -> torch.Tensor | list[torch.Tensor]: + if not is_dist_initialized(): + return tensor + if not isinstance(tensor, torch.Tensor): + tensor = torch.Tensor(1).fill_(tensor).cuda() + tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())] + torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) + if reduce == "mean": + return list_mean(tensor_list) + elif reduce == "sum": + return list_sum(tensor_list) + elif reduce == "cat": + return torch.cat(tensor_list, dim=0) + elif reduce == "root": + return tensor_list[0] + else: + return tensor_list + + +class AverageMeter: + """Computes and stores the average and current value.""" + + def __init__(self, is_distributed=True): + self.is_distributed = is_distributed + self.sum = 0 + self.count = 0 + + def _sync(self, val: torch.Tensor | int | float) -> torch.Tensor | int | float: + return sync_tensor(val, reduce="sum") if self.is_distributed else val + + def update(self, val: torch.Tensor | int | float, delta_n=1): + self.count += self._sync(delta_n) + self.sum += self._sync(val * delta_n) + + def get_count(self) -> torch.Tensor | int | float: + return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count + + @property + def avg(self): + avg = -1 if self.count == 0 else self.sum / self.count + return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg diff --git a/sgm/modules/efficientvitmodules/nn/ops.py b/sgm/modules/efficientvitmodules/nn/ops.py new file mode 100644 index 0000000..1f08d17 --- /dev/null +++ b/sgm/modules/efficientvitmodules/nn/ops.py @@ -0,0 +1,819 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .act import build_act +from .norm import build_norm +from ..utils import get_same_padding, list_sum, resize, val2list, val2tuple + +__all__ = [ + "ConvLayer", + "UpSampleLayer", + "ConvPixelUnshuffleDownSampleLayer", + "PixelUnshuffleChannelAveragingDownSampleLayer", + "ConvPixelShuffleUpSampleLayer", + "ChannelDuplicatingPixelUnshuffleUpSampleLayer", + "LinearLayer", + "IdentityLayer", + "DSConv", + "MBConv", + "FusedMBConv", + "ResBlock", + "LiteMLA", + "EfficientViTBlock", + "ResidualBlock", + "DAGBlock", + "OpSequential", +] + + +################################################################################# +# Basic Layers # +################################################################################# + + +class ConvLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + use_bias=False, + dropout=0, + norm="bn2d", + act_func="relu", + ): + super(ConvLayer, self).__init__() + + padding = get_same_padding(kernel_size) + padding *= dilation + + self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=(kernel_size, kernel_size), + stride=(stride, stride), + padding=padding, + dilation=(dilation, dilation), + groups=groups, + bias=use_bias, + ) + self.norm = build_norm(norm, num_features=out_channels) + self.act = build_act(act_func) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dropout is not None: + x = self.dropout(x) + x = self.conv(x) + if self.norm: + x = self.norm(x) + if self.act: + x = self.act(x) + return x + + +class UpSampleLayer(nn.Module): + def __init__( + self, + mode="bicubic", + size: Optional[int | tuple[int, int] | list[int]] = None, + factor=2, + align_corners=False, + ): + super(UpSampleLayer, self).__init__() + self.mode = mode + self.size = val2list(size, 2) if size is not None else None + self.factor = None if self.size is not None else factor + self.align_corners = align_corners + + @torch.autocast(device_type="cuda", enabled=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + if (self.size is not None and tuple(x.shape[-2:]) == self.size) or self.factor == 1: + return x + if x.dtype in [torch.float16, torch.bfloat16]: + x = x.float() + return resize(x, self.size, self.factor, self.mode, self.align_corners) + + +class ConvPixelUnshuffleDownSampleLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + factor: int, + ): + super().__init__() + self.factor = factor + out_ratio = factor**2 + assert out_channels % out_ratio == 0 + self.conv = ConvLayer( + in_channels=in_channels, + out_channels=out_channels // out_ratio, + kernel_size=kernel_size, + use_bias=True, + norm=None, + act_func=None, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = F.pixel_unshuffle(x, self.factor) + return x + + +class PixelUnshuffleChannelAveragingDownSampleLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor = factor + assert in_channels * factor**2 % out_channels == 0 + self.group_size = in_channels * factor**2 // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pixel_unshuffle(x, self.factor) + B, C, H, W = x.shape + x = x.view(B, self.out_channels, self.group_size, H, W) + x = x.mean(dim=2) + return x + + +class ConvPixelShuffleUpSampleLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + factor: int, + ): + super().__init__() + self.factor = factor + out_ratio = factor**2 + self.conv = ConvLayer( + in_channels=in_channels, + out_channels=out_channels * out_ratio, + kernel_size=kernel_size, + use_bias=True, + norm=None, + act_func=None, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = F.pixel_shuffle(x, self.factor) + return x + + +class InterpolateConvUpSampleLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + factor: int, + mode: str = "nearest", + ) -> None: + super().__init__() + self.factor = factor + self.mode = mode + self.conv = ConvLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + use_bias=True, + norm=None, + act_func=None, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode) + x = self.conv(x) + return x + + +class ChannelDuplicatingPixelUnshuffleUpSampleLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor = factor + assert out_channels * factor**2 % in_channels == 0 + self.repeats = out_channels * factor**2 // in_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = F.pixel_shuffle(x, self.factor) + return x + + +class LinearLayer(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + use_bias=True, + dropout=0, + norm=None, + act_func=None, + ): + super(LinearLayer, self).__init__() + + self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None + self.linear = nn.Linear(in_features, out_features, use_bias) + self.norm = build_norm(norm, num_features=out_features) + self.act = build_act(act_func) + + def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() > 2: + x = torch.flatten(x, start_dim=1) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._try_squeeze(x) + if self.dropout: + x = self.dropout(x) + x = self.linear(x) + if self.norm: + x = self.norm(x) + if self.act: + x = self.act(x) + return x + + +class IdentityLayer(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +################################################################################# +# Basic Blocks # +################################################################################# + + +class DSConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super(DSConv, self).__init__() + + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + self.depth_conv = ConvLayer( + in_channels, + in_channels, + kernel_size, + stride, + groups=in_channels, + norm=norm[0], + act_func=act_func[0], + use_bias=use_bias[0], + ) + self.point_conv = ConvLayer( + in_channels, + out_channels, + 1, + norm=norm[1], + act_func=act_func[1], + use_bias=use_bias[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class MBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + use_bias=False, + norm=("bn2d", "bn2d", "bn2d"), + act_func=("relu6", "relu6", None), + ): + super(MBConv, self).__init__() + + use_bias = val2tuple(use_bias, 3) + norm = val2tuple(norm, 3) + act_func = val2tuple(act_func, 3) + mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels + + self.inverted_conv = ConvLayer( + in_channels, + mid_channels, + 1, + stride=1, + norm=norm[0], + act_func=act_func[0], + use_bias=use_bias[0], + ) + self.depth_conv = ConvLayer( + mid_channels, + mid_channels, + kernel_size, + stride=stride, + groups=mid_channels, + norm=norm[1], + act_func=act_func[1], + use_bias=use_bias[1], + ) + self.point_conv = ConvLayer( + mid_channels, + out_channels, + 1, + norm=norm[2], + act_func=act_func[2], + use_bias=use_bias[2], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.inverted_conv(x) + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class FusedMBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + groups=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super().__init__() + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels + + self.spatial_conv = ConvLayer( + in_channels, + mid_channels, + kernel_size, + stride, + groups=groups, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.point_conv = ConvLayer( + mid_channels, + out_channels, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.spatial_conv(x) + x = self.point_conv(x) + return x + + +class GLUMBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + use_bias=False, + norm=(None, None, "ln2d"), + act_func=("silu", "silu", None), + ): + super().__init__() + use_bias = val2tuple(use_bias, 3) + norm = val2tuple(norm, 3) + act_func = val2tuple(act_func, 3) + + mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels + + self.glu_act = build_act(act_func[1], inplace=False) + self.inverted_conv = ConvLayer( + in_channels, + mid_channels * 2, + 1, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.depth_conv = ConvLayer( + mid_channels * 2, + mid_channels * 2, + kernel_size, + stride=stride, + groups=mid_channels * 2, + use_bias=use_bias[1], + norm=norm[1], + act_func=None, + ) + self.point_conv = ConvLayer( + mid_channels, + out_channels, + 1, + use_bias=use_bias[2], + norm=norm[2], + act_func=act_func[2], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.inverted_conv(x) + x = self.depth_conv(x) + + x, gate = torch.chunk(x, 2, dim=1) + gate = self.glu_act(gate) + x = x * gate + + x = self.point_conv(x) + return x + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super().__init__() + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels + + self.conv1 = ConvLayer( + in_channels, + mid_channels, + kernel_size, + stride, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.conv2 = ConvLayer( + mid_channels, + out_channels, + kernel_size, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.conv2(x) + return x + + +class LiteMLA(nn.Module): + r"""Lightweight multi-scale linear attention""" + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: Optional[int] = None, + heads_ratio: float = 1.0, + dim=8, + use_bias=False, + norm=(None, "bn2d"), + act_func=(None, None), + kernel_func="relu", + scales: tuple[int, ...] = (5,), + eps=1.0e-15, + ): + super(LiteMLA, self).__init__() + self.eps = eps + heads = int(in_channels // dim * heads_ratio) if heads is None else heads + + total_dim = heads * dim + + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + self.dim = dim + self.qkv = ConvLayer( + in_channels, + 3 * total_dim, + 1, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.aggreg = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d( + 3 * total_dim, + 3 * total_dim, + scale, + padding=get_same_padding(scale), + groups=3 * total_dim, + bias=use_bias[0], + ), + nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), + ) + for scale in scales + ] + ) + self.kernel_func = build_act(kernel_func, inplace=False) + + self.proj = ConvLayer( + total_dim * (1 + len(scales)), + out_channels, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + @torch.autocast(device_type="cuda", enabled=False) + def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor: + B, _, H, W = list(qkv.size()) + + if qkv.dtype == torch.float16: + qkv = qkv.float() + + qkv = torch.reshape( + qkv, + ( + B, + -1, + 3 * self.dim, + H * W, + ), + ) + q, k, v = ( + qkv[:, :, 0 : self.dim], + qkv[:, :, self.dim : 2 * self.dim], + qkv[:, :, 2 * self.dim :], + ) + + # lightweight linear attention + q = self.kernel_func(q) + k = self.kernel_func(k) + + # linear matmul + trans_k = k.transpose(-1, -2) + + v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1) + vk = torch.matmul(v, trans_k) + out = torch.matmul(vk, q) + if out.dtype == torch.bfloat16: + out = out.float() + out = out[:, :, :-1] / (out[:, :, -1:] + self.eps) + + out = torch.reshape(out, (B, -1, H, W)) + return out + + @torch.autocast(device_type="cuda", enabled=False) + def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor: + B, _, H, W = list(qkv.size()) + + qkv = torch.reshape( + qkv, + ( + B, + -1, + 3 * self.dim, + H * W, + ), + ) + q, k, v = ( + qkv[:, :, 0 : self.dim], + qkv[:, :, self.dim : 2 * self.dim], + qkv[:, :, 2 * self.dim :], + ) + + q = self.kernel_func(q) + k = self.kernel_func(k) + + att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n + original_dtype = att_map.dtype + if original_dtype in [torch.float16, torch.bfloat16]: + att_map = att_map.float() + att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n + att_map = att_map.to(original_dtype) + out = torch.matmul(v, att_map) # b h d n + + out = torch.reshape(out, (B, -1, H, W)) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # generate multi-scale q, k, v + qkv = self.qkv(x) + multi_scale_qkv = [qkv] + for op in self.aggreg: + multi_scale_qkv.append(op(qkv)) + qkv = torch.cat(multi_scale_qkv, dim=1) + + H, W = list(qkv.size())[-2:] + if H * W > self.dim: + out = self.relu_linear_att(qkv).to(qkv.dtype) + else: + out = self.relu_quadratic_att(qkv) + out = self.proj(out) + + return out + + +class EfficientViTBlock(nn.Module): + def __init__( + self, + in_channels: int, + heads_ratio: float = 1.0, + dim=32, + expand_ratio: float = 4, + scales: tuple[int, ...] = (5,), + norm: str = "bn2d", + act_func: str = "hswish", + context_module: str = "LiteMLA", + local_module: str = "MBConv", + ): + super(EfficientViTBlock, self).__init__() + if context_module == "LiteMLA": + self.context_module = ResidualBlock( + LiteMLA( + in_channels=in_channels, + out_channels=in_channels, + heads_ratio=heads_ratio, + dim=dim, + norm=(None, norm), + scales=scales, + ), + IdentityLayer(), + ) + else: + raise ValueError(f"context_module {context_module} is not supported") + if local_module == "MBConv": + self.local_module = ResidualBlock( + MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + use_bias=(True, True, False), + norm=(None, None, norm), + act_func=(act_func, act_func, None), + ), + IdentityLayer(), + ) + elif local_module == "GLUMBConv": + self.local_module = ResidualBlock( + GLUMBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + use_bias=(True, True, False), + norm=(None, None, norm), + act_func=(act_func, act_func, None), + ), + IdentityLayer(), + ) + else: + raise NotImplementedError(f"local_module {local_module} is not supported") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.context_module(x) + x = self.local_module(x) + return x + + +################################################################################# +# Functional Blocks # +################################################################################# + + +class ResidualBlock(nn.Module): + def __init__( + self, + main: Optional[nn.Module], + shortcut: Optional[nn.Module], + post_act=None, + pre_norm: Optional[nn.Module] = None, + ): + super(ResidualBlock, self).__init__() + + self.pre_norm = pre_norm + self.main = main + self.shortcut = shortcut + self.post_act = build_act(post_act) + + def forward_main(self, x: torch.Tensor) -> torch.Tensor: + if self.pre_norm is None: + return self.main(x) + else: + return self.main(self.pre_norm(x)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.main is None: + res = x + elif self.shortcut is None: + res = self.forward_main(x) + else: + res = self.forward_main(x) + self.shortcut(x) + if self.post_act: + res = self.post_act(res) + return res + + +class DAGBlock(nn.Module): + def __init__( + self, + inputs: dict[str, nn.Module], + merge: str, + post_input: Optional[nn.Module], + middle: nn.Module, + outputs: dict[str, nn.Module], + ): + super(DAGBlock, self).__init__() + + self.input_keys = list(inputs.keys()) + self.input_ops = nn.ModuleList(list(inputs.values())) + self.merge = merge + self.post_input = post_input + + self.middle = middle + + self.output_keys = list(outputs.keys()) + self.output_ops = nn.ModuleList(list(outputs.values())) + + def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + feat = [op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)] + if self.merge == "add": + feat = list_sum(feat) + elif self.merge == "cat": + feat = torch.concat(feat, dim=1) + else: + raise NotImplementedError + if self.post_input is not None: + feat = self.post_input(feat) + feat = self.middle(feat) + for key, op in zip(self.output_keys, self.output_ops): + feature_dict[key] = op(feat) + return feature_dict + + +class OpSequential(nn.Module): + def __init__(self, op_list: list[Optional[nn.Module]]): + super(OpSequential, self).__init__() + valid_op_list = [] + for op in op_list: + if op is not None: + valid_op_list.append(op) + self.op_list = nn.ModuleList(valid_op_list) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for op in self.op_list: + x = op(x) + return x diff --git a/sgm/modules/efficientvitmodules/nn/triton_rms_norm.py b/sgm/modules/efficientvitmodules/nn/triton_rms_norm.py new file mode 100644 index 0000000..e6a3ca7 --- /dev/null +++ b/sgm/modules/efficientvitmodules/nn/triton_rms_norm.py @@ -0,0 +1,191 @@ +import torch +import triton +import triton.language as tl + +__all__ = ["TritonRMSNorm2dFunc"] + + +@triton.jit +def _rms_norm_2d_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Rrms, # pointer to the 1/rms + M, + C, + N, + num_blocks, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + m_n = tl.program_id(0) + m, n = m_n // num_blocks, m_n % num_blocks + + Y += m * C * N + X += m * C * N + # Compute mean + + cols = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = cols < N + + x_sum_square = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, C): + x = tl.load(X + off * N + cols, mask=mask, other=0.0).to(tl.float32) + x_sum_square += x * x + mean_square = x_sum_square / C + rrms = 1 / tl.sqrt(mean_square + eps) + # Write rstd + tl.store(Rrms + m * N + cols, rrms, mask=mask) + # Normalize and apply linear transformation + for off in range(0, C): + pos = off * N + cols + w = tl.load(W + off) + b = tl.load(B + off) + x = tl.load(X + pos, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rrms + y = x_hat * w + b + # Write output + tl.store(Y + pos, y, mask=mask) + + +@triton.jit +def _rms_norm_2d_bwd_dx_fused( + DX, # pointer to the input gradient + DY, # pointer to the output gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Rrms, # pointer to the 1/rms + M, + C, + N, # number of columns in X + num_blocks, + eps, # epsilon to avoid division by zero + GROUP_SIZE_M: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + m_n = tl.program_id(0) + m, n = m_n // num_blocks, m_n % num_blocks + X += m * C * N + DY += m * C * N + DX += m * C * N + Rrms += m * N + + cols = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = cols < N + # Offset locks and weights/biases gradient pointer for parallel reduction + DW = DW + m_n * C + DB = DB + m_n * C + rrms = tl.load(Rrms + cols, mask=mask, other=1) + # Load data to SRAM + c1 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, C): + pos = off * N + cols + x = tl.load(X + pos, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + pos, mask=mask, other=0).to(tl.float32) + w = tl.load(W + off).to(tl.float32) + # Compute dx + xhat = x * rrms + wdy = w * dy + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 += xhat * wdy + # Accumulate partial sums for dw/db + tl.store(DW + off, tl.sum((dy * xhat).to(w.dtype), axis=0)) + tl.store(DB + off, tl.sum(dy.to(w.dtype), axis=0)) + + c1 /= C + for off in range(0, C): + pos = off * N + cols + x = tl.load(X + pos, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + pos, mask=mask, other=0).to(tl.float32) + w = tl.load(W + off).to(tl.float32) + xhat = x * rrms + wdy = w * dy + dx = (wdy - (xhat * c1)) * rrms + # Write dx + tl.store(DX + pos, dx, mask=mask) + + +class TritonRMSNorm2dFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(x.shape[0], x.shape[1], -1) + M, C, N = x_arg.shape + rrms = torch.empty((M, N), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + BLOCK_SIZE = 256 + num_blocks = triton.cdiv(N, BLOCK_SIZE) + num_warps = 8 + # enqueue kernel + _rms_norm_2d_fwd_fused[(M * num_blocks,)]( # + x_arg, + y, + weight, + bias, + rrms, # + M, + C, + N, + num_blocks, + eps, # + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_ctas=1, + ) + ctx.save_for_backward(x, weight, bias, rrms) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_blocks = num_blocks + ctx.num_warps = num_warps + ctx.eps = eps + return y + + @staticmethod + def backward(ctx, dy): + x, w, b, rrms = ctx.saved_tensors + num_blocks = ctx.num_blocks + + x_arg = x.reshape(x.shape[0], x.shape[1], -1) + M, C, N = x_arg.shape + # GROUP_SIZE_M = 64 + GROUP_SIZE_M = M * num_blocks + # allocate output + _dw = torch.empty((GROUP_SIZE_M, C), dtype=x.dtype, device=w.device) + _db = torch.empty((GROUP_SIZE_M, C), dtype=x.dtype, device=w.device) + dw = torch.empty((C,), dtype=w.dtype, device=w.device) + db = torch.empty((C,), dtype=w.dtype, device=w.device) + dx = torch.empty_like(dy) + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + # print(f"M={M}, num_blocks={num_blocks}, dx={dx.shape}, dy={dy.shape}, _dw={_dw.shape}, _db={_db.shape}, x={x.shape}, w={w.shape}, b={b.shape}, m={m.shape}, v={v.shape}, M={M}, C={C}, N={N}") + _rms_norm_2d_bwd_dx_fused[(M * num_blocks,)]( # + dx, + dy, + _dw, + _db, + x, + w, + b, + rrms, # + M, + C, + N, + num_blocks, + ctx.eps, # + BLOCK_SIZE=ctx.BLOCK_SIZE, + GROUP_SIZE_M=GROUP_SIZE_M, # + BLOCK_SIZE_C=triton.next_power_of_2(C), + num_warps=ctx.num_warps, + ) + dw = _dw.sum(dim=0) + db = _db.sum(dim=0) + return dx, dw, db, None diff --git a/sgm/modules/efficientvitmodules/utils/__init__.py b/sgm/modules/efficientvitmodules/utils/__init__.py new file mode 100644 index 0000000..4155f95 --- /dev/null +++ b/sgm/modules/efficientvitmodules/utils/__init__.py @@ -0,0 +1,3 @@ +from .list import * +from .network import * +from .random import * diff --git a/sgm/modules/efficientvitmodules/utils/list.py b/sgm/modules/efficientvitmodules/utils/list.py new file mode 100644 index 0000000..71fa5f7 --- /dev/null +++ b/sgm/modules/efficientvitmodules/utils/list.py @@ -0,0 +1,51 @@ +from typing import Any, Optional + +__all__ = [ + "list_sum", + "list_mean", + "weighted_list_sum", + "list_join", + "val2list", + "val2tuple", + "squeeze_list", +] + + +def list_sum(x: list) -> Any: + return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) + + +def list_mean(x: list) -> Any: + return list_sum(x) / len(x) + + +def weighted_list_sum(x: list, weights: list) -> Any: + assert len(x) == len(weights) + return x[0] * weights[0] if len(x) == 1 else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:]) + + +def list_join(x: list, sep="\t", format_str="%s") -> str: + return sep.join([format_str % val for val in x]) + + +def val2list(x: list | tuple | Any, repeat_time=1) -> list: + if isinstance(x, (list, tuple)): + return list(x) + return [x for _ in range(repeat_time)] + + +def val2tuple(x: list | tuple | Any, min_len: int = 1, idx_repeat: int = -1) -> tuple: + x = val2list(x) + + # repeat elements if necessary + if len(x) > 0: + x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] + + return tuple(x) + + +def squeeze_list(x: Optional[list]) -> list | Any: + if x is not None and len(x) == 1: + return x[0] + else: + return x diff --git a/sgm/modules/efficientvitmodules/utils/network.py b/sgm/modules/efficientvitmodules/utils/network.py new file mode 100644 index 0000000..fa054e7 --- /dev/null +++ b/sgm/modules/efficientvitmodules/utils/network.py @@ -0,0 +1,95 @@ +import collections +import os +from inspect import signature +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "is_parallel", + "get_device", + "get_same_padding", + "resize", + "build_kwargs_from_config", + "load_state_dict_from_file", + "get_submodule_weights", +] + + +def is_parallel(model: nn.Module) -> bool: + return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) + + +def get_device(model: nn.Module) -> torch.device: + return model.parameters().__next__().device + + +def get_dtype(model: nn.Module) -> torch.dtype: + return model.parameters().__next__().dtype + + +def get_same_padding(kernel_size: int | tuple[int, ...]) -> int | tuple[int, ...]: + if isinstance(kernel_size, tuple): + return tuple([get_same_padding(ks) for ks in kernel_size]) + else: + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +def resize( + x: torch.Tensor, + size: Optional[Any] = None, + scale_factor: Optional[list[float]] = None, + mode: str = "bicubic", + align_corners: Optional[bool] = False, +) -> torch.Tensor: + if mode in {"bilinear", "bicubic"}: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + ) + elif mode in {"nearest", "area"}: + return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) + else: + raise NotImplementedError(f"resize(mode={mode}) not implemented.") + + +def build_kwargs_from_config(config: dict, target_func: Callable) -> dict[str, Any]: + valid_keys = list(signature(target_func).parameters) + kwargs = {} + for key in config: + if key in valid_keys: + kwargs[key] = config[key] + return kwargs + + +def load_state_dict_from_file(file: str, only_state_dict=True) -> dict[str, torch.Tensor]: + file = os.path.realpath(os.path.expanduser(file)) + checkpoint = torch.load(file, map_location="cpu", weights_only=True) + if only_state_dict and "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + return checkpoint + + +def get_submodule_weights(weights: collections.OrderedDict, prefix: str): + submodule_weights = collections.OrderedDict() + len_prefix = len(prefix) + for key, weight in weights.items(): + if key.startswith(prefix): + submodule_weights[key[len_prefix:]] = weight + return submodule_weights + + +def get_dtype_from_str(dtype: str) -> torch.dtype: + if dtype == "fp32": + return torch.float32 + if dtype == "fp16": + return torch.float16 + if dtype == "bf16": + return torch.bfloat16 + raise NotImplementedError(f"dtype {dtype} is not supported") diff --git a/sgm/modules/efficientvitmodules/utils/random.py b/sgm/modules/efficientvitmodules/utils/random.py new file mode 100644 index 0000000..9287bdc --- /dev/null +++ b/sgm/modules/efficientvitmodules/utils/random.py @@ -0,0 +1,63 @@ +from typing import Any, Optional + +import numpy as np +import torch + +__all__ = [ + "torch_randint", + "torch_random", + "torch_shuffle", + "torch_uniform", + "torch_random_choices", +] + + +def torch_randint(low: int, high: int, generator: Optional[torch.Generator] = None) -> int: + """uniform: [low, high)""" + if low == high: + return low + else: + assert low < high + return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) + + +def torch_random(generator: Optional[torch.Generator] = None) -> float: + """uniform distribution on the interval [0, 1)""" + return float(torch.rand(1, generator=generator)) + + +def torch_shuffle(src_list: list[Any], generator: Optional[torch.Generator] = None) -> list[Any]: + rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() + return [src_list[i] for i in rand_indexes] + + +def torch_uniform(low: float, high: float, generator: Optional[torch.Generator] = None) -> float: + """uniform distribution on the interval [low, high)""" + rand_val = torch_random(generator) + return (high - low) * rand_val + low + + +def torch_random_choices( + src_list: list[Any], + generator: Optional[torch.Generator] = None, + k=1, + weight_list: Optional[list[float]] = None, +) -> Any | list: + if weight_list is None: + rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,)) + out_list = [src_list[i] for i in rand_idx] + else: + assert len(weight_list) == len(src_list) + accumulate_weight_list = np.cumsum(weight_list) + + out_list = [] + for _ in range(k): + val = torch_uniform(0, accumulate_weight_list[-1], generator) + active_id = 0 + for i, weight_val in enumerate(accumulate_weight_list): + active_id = i + if weight_val > val: + break + out_list.append(src_list[active_id]) + + return out_list[0] if k == 1 else out_list