Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions configs/sparsification/methods/VisionZip/visionzip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ sparse:
method: TokenReduction
special:
method: VisionZip # retain
dominant: 191 # visual_tokens = dominan_tokens + 1(cls_token)
contextual: 30
dominant: 162 # visual_tokens = dominan_tokens + contextual
contextual: 30 # llava: 162+30,108+20,54+10 llava_next: 108+20,54+10,27+5
prune_only: False
merge_only: False
save:
save_trans: False
save_fake: False
Expand Down
4 changes: 1 addition & 3 deletions llmc/compression/token_reduction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def prepare_inputs_labels_for_multimodal_with_index_masks(
if 'maxpool2x2' in mm_patch_merge_type:
raise NotImplementedError
elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio:
NotImplementedError
raise NotImplementedError

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This is a critical bug fix. The statement NotImplementedError by itself is a no-op and does not raise an exception. By adding raise, you ensure that an exception is thrown, preventing the program from continuing with an unsupported configuration, which could lead to silent failures or incorrect behavior.

elif 'unpad' in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
Expand Down Expand Up @@ -446,7 +446,6 @@ def prepare_inputs_labels_for_multimodal_with_index_masks(

cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

# import pdb; pdb.set_trace()
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)

Expand Down Expand Up @@ -554,7 +553,6 @@ def prepare_inputs_labels_for_multimodal_with_index_masks(
right_add = random.randint(left_add, self.config.pos_skipping_range)
position_ids[:, :split_position] += left_add
position_ids[:, split_position:] += right_add
# import pdb; pdb.set_trace()
# rank0_print("Finish preparing")
# print(vtoken_length)
return None, position_ids, attention_mask, past_key_values, \
Expand Down
86 changes: 66 additions & 20 deletions llmc/compression/token_reduction/visionzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import apply_info, prefill_wrapper
from .utils import (apply_info, prefill_wrapper,
prepare_inputs_labels_for_multimodal_with_index_masks)


def visionzip_forward(
Expand Down Expand Up @@ -286,15 +287,19 @@ def __init__(self, config, model, blocks):
self.register_reduction_modules()

def add_sparse_config(self):
special_config = self.config.get('special', {})
self.dominant = special_config['dominant']
self.contextual = special_config['contextual']
self.dominant = self.special_config['dominant']
self.contextual = self.special_config['contextual']

self.pruning_paras = special_config
self.pruning_paras = self.special_config
prune_only = self.special_config.get('prune_only', False)
merge_only = self.special_config.get('merge_only', False)
assert not (prune_only and merge_only), 'prune_only and merge_only cannot both be True'
self.pruning_paras['prune_only'] = prune_only
self.pruning_paras['merge_only'] = merge_only

def register_reduction_modules(self):

def visionzip_hook(m, images, image_forward_outs):
def visionzip_hook(m, images, image_forward_outs, pruning_paras, llava_next):
attn_weights = image_forward_outs.attentions[-2]
hidden_states = image_forward_outs.hidden_states[-2]
metric = self.blocks[-2].self_attn.k_proj.metric
Expand All @@ -306,17 +311,22 @@ def visionzip_hook(m, images, image_forward_outs):
cls_attention = attn_weights[:, :, cls_idx, cls_idx + 1:]
cls_attention_sum = cls_attention.sum(dim=1)
topk_indices = cls_attention_sum.topk(dominant_num, dim=1).indices + 1
all_indices = torch.cat(
[
torch.zeros(
(hidden_states.shape[0], 1),
dtype=topk_indices.dtype,
device=topk_indices.device,
),
topk_indices,
],
dim=1,
)
if pruning_paras['merge_only']:
all_indices = torch.zeros(
(hidden_states.shape[0], 1),
dtype=topk_indices.dtype, device=topk_indices.device
)
dominant_num = 0
else:
all_indices = torch.cat(
[
torch.zeros(
(hidden_states.shape[0], 1),
dtype=topk_indices.dtype, device=topk_indices.device,
),
topk_indices,
], dim=1,
)
Comment on lines 313 to +329

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The topk operation can be computationally expensive. In the current implementation, it's called on line 313 regardless of whether pruning_paras['merge_only'] is true. When merge_only is true, the result of topk is not fully utilized (only its dtype and device attributes are used), leading to unnecessary computation.

To improve efficiency, consider moving the topk call into the else block, so it's only executed when its result is actually needed.

Suggested change
topk_indices = cls_attention_sum.topk(dominant_num, dim=1).indices + 1
all_indices = torch.cat(
[
torch.zeros(
(hidden_states.shape[0], 1),
dtype=topk_indices.dtype,
device=topk_indices.device,
),
topk_indices,
],
dim=1,
)
if pruning_paras['merge_only']:
all_indices = torch.zeros(
(hidden_states.shape[0], 1),
dtype=topk_indices.dtype, device=topk_indices.device
)
dominant_num = 0
else:
all_indices = torch.cat(
[
torch.zeros(
(hidden_states.shape[0], 1),
dtype=topk_indices.dtype, device=topk_indices.device,
),
topk_indices,
], dim=1,
)
if pruning_paras['merge_only']:
all_indices = torch.zeros(
(hidden_states.shape[0], 1),
dtype=torch.long, device=hidden_states.device
)
dominant_num = 0
else:
topk_indices = cls_attention_sum.topk(dominant_num, dim=1).indices + 1
all_indices = torch.cat(
[
torch.zeros(
(hidden_states.shape[0], 1),
dtype=topk_indices.dtype, device=topk_indices.device,
),
topk_indices,
], dim=1,
)


mask = torch.ones_like(
hidden_states[:, :, 0], dtype=torch.bool, device=metric.device
Expand Down Expand Up @@ -355,6 +365,15 @@ def visionzip_hook(m, images, image_forward_outs):
target_indices = torch.arange(
0, metric_normalized.shape[1], step, device=metric_normalized.device
)[:contextual_num]

# keep_idxs
index_masks = ~mask
if not pruning_paras['prune_only']:
pruned_indices = mask.nonzero(as_tuple=False)[:, 1].view(hidden_states.shape[0], -1)
target_index = pruned_indices[:, target_indices]
index_masks.scatter_(1, target_index, True)
pruning_paras['index_masks'] = index_masks[:, 1:]

target_tokens = metric_normalized[:, target_indices, :]

tokens_to_merge = metric_normalized[
Expand Down Expand Up @@ -401,9 +420,15 @@ def visionzip_hook(m, images, image_forward_outs):
).to(images[0].dtype)

res = list(image_forward_outs.hidden_states)
res[-2] = hidden_states_save.contiguous()
if not llava_next:
if pruning_paras['prune_only']:
res[-2] = dominant_tokens.contiguous().to(images[0].dtype)
else:
res[-2] = hidden_states_save.contiguous()
image_forward_outs.hidden_states = tuple(res)

return image_forward_outs

def store_key_hook(m, x, outputs):
bsz = x[0].shape[0]
raw_outputs = (
Expand All @@ -418,10 +443,13 @@ def update_output_attentions_hook(module, args, kwargs):
kwargs['output_attentions'] = True
return args, kwargs

def update_index_masks_hook(module, inps, outs, pruning_paras):
module.index_masks = pruning_paras['index_masks']

if self.model.__class__.__name__ == 'LlavaHf':
vision_tower = self.model.vlm_model.vision_tower
elif self.model.__class__.__name__ == 'Llava':
vision_tower = self.model.vlm_model.model.vision_tower.vision_tower
vision_tower = self.model.vision_model.vision_tower

if self.model.__class__.__name__ in ('LlavaHf', 'Llava'):
apply_info(
Expand All @@ -444,7 +472,25 @@ def update_output_attentions_hook(module, args, kwargs):
block.self_attn.k_proj.head_dim = block.self_attn.head_dim
block.self_attn.k_proj.register_forward_hook(store_key_hook)

vision_tower.register_forward_hook(visionzip_hook)
vision_tower.register_forward_hook(
functools.partial(
visionzip_hook,
pruning_paras=self.pruning_paras,
llava_next=self.special_config['vision_token_length'] is None
)
)

# llava_next
if self.special_config['vision_token_length'] is None:

self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
prepare_inputs_labels_for_multimodal_with_index_masks,
self.model.vlm_model
)

self.model.vision_model.register_forward_hook(
functools.partial(update_index_masks_hook, pruning_paras=self.pruning_paras),
)

def get_metric(fn, pruning_paras):
@wraps(fn)
Expand Down