diff --git a/configs/sparsification/methods/VisionZip/visionzip.yml b/configs/sparsification/methods/VisionZip/visionzip.yml index 76d1fec8..19bf9d9e 100644 --- a/configs/sparsification/methods/VisionZip/visionzip.yml +++ b/configs/sparsification/methods/VisionZip/visionzip.yml @@ -17,8 +17,10 @@ sparse: method: TokenReduction special: method: VisionZip # retain - dominant: 191 # visual_tokens = dominan_tokens + 1(cls_token) - contextual: 30 + dominant: 162 # visual_tokens = dominan_tokens + contextual + contextual: 30 # llava: 162+30,108+20,54+10 llava_next: 108+20,54+10,27+5 + prune_only: False + merge_only: False save: save_trans: False save_fake: False diff --git a/llmc/compression/token_reduction/utils.py b/llmc/compression/token_reduction/utils.py index d3451ffb..7a69d3eb 100755 --- a/llmc/compression/token_reduction/utils.py +++ b/llmc/compression/token_reduction/utils.py @@ -296,7 +296,7 @@ def prepare_inputs_labels_for_multimodal_with_index_masks( if 'maxpool2x2' in mm_patch_merge_type: raise NotImplementedError elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio: - NotImplementedError + raise NotImplementedError elif 'unpad' in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) @@ -446,7 +446,6 @@ def prepare_inputs_labels_for_multimodal_with_index_masks( cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] - # import pdb; pdb.set_trace() cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) @@ -554,7 +553,6 @@ def prepare_inputs_labels_for_multimodal_with_index_masks( right_add = random.randint(left_add, self.config.pos_skipping_range) position_ids[:, :split_position] += left_add position_ids[:, split_position:] += right_add - # import pdb; pdb.set_trace() # rank0_print("Finish preparing") # print(vtoken_length) return None, position_ids, attention_mask, past_key_values, \ diff --git a/llmc/compression/token_reduction/visionzip.py b/llmc/compression/token_reduction/visionzip.py index 3bc3eb06..ca23dbfd 100755 --- a/llmc/compression/token_reduction/visionzip.py +++ b/llmc/compression/token_reduction/visionzip.py @@ -12,7 +12,8 @@ from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule -from .utils import apply_info, prefill_wrapper +from .utils import (apply_info, prefill_wrapper, + prepare_inputs_labels_for_multimodal_with_index_masks) def visionzip_forward( @@ -286,15 +287,19 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): - special_config = self.config.get('special', {}) - self.dominant = special_config['dominant'] - self.contextual = special_config['contextual'] + self.dominant = self.special_config['dominant'] + self.contextual = self.special_config['contextual'] - self.pruning_paras = special_config + self.pruning_paras = self.special_config + prune_only = self.special_config.get('prune_only', False) + merge_only = self.special_config.get('merge_only', False) + assert not (prune_only and merge_only), 'prune_only and merge_only cannot both be True' + self.pruning_paras['prune_only'] = prune_only + self.pruning_paras['merge_only'] = merge_only def register_reduction_modules(self): - def visionzip_hook(m, images, image_forward_outs): + def visionzip_hook(m, images, image_forward_outs, pruning_paras, llava_next): attn_weights = image_forward_outs.attentions[-2] hidden_states = image_forward_outs.hidden_states[-2] metric = self.blocks[-2].self_attn.k_proj.metric @@ -306,17 +311,22 @@ def visionzip_hook(m, images, image_forward_outs): cls_attention = attn_weights[:, :, cls_idx, cls_idx + 1:] cls_attention_sum = cls_attention.sum(dim=1) topk_indices = cls_attention_sum.topk(dominant_num, dim=1).indices + 1 - all_indices = torch.cat( - [ - torch.zeros( - (hidden_states.shape[0], 1), - dtype=topk_indices.dtype, - device=topk_indices.device, - ), - topk_indices, - ], - dim=1, - ) + if pruning_paras['merge_only']: + all_indices = torch.zeros( + (hidden_states.shape[0], 1), + dtype=topk_indices.dtype, device=topk_indices.device + ) + dominant_num = 0 + else: + all_indices = torch.cat( + [ + torch.zeros( + (hidden_states.shape[0], 1), + dtype=topk_indices.dtype, device=topk_indices.device, + ), + topk_indices, + ], dim=1, + ) mask = torch.ones_like( hidden_states[:, :, 0], dtype=torch.bool, device=metric.device @@ -355,6 +365,15 @@ def visionzip_hook(m, images, image_forward_outs): target_indices = torch.arange( 0, metric_normalized.shape[1], step, device=metric_normalized.device )[:contextual_num] + + # keep_idxs + index_masks = ~mask + if not pruning_paras['prune_only']: + pruned_indices = mask.nonzero(as_tuple=False)[:, 1].view(hidden_states.shape[0], -1) + target_index = pruned_indices[:, target_indices] + index_masks.scatter_(1, target_index, True) + pruning_paras['index_masks'] = index_masks[:, 1:] + target_tokens = metric_normalized[:, target_indices, :] tokens_to_merge = metric_normalized[ @@ -401,9 +420,15 @@ def visionzip_hook(m, images, image_forward_outs): ).to(images[0].dtype) res = list(image_forward_outs.hidden_states) - res[-2] = hidden_states_save.contiguous() + if not llava_next: + if pruning_paras['prune_only']: + res[-2] = dominant_tokens.contiguous().to(images[0].dtype) + else: + res[-2] = hidden_states_save.contiguous() image_forward_outs.hidden_states = tuple(res) + return image_forward_outs + def store_key_hook(m, x, outputs): bsz = x[0].shape[0] raw_outputs = ( @@ -418,10 +443,13 @@ def update_output_attentions_hook(module, args, kwargs): kwargs['output_attentions'] = True return args, kwargs + def update_index_masks_hook(module, inps, outs, pruning_paras): + module.index_masks = pruning_paras['index_masks'] + if self.model.__class__.__name__ == 'LlavaHf': vision_tower = self.model.vlm_model.vision_tower elif self.model.__class__.__name__ == 'Llava': - vision_tower = self.model.vlm_model.model.vision_tower.vision_tower + vision_tower = self.model.vision_model.vision_tower if self.model.__class__.__name__ in ('LlavaHf', 'Llava'): apply_info( @@ -444,7 +472,25 @@ def update_output_attentions_hook(module, args, kwargs): block.self_attn.k_proj.head_dim = block.self_attn.head_dim block.self_attn.k_proj.register_forward_hook(store_key_hook) - vision_tower.register_forward_hook(visionzip_hook) + vision_tower.register_forward_hook( + functools.partial( + visionzip_hook, + pruning_paras=self.pruning_paras, + llava_next=self.special_config['vision_token_length'] is None + ) + ) + + # llava_next + if self.special_config['vision_token_length'] is None: + + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + prepare_inputs_labels_for_multimodal_with_index_masks, + self.model.vlm_model + ) + + self.model.vision_model.register_forward_hook( + functools.partial(update_index_masks_hook, pruning_paras=self.pruning_paras), + ) def get_metric(fn, pruning_paras): @wraps(fn)