From 0ef394cd0c0f41b55fb073ab9abbb95acc13104e Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 31 Jan 2024 19:58:51 +0800 Subject: [PATCH 1/6] [TinyCLIP] inference for auto weight inheritance --- TinyCLIP/inference.py | 6 +++- TinyCLIP/src/open_clip/factory.py | 18 +++++++++++- TinyCLIP/src/open_clip/model.py | 28 +++++++++++++++++++ .../TinyCLIP-auto-ViT-63M-32-Text-31M.json | 19 +++++++++++++ TinyCLIP/src/open_clip/pretrained.py | 13 +++++++++ 5 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json diff --git a/TinyCLIP/inference.py b/TinyCLIP/inference.py index 7612015b..4057cf4f 100644 --- a/TinyCLIP/inference.py +++ b/TinyCLIP/inference.py @@ -18,7 +18,11 @@ # arch = 'TinyCLIP-ViT-61M-32-Text-29M' # model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') -arch = 'TinyCLIP-ViT-40M-32-Text-19M' +# arch = 'TinyCLIP-ViT-40M-32-Text-19M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') + +# auto inheritance +arch = 'TinyCLIP-auto-ViT-63M-32-Text-31M' model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') tokenizer = open_clip.get_tokenizer(arch) diff --git a/TinyCLIP/src/open_clip/factory.py b/TinyCLIP/src/open_clip/factory.py index ed9af2a2..c67b0c87 100644 --- a/TinyCLIP/src/open_clip/factory.py +++ b/TinyCLIP/src/open_clip/factory.py @@ -11,6 +11,7 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, convert_weights_to_fp16, resize_pos_embed +from .model import load_pruned_model, prune_model from .openai import load_openai_model from .pretrained import get_pretrained_cfg, download_pretrained from .transform import image_transform @@ -86,6 +87,14 @@ def load_checkpoint(model, checkpoint_path, strict=True): return incompatible_keys +def load_pruned_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + resize_pos_embed(state_dict, model) + load_pruned_model(model, state_dict) + incompatible_keys = dict() + return incompatible_keys + + def create_model( model_name: str, pretrained: str = '', @@ -138,6 +147,9 @@ def create_model( f'model sparsity varies from {model_cfg["start_sparsity"]} to {model_cfg["sparsity"]}, sparsity warmup steps: {model_cfg["sparsity_warmup"]}') logging.info(str(model_cfg)) + auto_weight_inheritance = model_cfg.get('mask_image', False) or \ + model_cfg.get('mask_text', False) + model = CLIP(**model_cfg) pretrained_cfg = {} @@ -153,7 +165,11 @@ def create_model( if checkpoint_path: logging.info( f'Loading pretrained {model_name} weights ({pretrained}).') - load_checkpoint(model, checkpoint_path) + if not auto_weight_inheritance: + load_checkpoint(model, checkpoint_path) + else: + load_pruned_checkpoint(model, checkpoint_path) + model = prune_model(model) else: logging.warning( f'Pretrained weights ({pretrained}) not found for model {model_name}.') diff --git a/TinyCLIP/src/open_clip/model.py b/TinyCLIP/src/open_clip/model.py index 0004deef..49db9348 100644 --- a/TinyCLIP/src/open_clip/model.py +++ b/TinyCLIP/src/open_clip/model.py @@ -1315,6 +1315,9 @@ def _copy_to_full_weight(dst, src): slices = [slice(0, d) for d in dims] dst[slices].copy_(src) + for _ in range(2): + pruned_state_dict = {k.replace('module.', ''): v for k, v in pruned_state_dict.items()} + lambda_init_value = 10.0 model_state_dict = model.state_dict() head_dim = model.transformer.head_dim @@ -1406,3 +1409,28 @@ def _get_layer_id(name): :].fill_(-lambda_init_value) model.load_state_dict(model_state_dict, strict=True) + + +def prune_model(model): + device = next(model.parameters()).device + + with torch.no_grad(): + model.image_encoder_without_ddp.eval() + image = torch.randn((1, 3, 224, 224), device=device) + model.image_encoder_without_ddp(image) + model.image_encoder_without_ddp = model.image_encoder_without_ddp.prune() + + assert hasattr( + model.image_encoder_without_ddp, 'l0_module') + model.image_encoder_without_ddp.l0_module = None + + with torch.no_grad(): + model.text_encoder_without_ddp.eval() + text = torch.randint(0, 100, (1, 77), device=device) + model.text_encoder_without_ddp(text) + model.text_encoder_without_ddp = model.text_encoder_without_ddp.prune() + assert hasattr(model.text_encoder_without_ddp, 'l0_module') + + model.text_encoder_without_ddp.l0_module = None + + return model diff --git a/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json new file mode 100644 index 00000000..bb568bb9 --- /dev/null +++ b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json @@ -0,0 +1,19 @@ +{ + "mask_image": true, + "mask_text": true, + + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/TinyCLIP/src/open_clip/pretrained.py b/TinyCLIP/src/open_clip/pretrained.py index b0ed7216..c92efd7b 100644 --- a/TinyCLIP/src/open_clip/pretrained.py +++ b/TinyCLIP/src/open_clip/pretrained.py @@ -146,6 +146,9 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): ) # TinyCLIP + +## manual weight inheritance + _TINYCLIP_VIT_39M_16_TEXT_19M = { "YFCC15M": _pcfg( "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-39M-16-Text-19M-YFCC15M.pt", @@ -182,6 +185,14 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): ), } +## auto weight inheritance + +_TINYCLIP_AUTO_VIT_63M_32_TEXT_31M = { + "LAION400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt", + ), +} + _PRETRAINED = { "RN50": _RN50, "RN50-quickgelu": _RN50_quickgelu, @@ -205,6 +216,8 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "TinyCLIP-ResNet-19M-Text-19M": _TINYCLIP_RESNET_19M_TEXT_19M, "TinyCLIP-ViT-61M-32-Text-29M": _TINYCLIP_VIT_61M_32_TEXT_29M, "TinyCLIP-ViT-40M-32-Text-19M": _TINYCLIP_VIT_40M_32_TEXT_19M, + + "TinyCLIP-auto-ViT-63M-32-Text-31M": _TINYCLIP_AUTO_VIT_63M_32_TEXT_31M, } From 6717bb0ab5fbf8ccd9f2fd21e9594c396cbf9cc7 Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 31 Jan 2024 20:12:20 +0800 Subject: [PATCH 2/6] [TinyCLIP] inference for auto weight inheritance --- TinyCLIP/inference.py | 17 +++++++++++-- TinyCLIP/src/open_clip/model.py | 9 ++++--- .../TinyCLIP-auto-ViT-22M-32-Text-10M.json | 19 +++++++++++++++ .../TinyCLIP-auto-ViT-45M-32-Text-18M.json | 19 +++++++++++++++ TinyCLIP/src/open_clip/pretrained.py | 24 +++++++++++++++++-- 5 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json create mode 100644 TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json diff --git a/TinyCLIP/inference.py b/TinyCLIP/inference.py index 4057cf4f..30160206 100644 --- a/TinyCLIP/inference.py +++ b/TinyCLIP/inference.py @@ -22,8 +22,21 @@ # model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') # auto inheritance -arch = 'TinyCLIP-auto-ViT-63M-32-Text-31M' -model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') +# arch = 'TinyCLIP-auto-ViT-63M-32-Text-31M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') + +# arch = 'TinyCLIP-auto-ViT-45M-32-Text-18M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') + +# arch = 'TinyCLIP-auto-ViT-22M-32-Text-10M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M') + +# arch = 'TinyCLIP-auto-ViT-63M-32-Text-31M' +# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAIONYFCC400M') + +arch = 'TinyCLIP-auto-ViT-45M-32-Text-18M' +model, _, preprocess = open_clip.create_model_and_transforms( + arch, pretrained='LAIONYFCC400M') tokenizer = open_clip.get_tokenizer(arch) diff --git a/TinyCLIP/src/open_clip/model.py b/TinyCLIP/src/open_clip/model.py index 49db9348..f0f0a34c 100644 --- a/TinyCLIP/src/open_clip/model.py +++ b/TinyCLIP/src/open_clip/model.py @@ -1316,7 +1316,8 @@ def _copy_to_full_weight(dst, src): dst[slices].copy_(src) for _ in range(2): - pruned_state_dict = {k.replace('module.', ''): v for k, v in pruned_state_dict.items()} + pruned_state_dict = { + k.replace('module.', ''): v for k, v in pruned_state_dict.items()} lambda_init_value = 10.0 model_state_dict = model.state_dict() @@ -1416,7 +1417,8 @@ def prune_model(model): with torch.no_grad(): model.image_encoder_without_ddp.eval() - image = torch.randn((1, 3, 224, 224), device=device) + image_size = (1, 3) + model.image_encoder_without_ddp.visual.image_size + image = torch.randn(image_size, device=device) model.image_encoder_without_ddp(image) model.image_encoder_without_ddp = model.image_encoder_without_ddp.prune() @@ -1426,7 +1428,8 @@ def prune_model(model): with torch.no_grad(): model.text_encoder_without_ddp.eval() - text = torch.randint(0, 100, (1, 77), device=device) + context_length = model.text_encoder_without_ddp.context_length + text = torch.zeros((1, context_length), dtype=torch.long, device=device) model.text_encoder_without_ddp(text) model.text_encoder_without_ddp = model.text_encoder_without_ddp.prune() assert hasattr(model.text_encoder_without_ddp, 'l0_module') diff --git a/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json new file mode 100644 index 00000000..bb568bb9 --- /dev/null +++ b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json @@ -0,0 +1,19 @@ +{ + "mask_image": true, + "mask_text": true, + + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json new file mode 100644 index 00000000..bb568bb9 --- /dev/null +++ b/TinyCLIP/src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json @@ -0,0 +1,19 @@ +{ + "mask_image": true, + "mask_text": true, + + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/TinyCLIP/src/open_clip/pretrained.py b/TinyCLIP/src/open_clip/pretrained.py index c92efd7b..4a890695 100644 --- a/TinyCLIP/src/open_clip/pretrained.py +++ b/TinyCLIP/src/open_clip/pretrained.py @@ -147,7 +147,7 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): # TinyCLIP -## manual weight inheritance +# manual weight inheritance _TINYCLIP_VIT_39M_16_TEXT_19M = { "YFCC15M": _pcfg( @@ -185,12 +185,30 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): ), } -## auto weight inheritance +# auto weight inheritance _TINYCLIP_AUTO_VIT_63M_32_TEXT_31M = { "LAION400M": _pcfg( "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt", ), + "LAIONYFCC400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt", + ), +} + +_TINYCLIP_AUTO_VIT_45M_32_TEXT_18M = { + "LAION400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt", + ), + "LAIONYFCC400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt", + ), +} + +_TINYCLIP_AUTO_VIT_22M_32_TEXT_10M = { + "LAION400M": _pcfg( + "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt", + ), } _PRETRAINED = { @@ -218,6 +236,8 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "TinyCLIP-ViT-40M-32-Text-19M": _TINYCLIP_VIT_40M_32_TEXT_19M, "TinyCLIP-auto-ViT-63M-32-Text-31M": _TINYCLIP_AUTO_VIT_63M_32_TEXT_31M, + "TinyCLIP-auto-ViT-45M-32-Text-18M": _TINYCLIP_AUTO_VIT_45M_32_TEXT_18M, + "TinyCLIP-auto-ViT-22M-32-Text-10M": _TINYCLIP_AUTO_VIT_22M_32_TEXT_10M, } From 6fdd57f73abdb32e79b82042ae001e18aa37c8dc Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 31 Jan 2024 20:14:14 +0800 Subject: [PATCH 3/6] one line --- TinyCLIP/inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/TinyCLIP/inference.py b/TinyCLIP/inference.py index 30160206..6c4b19cc 100644 --- a/TinyCLIP/inference.py +++ b/TinyCLIP/inference.py @@ -35,8 +35,7 @@ # model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAIONYFCC400M') arch = 'TinyCLIP-auto-ViT-45M-32-Text-18M' -model, _, preprocess = open_clip.create_model_and_transforms( - arch, pretrained='LAIONYFCC400M') +model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAIONYFCC400M') tokenizer = open_clip.get_tokenizer(arch) From fea33c157068fdfa354348de1fb23a20b76cf66a Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 31 Jan 2024 20:19:48 +0800 Subject: [PATCH 4/6] format --- TinyCLIP/src/open_clip/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/TinyCLIP/src/open_clip/model.py b/TinyCLIP/src/open_clip/model.py index f0f0a34c..f8dbd7bc 100644 --- a/TinyCLIP/src/open_clip/model.py +++ b/TinyCLIP/src/open_clip/model.py @@ -1422,8 +1422,7 @@ def prune_model(model): model.image_encoder_without_ddp(image) model.image_encoder_without_ddp = model.image_encoder_without_ddp.prune() - assert hasattr( - model.image_encoder_without_ddp, 'l0_module') + assert hasattr(model.image_encoder_without_ddp, 'l0_module') model.image_encoder_without_ddp.l0_module = None with torch.no_grad(): @@ -1432,8 +1431,8 @@ def prune_model(model): text = torch.zeros((1, context_length), dtype=torch.long, device=device) model.text_encoder_without_ddp(text) model.text_encoder_without_ddp = model.text_encoder_without_ddp.prune() - assert hasattr(model.text_encoder_without_ddp, 'l0_module') + assert hasattr(model.text_encoder_without_ddp, 'l0_module') model.text_encoder_without_ddp.l0_module = None return model From f2b0f3a8e2ae27a3fb583f199980e20c8292416c Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 31 Jan 2024 20:24:27 +0800 Subject: [PATCH 5/6] config link --- TinyCLIP/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/TinyCLIP/README.md b/TinyCLIP/README.md index 797b89d8..5105bc5b 100644 --- a/TinyCLIP/README.md +++ b/TinyCLIP/README.md @@ -31,11 +31,11 @@ [TinyCLIP ResNet-19M Text-19M](./src/open_clip/model_configs/TinyCLIP-ResNet-19M-Text-19M.json) | manual | LAION-400M | 56.4 | 4.4 | 3,024| [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ResNet-19M-Text-19M-LAION400M.pt) [TinyCLIP ViT-61M/32 Text-29M](./src/open_clip/model_configs/TinyCLIP-ViT-61M-32-Text-29M.json) | manual | LAION-400M | 62.4 | 5.3 | 3,191|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-61M-32-Text-29M-LAION400M.pt) [TinyCLIP ViT-40M/32 Text-19M](./src/open_clip/model_configs/TinyCLIP-ViT-40M-32-Text-19M.json) | manual | LAION-400M | 59.8 | 3.5 | 4,641|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-40M-32-Text-19M-LAION400M.pt) -TinyCLIP ViT-63M/32 Text-31M | auto | LAION-400M | 63.9 | 5.6 | 2,905|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt) -TinyCLIP ViT-45M/32 Text-18M | auto | LAION-400M | 61.4 | 3.7 | 3,682|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt) -TinyCLIP ViT-22M/32 Text-10M | auto | LAION-400M | 53.7 | 1.9 | 5,504|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt) -TinyCLIP ViT-63M/32 Text-31M | auto | LAION+YFCC-400M | 64.5 | 5.6| 2,909 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt) -TinyCLIP ViT-45M/32 Text-18M | auto | LAION+YFCC-400M | 62.7 | 1.9 | 3,685 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt) +[TinyCLIP ViT-63M/32 Text-31M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json) | auto | LAION-400M | 63.9 | 5.6 | 2,905|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt) +[TinyCLIP ViT-45M/32 Text-18M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json) | auto | LAION-400M | 61.4 | 3.7 | 3,682|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt) +[TinyCLIP ViT-22M/32 Text-10M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json) | auto | LAION-400M | 53.7 | 1.9 | 5,504|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt) +[TinyCLIP ViT-63M/32 Text-31M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json) | auto | LAION+YFCC-400M | 64.5 | 5.6| 2,909 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt) +[TinyCLIP ViT-45M/32 Text-18M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json) | auto | LAION+YFCC-400M | 62.7 | 1.9 | 3,685 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt) Note: The configs of models with auto inheritance are generated automatically. From 2e57e19cbb58446e9b386e9cda6bfb258699f02a Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 31 Jan 2024 20:35:04 +0800 Subject: [PATCH 6/6] incompatible_keys --- TinyCLIP/src/open_clip/factory.py | 3 +-- TinyCLIP/src/open_clip/model.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/TinyCLIP/src/open_clip/factory.py b/TinyCLIP/src/open_clip/factory.py index c67b0c87..4d1690eb 100644 --- a/TinyCLIP/src/open_clip/factory.py +++ b/TinyCLIP/src/open_clip/factory.py @@ -90,8 +90,7 @@ def load_checkpoint(model, checkpoint_path, strict=True): def load_pruned_checkpoint(model, checkpoint_path, strict=True): state_dict = load_state_dict(checkpoint_path) resize_pos_embed(state_dict, model) - load_pruned_model(model, state_dict) - incompatible_keys = dict() + incompatible_keys = load_pruned_model(model, state_dict, strict=strict) return incompatible_keys diff --git a/TinyCLIP/src/open_clip/model.py b/TinyCLIP/src/open_clip/model.py index f8dbd7bc..a61a19bb 100644 --- a/TinyCLIP/src/open_clip/model.py +++ b/TinyCLIP/src/open_clip/model.py @@ -1297,7 +1297,7 @@ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim= @torch.no_grad() -def load_pruned_model(model, pruned_state_dict): +def load_pruned_model(model, pruned_state_dict, strict=True): ''' A full model loads the pruned state dict. @@ -1409,7 +1409,7 @@ def _get_layer_id(name): model_state_dict[f'{ename}.l0_module.intermediate_loga'][d, :].fill_(-lambda_init_value) - model.load_state_dict(model_state_dict, strict=True) + return model.load_state_dict(model_state_dict, strict=strict) def prune_model(model):