-
Notifications
You must be signed in to change notification settings - Fork 70
fix vispruner bugs and update holitom_merge #431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,7 @@ | ||
| import functools | ||
| from functools import wraps | ||
| from types import MethodType | ||
|
|
||
| import torch | ||
| from loguru import logger | ||
|
|
||
| from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY | ||
|
|
||
|
|
@@ -21,61 +19,21 @@ def __init__(self, config, model, blocks): | |
| def add_sparse_config(self): | ||
|
|
||
| self.pruning_loc = self.special_config['pruning_loc'] | ||
| self.special_config['image_token_length'] = self.model.pruning_config[ | ||
| 'image_token_length' | ||
| ] | ||
|
|
||
| self.pruning_paras = self.special_config | ||
|
|
||
| def register_reduction_modules(self): | ||
|
|
||
| def input_hook_llava(fn, pruning_paras): | ||
| @wraps(fn) | ||
| def wrapper(self, *args, **kwargs): | ||
| if len(args) == 0: | ||
| return fn(*args, **kwargs) | ||
| input_args = args[0] | ||
| if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: | ||
| return fn(*args, **kwargs) | ||
|
|
||
| input_ids = args[0] | ||
| attention_mask = args[2] | ||
| token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX | ||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[ | ||
| 0 | ||
| ][0].item() | ||
|
|
||
| outputs = fn(*args, **kwargs) | ||
| return outputs | ||
|
|
||
| return wrapper | ||
|
|
||
| @prefill_wrapper | ||
| def input_hook(module, input_args, pruning_paras): | ||
| input_ids = input_args[0] | ||
| image_token_idxs = ( | ||
| input_ids[0] == pruning_paras['vision_token_index'] | ||
| ).nonzero(as_tuple=True)[0] | ||
| pruning_paras['image_token_start_index'] = image_token_idxs[0].item() | ||
|
|
||
| return input_args | ||
|
|
||
| @prefill_wrapper | ||
| def random_pruning_hook(module, args, kwargs, pruning_paras): | ||
|
|
||
| logger.info(' ========random_pruning_hook======== ') | ||
|
|
||
| rate = pruning_paras['rate'] | ||
| image_token_start_index = pruning_paras['image_token_start_index'] | ||
| image_token_length = pruning_paras['image_token_length'] | ||
| rate = pruning_paras['prune_ratio'] | ||
| image_token_start_index = pruning_paras['vision_token_start_index'] | ||
| image_token_length = pruning_paras['vision_token_length'] | ||
|
|
||
| hidden_states = args[0] | ||
| causal_mask = kwargs['attention_mask'] | ||
|
|
||
| logger.info(f'before hidden_states : {hidden_states.shape}') | ||
|
|
||
| device = hidden_states.device | ||
|
|
||
| vision_indexes = torch.arange( | ||
| image_token_start_index, | ||
| image_token_start_index + image_token_length, | ||
|
|
@@ -130,25 +88,169 @@ def random_pruning_hook(module, args, kwargs, pruning_paras): | |
| position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) | ||
| position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) | ||
|
|
||
| logger.info(f'after hidden_states : {hidden_states.shape}') | ||
| return (hidden_states,), kwargs | ||
|
|
||
| if self.model.__class__.__name__ == 'LlavaHf': | ||
| self.model.embed_tokens.register_forward_pre_hook( | ||
| functools.partial(input_hook, pruning_paras=self.pruning_paras) | ||
| @prefill_wrapper | ||
| def holitom_merge_hook(module, args, kwargs, pruning_paras): | ||
|
|
||
| rate = pruning_paras['prune_ratio'] | ||
| image_token_start_index = pruning_paras['vision_token_start_index'] | ||
| image_token_length = pruning_paras['vision_token_length'] | ||
|
|
||
| hidden_states = args[0] | ||
| causal_mask = kwargs['attention_mask'] | ||
|
|
||
| device = hidden_states.device | ||
| last_layer_attention = pruning_paras['attn_scores'] | ||
| # compute average attention over different head | ||
| last_layer_attention_avg = torch.mean( | ||
| last_layer_attention, dim=1 | ||
| )[0] | ||
| # generate new attention mask based on the average attention, | ||
| # sample the top ATTENTION_RANK tokens with highest attention | ||
| last_layer_attention_avg_last_tok = ( | ||
| last_layer_attention_avg[-1] | ||
| ) | ||
| # get the attention in image token | ||
| last_layer_attention_avg_last_tok_image = \ | ||
| last_layer_attention_avg_last_tok[ | ||
| image_token_start_index: | ||
| image_token_start_index + image_token_length | ||
| ] | ||
| # get the indexes of the top ATTENTION_RANK tokens | ||
| top_attention_rank_index = ( | ||
| last_layer_attention_avg_last_tok_image.topk( | ||
| round( | ||
| image_token_length * (1 - rate) | ||
| ) | ||
| ).indices | ||
| + image_token_start_index | ||
| ) | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| from llava.constants import IMAGE_TOKEN_INDEX | ||
|
|
||
| hook_fn = input_hook_llava( | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal, | ||
| self.pruning_paras, | ||
| all_indices = torch.arange( | ||
| image_token_length, device=device | ||
| ) | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( | ||
| hook_fn, self.model.vlm_model | ||
| non_topk_mask = ~torch.isin( | ||
| all_indices, | ||
| top_attention_rank_index | ||
| - image_token_start_index, | ||
| ) | ||
| non_topk_indices = ( | ||
| all_indices[non_topk_mask] | ||
| + image_token_start_index | ||
| ) | ||
| non_topk_states = hidden_states[ | ||
| :, non_topk_indices, : | ||
| ] # [batch_size, len(non_topk), hidden_size] | ||
| topk_states = hidden_states[ | ||
| :, top_attention_rank_index, : | ||
| ] # [batch_size, len(topk), hidden_size] | ||
| non_topk_norm = torch.norm( | ||
| non_topk_states, dim=-1, keepdim=True | ||
| ) # [batch_size, len(non_topk), 1] | ||
| topk_norm = torch.norm( | ||
| topk_states, dim=-1, keepdim=True | ||
| ) # [batch_size, len(topk), 1] | ||
| dot_product = torch.bmm( | ||
| non_topk_states, topk_states.transpose(1, 2) | ||
| ) # [batch_size, len(non_topk), len(topk)] | ||
| sim_matrix = dot_product / ( | ||
| non_topk_norm * topk_norm.transpose(1, 2) | ||
| ) | ||
| sim_max, sim_max_index = torch.max(sim_matrix, dim=-1) | ||
|
|
||
| batch_size = hidden_states.size(0) | ||
| num_topk = len(top_attention_rank_index) | ||
| num_non_topk = len(non_topk_indices) | ||
| topk_counter = torch.ones((batch_size, num_topk, 1), device=hidden_states.device) | ||
|
|
||
| for b in range(batch_size): | ||
| for i in range(num_non_topk): | ||
| topk_rel_idx = sim_max_index[b, i].item() # 这是 topk 中的相对索引 | ||
| topk_abs_idx = top_attention_rank_index[topk_rel_idx] # 得到绝对索引 | ||
| non_topk_abs_idx = non_topk_indices[i] | ||
|
|
||
| # 累加non-topk到topk token上(就地) | ||
| hidden_states[b, topk_abs_idx, :] += hidden_states[b, non_topk_abs_idx, :] | ||
| # 增加计数 | ||
| topk_counter[b, topk_rel_idx] += 1 | ||
|
|
||
|
Comment on lines
+167
to
+177
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The nested Python loops over Consider replacing the loop with vectorized PyTorch operations. You can use |
||
| # 平均化所有topk token(包含自己和所有被合并的) | ||
|
Comment on lines
+169
to
+178
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code contains comments in Chinese. To improve code clarity and maintainability for a broader audience, please translate these comments into English. For example:
|
||
| for b in range(batch_size): | ||
| for i in range(num_topk): | ||
| topk_abs_idx = top_attention_rank_index[i] | ||
| hidden_states[b, topk_abs_idx, :] /= topk_counter[b, i] | ||
|
Comment on lines
+179
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Comment on lines
+167
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part of the code has a couple of areas for improvement:
Here is a suggested vectorized implementation with English comments: # Vectorized token merging for performance.
# Determine which top-k token each non-top-k token will merge into.
topk_abs_indices_to_add_to = top_attention_rank_index[sim_max_index]
# Add non-top-k hidden states to their corresponding top-k hidden states.
source_states = hidden_states[:, non_topk_indices, :]
index_for_scatter = topk_abs_indices_to_add_to.unsqueeze(-1).expand_as(source_states)
hidden_states.scatter_add_(1, index_for_scatter, source_states)
# Update counters for averaging.
topk_counter.scatter_add_(1, sim_max_index.unsqueeze(-1), 1)
# Average the hidden states of the merged tokens.
hidden_states[:, top_attention_rank_index, :] /= topk_counter |
||
|
|
||
| keep_indexs = torch.cat( | ||
| ( | ||
| torch.arange( | ||
| image_token_start_index, | ||
| device=device, | ||
| ), | ||
| top_attention_rank_index, | ||
| torch.arange( | ||
| image_token_start_index | ||
| + image_token_length, | ||
| hidden_states.shape[1], | ||
| device=device, | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| self.blocks[self.pruning_loc].register_forward_pre_hook( | ||
| functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras), | ||
| with_kwargs=True, | ||
| ) | ||
| # sort index | ||
| keep_indexs = keep_indexs.sort().values | ||
| # filter hidden states & | ||
| hidden_states = hidden_states[:, keep_indexs, :] | ||
| # update position ids | ||
| position_ids = keep_indexs.unsqueeze(0) | ||
| # update attention mask | ||
| if causal_mask is not None: | ||
| causal_mask = causal_mask[:, :, :hidden_states.shape[1], :hidden_states.shape[1]] | ||
| kwargs['attention_mask'].resize_as_(causal_mask).copy_(causal_mask.clone()) | ||
| kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_( | ||
| position_ids.squeeze(0).clone()) | ||
| kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone()) | ||
|
|
||
| position_embeddings = kwargs['position_embeddings'] | ||
| index_dim = 1 if position_embeddings[0].dim() == 3 else 2 | ||
| new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone() | ||
| new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone() | ||
| position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) | ||
| position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) | ||
|
|
||
| return (hidden_states,), kwargs | ||
|
|
||
| def update_output_attentions_hook(module, args, kwargs): | ||
| kwargs['output_attentions'] = True | ||
| return args, kwargs | ||
|
|
||
| def store_attention_hook(m, x, layer_outputs, pruning_paras): | ||
| layer_attention = layer_outputs[1] | ||
| pruning_paras['attn_scores'] = layer_attention | ||
|
|
||
| if self.special_config['vision_token_length'] is None: | ||
| if self.model.__class__.__name__ == 'Llava': | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( | ||
| self.vtoken_length_for_llava_hook( | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal, | ||
| self.pruning_paras | ||
| ), self.model.vlm_model | ||
| ) | ||
|
|
||
| if self.special_config['metric'] == 'random': | ||
| self.blocks[self.pruning_loc].register_forward_pre_hook( | ||
| functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras), | ||
| with_kwargs=True | ||
| ) | ||
| elif self.special_config['metric'] == 'holitom_merge': | ||
| self.blocks[self.pruning_loc - 1].register_forward_pre_hook( | ||
| update_output_attentions_hook, | ||
| with_kwargs=True | ||
| ) | ||
| self.blocks[self.pruning_loc - 1].register_forward_hook( | ||
| functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), | ||
| ) | ||
| self.blocks[self.pruning_loc].register_forward_pre_hook( | ||
| functools.partial(holitom_merge_hook, pruning_paras=self.pruning_paras), | ||
| with_kwargs=True | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,8 +3,6 @@ | |
| from typing import Callable, Optional, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from loguru import logger | ||
| from transformers.models.clip.modeling_clip import CLIPEncoderLayer | ||
|
|
||
| from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY | ||
|
|
@@ -20,8 +18,7 @@ def __init__(self, config, model, blocks): | |
| self.patch_layer() | ||
|
|
||
| def add_sparse_config(self): | ||
| special_config = self.config.get('special', {}) | ||
| r_param = special_config.get('r', 0) | ||
| r_param = self.special_config.get('r', 0) | ||
| if isinstance(r_param, int) or isinstance(r_param, float): | ||
| self.r = [max(int(r_param), 0)] * len(self.blocks) | ||
| elif isinstance(r_param, (tuple, list)): | ||
|
|
@@ -36,19 +33,17 @@ def add_sparse_config(self): | |
| else: | ||
| raise ValueError('Invalid r format. Expected int or (start, step) tuple.') | ||
|
|
||
| self.pruning_paras = special_config | ||
| self.pruning_paras = self.special_config | ||
|
|
||
| def patch_layer(self): | ||
| for idx, block in enumerate(self.blocks): | ||
| if self.r[idx] > 0: | ||
| block.r = self.r[idx] | ||
| if isinstance(block, CLIPEncoderLayer): # llava | ||
| block.self_attn.original_forward = block.self_attn.forward | ||
| block.self_attn.forward = types.MethodType( | ||
| tome_CLIPSdpaAttention_forward, | ||
| block.self_attn | ||
| ) | ||
|
Comment on lines
43
to
46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You've removed the lines that back up the original It is generally a good practice to store the original method before overwriting it. If there's no strong reason to remove them, I'd recommend restoring these backups. |
||
| block.original_forward = block.forward | ||
| block.forward = types.MethodType( | ||
| tome_CLIPEncoderLayer_forward, | ||
| block | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -585,19 +585,8 @@ def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras): | |
| st_idx = torch.nonzero(img_mask, as_tuple=True)[0] | ||
|
|
||
| if st_idx.numel() > 0: | ||
| discontinuities = torch.where(st_idx[1:] - st_idx[:-1] != 1)[0] | ||
| if discontinuities.numel() > 0: | ||
| raise ValueError('Visual tokens are not contiguous in input_ids!') | ||
| segment_starts = [st_idx[0].item()] + [st_idx[i + 1].item() for i in discontinuities.tolist()] # noqa | ||
| segment_ends = [st_idx[i].item() for i in discontinuities.tolist()] + [st_idx[-1].item()] # noqa | ||
| offset = 0 | ||
| for first, last in zip(segment_starts, segment_ends): | ||
| length = last - first + 1 | ||
| # [15 1502] [1505 3289] | ||
| img_mask[first: last + 1] = ~select_mask[offset: offset + length] | ||
| else: | ||
| first, last = st_idx[0].item(), st_idx[-1].item() | ||
| img_mask[first: last + 1] = ~select_mask | ||
| first, last = st_idx[0].item(), st_idx[-1].item() | ||
| img_mask[first: last + 1] = ~select_mask | ||
|
Comment on lines
+588
to
+589
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous implementation checked for discontinuous visual tokens and raised an error. This check has been removed. The new logic The original check was safer. If discontinuous tokens are not supported, it's better to fail loudly with a Please consider restoring the check for contiguity or implementing logic to handle multiple segments of visual tokens. |
||
| img_mask = ~img_mask | ||
|
Comment on lines
+588
to
590
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to detect and handle discontinuous visual tokens has been removed. The previous code explicitly checked for this and would raise a If it's possible for visual tokens to be non-contiguous, this change could lead to incorrect behavior and silent errors by applying the Could you please confirm that visual tokens are guaranteed to be contiguous for |
||
| contextual_input_idx = false_pos[target_indices] + first | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backslash
\for line continuation here is unnecessary. Python implicitly continues lines inside parentheses, brackets, and braces, which is the case here with the square brackets for slicing. Removing the backslash and adjusting the indentation will make the code cleaner and more compliant with PEP 8.