diff --git a/llmc/compression/token_reduction/dycoke.py b/llmc/compression/token_reduction/dycoke.py index 925d12d1..b4691658 100644 --- a/llmc/compression/token_reduction/dycoke.py +++ b/llmc/compression/token_reduction/dycoke.py @@ -7,8 +7,8 @@ try: from llava.model.llava_arch import LlavaMetaForCausalLM -except ModuleNotFoundError: - logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.') +except ImportError: + pass from transformers.cache_utils import Cache, DynamicCache from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY diff --git a/llmc/compression/token_reduction/fastvid.py b/llmc/compression/token_reduction/fastvid.py index 86c3d11a..96e60fc9 100644 --- a/llmc/compression/token_reduction/fastvid.py +++ b/llmc/compression/token_reduction/fastvid.py @@ -13,8 +13,8 @@ from llava.model.multimodal_encoder.siglip_encoder import ( SigLipVisionConfig, SigLipVisionModel) from llava.utils import rank0_print -except ModuleNotFoundError: - logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.') +except ImportError: + pass from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY diff --git a/llmc/compression/token_reduction/holitom.py b/llmc/compression/token_reduction/holitom.py index 5b48691f..27f601c5 100644 --- a/llmc/compression/token_reduction/holitom.py +++ b/llmc/compression/token_reduction/holitom.py @@ -20,8 +20,8 @@ from llava.utils import rank0_print from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) -except ModuleNotFoundError: - logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.') +except ImportError: + pass from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY diff --git a/llmc/compression/token_reduction/prunevid.py b/llmc/compression/token_reduction/prunevid.py index ac5e742b..290822dc 100644 --- a/llmc/compression/token_reduction/prunevid.py +++ b/llmc/compression/token_reduction/prunevid.py @@ -8,8 +8,8 @@ try: from llava.model.llava_arch import LlavaMetaForCausalLM -except ModuleNotFoundError: - logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.') +except ImportError: + pass from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY diff --git a/llmc/compression/token_reduction/random.py b/llmc/compression/token_reduction/random.py index d6dfde1d..e889df78 100644 --- a/llmc/compression/token_reduction/random.py +++ b/llmc/compression/token_reduction/random.py @@ -1,9 +1,7 @@ import functools -from functools import wraps from types import MethodType import torch -from loguru import logger from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY @@ -21,61 +19,21 @@ def __init__(self, config, model, blocks): def add_sparse_config(self): self.pruning_loc = self.special_config['pruning_loc'] - self.special_config['image_token_length'] = self.model.pruning_config[ - 'image_token_length' - ] - self.pruning_paras = self.special_config def register_reduction_modules(self): - def input_hook_llava(fn, pruning_paras): - @wraps(fn) - def wrapper(self, *args, **kwargs): - if len(args) == 0: - return fn(*args, **kwargs) - input_args = args[0] - if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: - return fn(*args, **kwargs) - - input_ids = args[0] - attention_mask = args[2] - token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX - pruning_paras['image_token_start_index'] = torch.where(token_indices)[ - 0 - ][0].item() - - outputs = fn(*args, **kwargs) - return outputs - - return wrapper - - @prefill_wrapper - def input_hook(module, input_args, pruning_paras): - input_ids = input_args[0] - image_token_idxs = ( - input_ids[0] == pruning_paras['vision_token_index'] - ).nonzero(as_tuple=True)[0] - pruning_paras['image_token_start_index'] = image_token_idxs[0].item() - - return input_args - @prefill_wrapper def random_pruning_hook(module, args, kwargs, pruning_paras): - logger.info(' ========random_pruning_hook======== ') - - rate = pruning_paras['rate'] - image_token_start_index = pruning_paras['image_token_start_index'] - image_token_length = pruning_paras['image_token_length'] + rate = pruning_paras['prune_ratio'] + image_token_start_index = pruning_paras['vision_token_start_index'] + image_token_length = pruning_paras['vision_token_length'] hidden_states = args[0] causal_mask = kwargs['attention_mask'] - logger.info(f'before hidden_states : {hidden_states.shape}') - device = hidden_states.device - vision_indexes = torch.arange( image_token_start_index, image_token_start_index + image_token_length, @@ -130,25 +88,169 @@ def random_pruning_hook(module, args, kwargs, pruning_paras): position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) - logger.info(f'after hidden_states : {hidden_states.shape}') return (hidden_states,), kwargs - if self.model.__class__.__name__ == 'LlavaHf': - self.model.embed_tokens.register_forward_pre_hook( - functools.partial(input_hook, pruning_paras=self.pruning_paras) + @prefill_wrapper + def holitom_merge_hook(module, args, kwargs, pruning_paras): + + rate = pruning_paras['prune_ratio'] + image_token_start_index = pruning_paras['vision_token_start_index'] + image_token_length = pruning_paras['vision_token_length'] + + hidden_states = args[0] + causal_mask = kwargs['attention_mask'] + + device = hidden_states.device + last_layer_attention = pruning_paras['attn_scores'] + # compute average attention over different head + last_layer_attention_avg = torch.mean( + last_layer_attention, dim=1 + )[0] + # generate new attention mask based on the average attention, + # sample the top ATTENTION_RANK tokens with highest attention + last_layer_attention_avg_last_tok = ( + last_layer_attention_avg[-1] + ) + # get the attention in image token + last_layer_attention_avg_last_tok_image = \ + last_layer_attention_avg_last_tok[ + image_token_start_index: + image_token_start_index + image_token_length + ] + # get the indexes of the top ATTENTION_RANK tokens + top_attention_rank_index = ( + last_layer_attention_avg_last_tok_image.topk( + round( + image_token_length * (1 - rate) + ) + ).indices + + image_token_start_index ) - elif self.model.__class__.__name__ == 'Llava': - from llava.constants import IMAGE_TOKEN_INDEX - hook_fn = input_hook_llava( - self.model.vlm_model.prepare_inputs_labels_for_multimodal, - self.pruning_paras, + all_indices = torch.arange( + image_token_length, device=device ) - self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( - hook_fn, self.model.vlm_model + non_topk_mask = ~torch.isin( + all_indices, + top_attention_rank_index + - image_token_start_index, + ) + non_topk_indices = ( + all_indices[non_topk_mask] + + image_token_start_index + ) + non_topk_states = hidden_states[ + :, non_topk_indices, : + ] # [batch_size, len(non_topk), hidden_size] + topk_states = hidden_states[ + :, top_attention_rank_index, : + ] # [batch_size, len(topk), hidden_size] + non_topk_norm = torch.norm( + non_topk_states, dim=-1, keepdim=True + ) # [batch_size, len(non_topk), 1] + topk_norm = torch.norm( + topk_states, dim=-1, keepdim=True + ) # [batch_size, len(topk), 1] + dot_product = torch.bmm( + non_topk_states, topk_states.transpose(1, 2) + ) # [batch_size, len(non_topk), len(topk)] + sim_matrix = dot_product / ( + non_topk_norm * topk_norm.transpose(1, 2) + ) + sim_max, sim_max_index = torch.max(sim_matrix, dim=-1) + + batch_size = hidden_states.size(0) + num_topk = len(top_attention_rank_index) + num_non_topk = len(non_topk_indices) + topk_counter = torch.ones((batch_size, num_topk, 1), device=hidden_states.device) + + for b in range(batch_size): + for i in range(num_non_topk): + topk_rel_idx = sim_max_index[b, i].item() # 这是 topk 中的相对索引 + topk_abs_idx = top_attention_rank_index[topk_rel_idx] # 得到绝对索引 + non_topk_abs_idx = non_topk_indices[i] + + # 累加non-topk到topk token上(就地) + hidden_states[b, topk_abs_idx, :] += hidden_states[b, non_topk_abs_idx, :] + # 增加计数 + topk_counter[b, topk_rel_idx] += 1 + + # 平均化所有topk token(包含自己和所有被合并的) + for b in range(batch_size): + for i in range(num_topk): + topk_abs_idx = top_attention_rank_index[i] + hidden_states[b, topk_abs_idx, :] /= topk_counter[b, i] + + keep_indexs = torch.cat( + ( + torch.arange( + image_token_start_index, + device=device, + ), + top_attention_rank_index, + torch.arange( + image_token_start_index + + image_token_length, + hidden_states.shape[1], + device=device, + ), + ) ) - self.blocks[self.pruning_loc].register_forward_pre_hook( - functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras), - with_kwargs=True, - ) + # sort index + keep_indexs = keep_indexs.sort().values + # filter hidden states & + hidden_states = hidden_states[:, keep_indexs, :] + # update position ids + position_ids = keep_indexs.unsqueeze(0) + # update attention mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, :hidden_states.shape[1], :hidden_states.shape[1]] + kwargs['attention_mask'].resize_as_(causal_mask).copy_(causal_mask.clone()) + kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_( + position_ids.squeeze(0).clone()) + kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone()) + + position_embeddings = kwargs['position_embeddings'] + index_dim = 1 if position_embeddings[0].dim() == 3 else 2 + new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone() + new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone() + position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) + position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) + + return (hidden_states,), kwargs + + def update_output_attentions_hook(module, args, kwargs): + kwargs['output_attentions'] = True + return args, kwargs + + def store_attention_hook(m, x, layer_outputs, pruning_paras): + layer_attention = layer_outputs[1] + pruning_paras['attn_scores'] = layer_attention + + if self.special_config['vision_token_length'] is None: + if self.model.__class__.__name__ == 'Llava': + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + self.vtoken_length_for_llava_hook( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras + ), self.model.vlm_model + ) + + if self.special_config['metric'] == 'random': + self.blocks[self.pruning_loc].register_forward_pre_hook( + functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras), + with_kwargs=True + ) + elif self.special_config['metric'] == 'holitom_merge': + self.blocks[self.pruning_loc - 1].register_forward_pre_hook( + update_output_attentions_hook, + with_kwargs=True + ) + self.blocks[self.pruning_loc - 1].register_forward_hook( + functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), + ) + self.blocks[self.pruning_loc].register_forward_pre_hook( + functools.partial(holitom_merge_hook, pruning_paras=self.pruning_paras), + with_kwargs=True + ) diff --git a/llmc/compression/token_reduction/tome.py b/llmc/compression/token_reduction/tome.py index 55c9e051..1bd660e0 100644 --- a/llmc/compression/token_reduction/tome.py +++ b/llmc/compression/token_reduction/tome.py @@ -3,8 +3,6 @@ from typing import Callable, Optional, Tuple import torch -import torch.nn.functional as F -from loguru import logger from transformers.models.clip.modeling_clip import CLIPEncoderLayer from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY @@ -20,8 +18,7 @@ def __init__(self, config, model, blocks): self.patch_layer() def add_sparse_config(self): - special_config = self.config.get('special', {}) - r_param = special_config.get('r', 0) + r_param = self.special_config.get('r', 0) if isinstance(r_param, int) or isinstance(r_param, float): self.r = [max(int(r_param), 0)] * len(self.blocks) elif isinstance(r_param, (tuple, list)): @@ -36,19 +33,17 @@ def add_sparse_config(self): else: raise ValueError('Invalid r format. Expected int or (start, step) tuple.') - self.pruning_paras = special_config + self.pruning_paras = self.special_config def patch_layer(self): for idx, block in enumerate(self.blocks): if self.r[idx] > 0: block.r = self.r[idx] if isinstance(block, CLIPEncoderLayer): # llava - block.self_attn.original_forward = block.self_attn.forward block.self_attn.forward = types.MethodType( tome_CLIPSdpaAttention_forward, block.self_attn ) - block.original_forward = block.forward block.forward = types.MethodType( tome_CLIPEncoderLayer_forward, block diff --git a/llmc/compression/token_reduction/utils.py b/llmc/compression/token_reduction/utils.py index 7a69d3eb..100dd567 100755 --- a/llmc/compression/token_reduction/utils.py +++ b/llmc/compression/token_reduction/utils.py @@ -9,8 +9,8 @@ try: from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX -except Exception as e: - logger.debug('LLaVA is not installed. Please install LLaVA to use this model.\nError: %s' % e) +except ImportError: + pass import random diff --git a/llmc/compression/token_reduction/visionzip.py b/llmc/compression/token_reduction/visionzip.py index ca23dbfd..97988e9e 100755 --- a/llmc/compression/token_reduction/visionzip.py +++ b/llmc/compression/token_reduction/visionzip.py @@ -585,19 +585,8 @@ def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras): st_idx = torch.nonzero(img_mask, as_tuple=True)[0] if st_idx.numel() > 0: - discontinuities = torch.where(st_idx[1:] - st_idx[:-1] != 1)[0] - if discontinuities.numel() > 0: - raise ValueError('Visual tokens are not contiguous in input_ids!') - segment_starts = [st_idx[0].item()] + [st_idx[i + 1].item() for i in discontinuities.tolist()] # noqa - segment_ends = [st_idx[i].item() for i in discontinuities.tolist()] + [st_idx[-1].item()] # noqa - offset = 0 - for first, last in zip(segment_starts, segment_ends): - length = last - first + 1 - # [15 1502] [1505 3289] - img_mask[first: last + 1] = ~select_mask[offset: offset + length] - else: - first, last = st_idx[0].item(), st_idx[-1].item() - img_mask[first: last + 1] = ~select_mask + first, last = st_idx[0].item(), st_idx[-1].item() + img_mask[first: last + 1] = ~select_mask img_mask = ~img_mask contextual_input_idx = false_pos[target_indices] + first diff --git a/llmc/compression/token_reduction/vispruner.py b/llmc/compression/token_reduction/vispruner.py index ddcabfa6..afe63fe1 100644 --- a/llmc/compression/token_reduction/vispruner.py +++ b/llmc/compression/token_reduction/vispruner.py @@ -80,12 +80,7 @@ def get_index_masks_hook(module, args, pruning_paras): B, N, C = image_features.shape device = image_features.device index_masks = torch.ones(B, N, dtype=torch.bool, device=device) - - visual_token_num = round( - self.special_config['vision_token_length'] * ( - 1 - self.special_config['prune_ratio'] - ) - ) # T + visual_token_num = round(N * (1 - self.special_config['prune_ratio'])) # T important_ratio = self.pruning_paras['important_ratio'] # r important_token_num = int(visual_token_num * important_ratio) # T_imp = T * r diverse_token_num = visual_token_num - important_token_num # T_div = T * (1 - r) @@ -143,7 +138,7 @@ def prune_hook(module, inputs, outputs, pruning_paras, model_config): index_masks = torch.split(index_masks, split_sizes, dim=0) # 'spatial_unpad', 'anyres' mm_patch_merge_type = getattr(model_config, 'mm_patch_merge_type', 'flat') - mm_patch_merge_type = mm_patch_merge_type.replace('_unpad', '') + # mm_patch_merge_type = mm_patch_merge_type.replace('_unpad', '') image_aspect_ratio = getattr(model_config, 'image_aspect_ratio', 'square') if mm_patch_merge_type == 'flat': @@ -200,7 +195,6 @@ def prune_hook(module, inputs, outputs, pruning_paras, model_config): ).to(image_feature.device) ), dim=-1) image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) index_mask = index_mask.permute(0, 2, 1, 3).contiguous().unsqueeze(0) index_mask = index_mask.flatten(1, 2).flatten(2, 3) index_mask = unpad_image(index_mask, image_sizes[image_idx]) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 66c36792..5869fa8d 100755 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -4,7 +4,7 @@ timm pillow loguru transformers>=4.45.2 -lmms-eval +lmms-eval==0.3.0 huggingface-hub sentencepiece protobuf