diff --git a/TinyCLIP/README.md b/TinyCLIP/README.md index 797b89d..5105bc5 100644 --- a/TinyCLIP/README.md +++ b/TinyCLIP/README.md @@ -31,11 +31,11 @@ [TinyCLIP ResNet-19M Text-19M](./src/open_clip/model_configs/TinyCLIP-ResNet-19M-Text-19M.json) | manual | LAION-400M | 56.4 | 4.4 | 3,024| [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ResNet-19M-Text-19M-LAION400M.pt) [TinyCLIP ViT-61M/32 Text-29M](./src/open_clip/model_configs/TinyCLIP-ViT-61M-32-Text-29M.json) | manual | LAION-400M | 62.4 | 5.3 | 3,191|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-61M-32-Text-29M-LAION400M.pt) [TinyCLIP ViT-40M/32 Text-19M](./src/open_clip/model_configs/TinyCLIP-ViT-40M-32-Text-19M.json) | manual | LAION-400M | 59.8 | 3.5 | 4,641|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-40M-32-Text-19M-LAION400M.pt) -TinyCLIP ViT-63M/32 Text-31M | auto | LAION-400M | 63.9 | 5.6 | 2,905|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt) -TinyCLIP ViT-45M/32 Text-18M | auto | LAION-400M | 61.4 | 3.7 | 3,682|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt) -TinyCLIP ViT-22M/32 Text-10M | auto | LAION-400M | 53.7 | 1.9 | 5,504|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt) -TinyCLIP ViT-63M/32 Text-31M | auto | LAION+YFCC-400M | 64.5 | 5.6| 2,909 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt) -TinyCLIP ViT-45M/32 Text-18M | auto | LAION+YFCC-400M | 62.7 | 1.9 | 3,685 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt) +[TinyCLIP ViT-63M/32 Text-31M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json) | auto | LAION-400M | 63.9 | 5.6 | 2,905|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt) +[TinyCLIP ViT-45M/32 Text-18M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json) | auto | LAION-400M | 61.4 | 3.7 | 3,682|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt) +[TinyCLIP ViT-22M/32 Text-10M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json) | auto | LAION-400M | 53.7 | 1.9 | 5,504|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt) +[TinyCLIP ViT-63M/32 Text-31M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json) | auto | LAION+YFCC-400M | 64.5 | 5.6| 2,909 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt) +[TinyCLIP ViT-45M/32 Text-18M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json) | auto | LAION+YFCC-400M | 62.7 | 1.9 | 3,685 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt) Note: The configs of models with auto inheritance are generated automatically. diff --git a/TinyCLIP/inference.py b/TinyCLIP/inference.py index 7612015..6c4b19c 100644 --- a/TinyCLIP/inference.py +++ b/TinyCLIP/inference.py @@ -18,8 +18,24 @@ # arch = 'TinyCLIP-ViT-61M-32-Text-29M' # model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') -arch = 'TinyCLIP-ViT-40M-32-Text-19M' -model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') +# arch = 'TinyCLIP-ViT-40M-32-Text-19M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') + +# auto inheritance +# arch = 'TinyCLIP-auto-ViT-63M-32-Text-31M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') + +# arch = 'TinyCLIP-auto-ViT-45M-32-Text-18M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') + +# arch = 'TinyCLIP-auto-ViT-22M-32-Text-10M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') + +# arch = 'TinyCLIP-auto-ViT-63M-32-Text-31M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAIONYFCC400M') + +arch = 'TinyCLIP-auto-ViT-45M-32-Text-18M' +model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAIONYFCC400M') tokenizer = open_clip.get_tokenizer(arch) diff --git a/TinyCLIP/src/open_clip/factory.py b/TinyCLIP/src/open_clip/factory.py index ed9af2a..4d1690e 100644 --- a/TinyCLIP/src/open_clip/factory.py +++ b/TinyCLIP/src/open_clip/factory.py @@ -11,6 +11,7 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, convert_weights_to_fp16, resize_pos_embed +from .model import load_pruned_model, prune_model from .openai import load_openai_model from .pretrained import get_pretrained_cfg, download_pretrained from .transform import image_transform @@ -86,6 +87,13 @@ def load_checkpoint(model, checkpoint_path, strict=True): return incompatible_keys +def load_pruned_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + resize_pos_embed(state_dict, model) + incompatible_keys = load_pruned_model(model, state_dict, strict=strict) + return incompatible_keys + + def create_model( model_name: str, pretrained: str = '', @@ -138,6 +146,9 @@ def create_model( f'model sparsity varies from {model_cfg["start_sparsity"]} to {model_cfg["sparsity"]}, sparsity warmup steps: {model_cfg["sparsity_warmup"]}') logging.info(str(model_cfg)) + auto_weight_inheritance = model_cfg.get('mask_image', False) or \ + model_cfg.get('mask_text', False) + model = CLIP(**model_cfg) pretrained_cfg = {} @@ -153,7 +164,11 @@ def create_model( if checkpoint_path: logging.info( f'Loading pretrained {model_name} weights ({pretrained}).') - load_checkpoint(model, checkpoint_path) + if not auto_weight_inheritance: + load_checkpoint(model, checkpoint_path) + else: + load_pruned_checkpoint(model, checkpoint_path) + model = prune_model(model) else: logging.warning( f'Pretrained weights ({pretrained}) not found for model {model_name}.') diff --git a/TinyCLIP/src/open_clip/model.py b/TinyCLIP/src/open_clip/model.py index 0004dee..a61a19b 100644 --- a/TinyCLIP/src/open_clip/model.py +++ b/TinyCLIP/src/open_clip/model.py @@ -1297,7 +1297,7 @@ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim= @torch.no_grad() -def load_pruned_model(model, pruned_state_dict): +def load_pruned_model(model, pruned_state_dict, strict=True): ''' A full model loads the pruned state dict. @@ -1315,6 +1315,10 @@ def _copy_to_full_weight(dst, src): slices = [slice(0, d) for d in dims] dst[slices].copy_(src) + for _ in range(2): + pruned_state_dict = { + k.replace('module.', ''): v for k, v in pruned_state_dict.items()} + lambda_init_value = 10.0 model_state_dict = model.state_dict() head_dim = model.transformer.head_dim @@ -1405,4 +1409,30 @@ def _get_layer_id(name): model_state_dict[f'{ename}.l0_module.intermediate_loga'][d, :].fill_(-lambda_init_value) - model.load_state_dict(model_state_dict, strict=True) + return model.load_state_dict(model_state_dict, strict=strict) + + +def prune_model(model): + device = next(model.parameters()).device + + with torch.no_grad(): + model.image_encoder_without_ddp.eval() + image_size = (1, 3) + model.image_encoder_without_ddp.visual.image_size + image = torch.randn(image_size, device=device) + model.image_encoder_without_ddp(image) + model.image_encoder_without_ddp = model.image_encoder_without_ddp.prune() + + assert hasattr(model.image_encoder_without_ddp, 'l0_module') + model.image_encoder_without_ddp.l0_module = None + + with torch.no_grad(): + model.text_encoder_without_ddp.eval() + context_length = model.text_encoder_without_ddp.context_length + text = torch.zeros((1, context_length), dtype=torch.long, device=device) + model.text_encoder_without_ddp(text) + model.text_encoder_without_ddp = model.text_encoder_without_ddp.prune() + + assert hasattr(model.text_encoder_without_ddp, 'l0_module') + model.text_encoder_without_ddp.l0_module = None + + return model diff --git a/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json new file mode 100644 index 0000000..bb568bb --- /dev/null +++ b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json @@ -0,0 +1,19 @@ +{ + "mask_image": true, + "mask_text": true, + + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json new file mode 100644 index 0000000..bb568bb --- /dev/null +++ b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json @@ -0,0 +1,19 @@ +{ + "mask_image": true, + "mask_text": true, + + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json new file mode 100644 index 0000000..bb568bb --- /dev/null +++ b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json @@ -0,0 +1,19 @@ +{ + "mask_image": true, + "mask_text": true, + + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/TinyCLIP/src/open_clip/pretrained.py b/TinyCLIP/src/open_clip/pretrained.py index b0ed721..4a89069 100644 --- a/TinyCLIP/src/open_clip/pretrained.py +++ b/TinyCLIP/src/open_clip/pretrained.py @@ -146,6 +146,9 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): ) # TinyCLIP + +# manual weight inheritance + _TINYCLIP_VIT_39M_16_TEXT_19M = { "YFCC15M": _pcfg( "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-39M-16-Text-19M-YFCC15M.pt", @@ -182,6 +185,32 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): ), } +# auto weight inheritance + +_TINYCLIP_AUTO_VIT_63M_32_TEXT_31M = { + "LAION400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt", + ), + "LAIONYFCC400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt", + ), +} + +_TINYCLIP_AUTO_VIT_45M_32_TEXT_18M = { + "LAION400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt", + ), + "LAIONYFCC400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt", + ), +} + +_TINYCLIP_AUTO_VIT_22M_32_TEXT_10M = { + "LAION400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt", + ), +} + _PRETRAINED = { "RN50": _RN50, "RN50-quickgelu": _RN50_quickgelu, @@ -205,6 +234,10 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "TinyCLIP-ResNet-19M-Text-19M": _TINYCLIP_RESNET_19M_TEXT_19M, "TinyCLIP-ViT-61M-32-Text-29M": _TINYCLIP_VIT_61M_32_TEXT_29M, "TinyCLIP-ViT-40M-32-Text-19M": _TINYCLIP_VIT_40M_32_TEXT_19M, + + "TinyCLIP-auto-ViT-63M-32-Text-31M": _TINYCLIP_AUTO_VIT_63M_32_TEXT_31M, + "TinyCLIP-auto-ViT-45M-32-Text-18M": _TINYCLIP_AUTO_VIT_45M_32_TEXT_18M, + "TinyCLIP-auto-ViT-22M-32-Text-10M": _TINYCLIP_AUTO_VIT_22M_32_TEXT_10M, }