Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llmc/compression/token_reduction/dycoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

try:
from llava.model.llava_arch import LlavaMetaForCausalLM
except ModuleNotFoundError:
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
except ImportError:
pass
from transformers.cache_utils import Cache, DynamicCache

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/token_reduction/fastvid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from llava.model.multimodal_encoder.siglip_encoder import (
SigLipVisionConfig, SigLipVisionModel)
from llava.utils import rank0_print
except ModuleNotFoundError:
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
except ImportError:
pass

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/token_reduction/holitom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from llava.utils import rank0_print
from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
except ModuleNotFoundError:
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
except ImportError:
pass

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/token_reduction/prunevid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

try:
from llava.model.llava_arch import LlavaMetaForCausalLM
except ModuleNotFoundError:
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
except ImportError:
pass

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

Expand Down
222 changes: 162 additions & 60 deletions llmc/compression/token_reduction/random.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import functools
from functools import wraps
from types import MethodType

import torch
from loguru import logger

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

Expand All @@ -21,61 +19,21 @@ def __init__(self, config, model, blocks):
def add_sparse_config(self):

self.pruning_loc = self.special_config['pruning_loc']
self.special_config['image_token_length'] = self.model.pruning_config[
'image_token_length'
]

self.pruning_paras = self.special_config

def register_reduction_modules(self):

def input_hook_llava(fn, pruning_paras):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if len(args) == 0:
return fn(*args, **kwargs)
input_args = args[0]
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
return fn(*args, **kwargs)

input_ids = args[0]
attention_mask = args[2]
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
pruning_paras['image_token_start_index'] = torch.where(token_indices)[
0
][0].item()

outputs = fn(*args, **kwargs)
return outputs

return wrapper

@prefill_wrapper
def input_hook(module, input_args, pruning_paras):
input_ids = input_args[0]
image_token_idxs = (
input_ids[0] == pruning_paras['vision_token_index']
).nonzero(as_tuple=True)[0]
pruning_paras['image_token_start_index'] = image_token_idxs[0].item()

return input_args

@prefill_wrapper
def random_pruning_hook(module, args, kwargs, pruning_paras):

logger.info(' ========random_pruning_hook======== ')

rate = pruning_paras['rate']
image_token_start_index = pruning_paras['image_token_start_index']
image_token_length = pruning_paras['image_token_length']
rate = pruning_paras['prune_ratio']
image_token_start_index = pruning_paras['vision_token_start_index']
image_token_length = pruning_paras['vision_token_length']

hidden_states = args[0]
causal_mask = kwargs['attention_mask']

logger.info(f'before hidden_states : {hidden_states.shape}')

device = hidden_states.device

vision_indexes = torch.arange(
image_token_start_index,
image_token_start_index + image_token_length,
Expand Down Expand Up @@ -130,25 +88,169 @@ def random_pruning_hook(module, args, kwargs, pruning_paras):
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)

logger.info(f'after hidden_states : {hidden_states.shape}')
return (hidden_states,), kwargs

if self.model.__class__.__name__ == 'LlavaHf':
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_paras=self.pruning_paras)
@prefill_wrapper
def holitom_merge_hook(module, args, kwargs, pruning_paras):

rate = pruning_paras['prune_ratio']
image_token_start_index = pruning_paras['vision_token_start_index']
image_token_length = pruning_paras['vision_token_length']

hidden_states = args[0]
causal_mask = kwargs['attention_mask']

device = hidden_states.device
last_layer_attention = pruning_paras['attn_scores']
# compute average attention over different head
last_layer_attention_avg = torch.mean(
last_layer_attention, dim=1
)[0]
# generate new attention mask based on the average attention,
# sample the top ATTENTION_RANK tokens with highest attention
last_layer_attention_avg_last_tok = (
last_layer_attention_avg[-1]
)
# get the attention in image token
last_layer_attention_avg_last_tok_image = \
last_layer_attention_avg_last_tok[
image_token_start_index:
image_token_start_index + image_token_length
]
Comment on lines +115 to +119

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The backslash \ for line continuation here is unnecessary. Python implicitly continues lines inside parentheses, brackets, and braces, which is the case here with the square brackets for slicing. Removing the backslash and adjusting the indentation will make the code cleaner and more compliant with PEP 8.

            last_layer_attention_avg_last_tok_image = last_layer_attention_avg_last_tok[
                image_token_start_index:
                image_token_start_index + image_token_length
            ]

# get the indexes of the top ATTENTION_RANK tokens
top_attention_rank_index = (
last_layer_attention_avg_last_tok_image.topk(
round(
image_token_length * (1 - rate)
)
).indices
+ image_token_start_index
)
elif self.model.__class__.__name__ == 'Llava':
from llava.constants import IMAGE_TOKEN_INDEX

hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras,
all_indices = torch.arange(
image_token_length, device=device
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
non_topk_mask = ~torch.isin(
all_indices,
top_attention_rank_index
- image_token_start_index,
)
non_topk_indices = (
all_indices[non_topk_mask]
+ image_token_start_index
)
non_topk_states = hidden_states[
:, non_topk_indices, :
] # [batch_size, len(non_topk), hidden_size]
topk_states = hidden_states[
:, top_attention_rank_index, :
] # [batch_size, len(topk), hidden_size]
non_topk_norm = torch.norm(
non_topk_states, dim=-1, keepdim=True
) # [batch_size, len(non_topk), 1]
topk_norm = torch.norm(
topk_states, dim=-1, keepdim=True
) # [batch_size, len(topk), 1]
dot_product = torch.bmm(
non_topk_states, topk_states.transpose(1, 2)
) # [batch_size, len(non_topk), len(topk)]
sim_matrix = dot_product / (
non_topk_norm * topk_norm.transpose(1, 2)
)
sim_max, sim_max_index = torch.max(sim_matrix, dim=-1)

batch_size = hidden_states.size(0)
num_topk = len(top_attention_rank_index)
num_non_topk = len(non_topk_indices)
topk_counter = torch.ones((batch_size, num_topk, 1), device=hidden_states.device)

for b in range(batch_size):
for i in range(num_non_topk):
topk_rel_idx = sim_max_index[b, i].item() # 这是 topk 中的相对索引
topk_abs_idx = top_attention_rank_index[topk_rel_idx] # 得到绝对索引
non_topk_abs_idx = non_topk_indices[i]

# 累加non-topk到topk token上(就地)
hidden_states[b, topk_abs_idx, :] += hidden_states[b, non_topk_abs_idx, :]
# 增加计数
topk_counter[b, topk_rel_idx] += 1

Comment on lines +167 to +177

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nested Python loops over batch_size and num_non_topk to merge token states can be very inefficient, especially for larger batch sizes, as it prevents vectorized execution on the GPU. This part of the code can be a significant performance bottleneck.

Consider replacing the loop with vectorized PyTorch operations. You can use torch.scatter_add_ to perform the summation of non-topk states into topk states in a single, efficient operation. This will also require vectorizing the topk_counter update.

# 平均化所有topk token(包含自己和所有被合并的)
Comment on lines +169 to +178

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This code contains comments in Chinese. To improve code clarity and maintainability for a broader audience, please translate these comments into English.

For example:

  • # 这是 topk 中的相对索引 -> # This is the relative index within topk
  • # 得到绝对索引 -> # Get the absolute index
  • # 累加non-topk到topk token上(就地) -> # Accumulate non-topk token to topk token (in-place)

for b in range(batch_size):
for i in range(num_topk):
topk_abs_idx = top_attention_rank_index[i]
hidden_states[b, topk_abs_idx, :] /= topk_counter[b, i]
Comment on lines +179 to +182

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This loop for averaging the token states can be vectorized for better performance. A simple element-wise division on the sliced tensor would be much more efficient than iterating through the batch and tokens in Python.

hidden_states[:, top_attention_rank_index, :] /= topk_counter

Comment on lines +167 to +182

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This part of the code has a couple of areas for improvement:

  1. Performance: The nested Python loops iterating over batch_size and token numbers can be a performance bottleneck. These operations can be vectorized using torch.scatter_add_ for significant speedup.
  2. Maintainability: There are comments in Chinese. For better code maintainability and to make it accessible to a wider range of contributors, they should be translated to English.

Here is a suggested vectorized implementation with English comments:

            # Vectorized token merging for performance.
            # Determine which top-k token each non-top-k token will merge into.
            topk_abs_indices_to_add_to = top_attention_rank_index[sim_max_index]

            # Add non-top-k hidden states to their corresponding top-k hidden states.
            source_states = hidden_states[:, non_topk_indices, :]
            index_for_scatter = topk_abs_indices_to_add_to.unsqueeze(-1).expand_as(source_states)
            hidden_states.scatter_add_(1, index_for_scatter, source_states)

            # Update counters for averaging.
            topk_counter.scatter_add_(1, sim_max_index.unsqueeze(-1), 1)

            # Average the hidden states of the merged tokens.
            hidden_states[:, top_attention_rank_index, :] /= topk_counter


keep_indexs = torch.cat(
(
torch.arange(
image_token_start_index,
device=device,
),
top_attention_rank_index,
torch.arange(
image_token_start_index
+ image_token_length,
hidden_states.shape[1],
device=device,
),
)
)

self.blocks[self.pruning_loc].register_forward_pre_hook(
functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras),
with_kwargs=True,
)
# sort index
keep_indexs = keep_indexs.sort().values
# filter hidden states &
hidden_states = hidden_states[:, keep_indexs, :]
# update position ids
position_ids = keep_indexs.unsqueeze(0)
# update attention mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :hidden_states.shape[1], :hidden_states.shape[1]]
kwargs['attention_mask'].resize_as_(causal_mask).copy_(causal_mask.clone())
kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_(
position_ids.squeeze(0).clone())
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())

position_embeddings = kwargs['position_embeddings']
index_dim = 1 if position_embeddings[0].dim() == 3 else 2
new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone()
new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone()
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)

return (hidden_states,), kwargs

def update_output_attentions_hook(module, args, kwargs):
kwargs['output_attentions'] = True
return args, kwargs

def store_attention_hook(m, x, layer_outputs, pruning_paras):
layer_attention = layer_outputs[1]
pruning_paras['attn_scores'] = layer_attention

if self.special_config['vision_token_length'] is None:
if self.model.__class__.__name__ == 'Llava':
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
self.vtoken_length_for_llava_hook(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
), self.model.vlm_model
)

if self.special_config['metric'] == 'random':
self.blocks[self.pruning_loc].register_forward_pre_hook(
functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)
elif self.special_config['metric'] == 'holitom_merge':
self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
update_output_attentions_hook,
with_kwargs=True
)
self.blocks[self.pruning_loc - 1].register_forward_hook(
functools.partial(store_attention_hook, pruning_paras=self.pruning_paras),
)
self.blocks[self.pruning_loc].register_forward_pre_hook(
functools.partial(holitom_merge_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)
9 changes: 2 additions & 7 deletions llmc/compression/token_reduction/tome.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import Callable, Optional, Tuple

import torch
import torch.nn.functional as F
from loguru import logger
from transformers.models.clip.modeling_clip import CLIPEncoderLayer

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
Expand All @@ -20,8 +18,7 @@ def __init__(self, config, model, blocks):
self.patch_layer()

def add_sparse_config(self):
special_config = self.config.get('special', {})
r_param = special_config.get('r', 0)
r_param = self.special_config.get('r', 0)
if isinstance(r_param, int) or isinstance(r_param, float):
self.r = [max(int(r_param), 0)] * len(self.blocks)
elif isinstance(r_param, (tuple, list)):
Expand All @@ -36,19 +33,17 @@ def add_sparse_config(self):
else:
raise ValueError('Invalid r format. Expected int or (start, step) tuple.')

self.pruning_paras = special_config
self.pruning_paras = self.special_config

def patch_layer(self):
for idx, block in enumerate(self.blocks):
if self.r[idx] > 0:
block.r = self.r[idx]
if isinstance(block, CLIPEncoderLayer): # llava
block.self_attn.original_forward = block.self_attn.forward
block.self_attn.forward = types.MethodType(
tome_CLIPSdpaAttention_forward,
block.self_attn
)
Comment on lines 43 to 46

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

You've removed the lines that back up the original forward methods before patching them. While this simplifies the code, it's a risky change as it removes the ability to revert the patch or to call the original methods from the patched versions. This can make debugging more difficult and reduces the modularity of the patch.

It is generally a good practice to store the original method before overwriting it. If there's no strong reason to remove them, I'd recommend restoring these backups.

block.original_forward = block.forward
block.forward = types.MethodType(
tome_CLIPEncoderLayer_forward,
block
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/token_reduction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

try:
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
except Exception as e:
logger.debug('LLaVA is not installed. Please install LLaVA to use this model.\nError: %s' % e)
except ImportError:
pass
import random


Expand Down
15 changes: 2 additions & 13 deletions llmc/compression/token_reduction/visionzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,19 +585,8 @@ def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras):
st_idx = torch.nonzero(img_mask, as_tuple=True)[0]

if st_idx.numel() > 0:
discontinuities = torch.where(st_idx[1:] - st_idx[:-1] != 1)[0]
if discontinuities.numel() > 0:
raise ValueError('Visual tokens are not contiguous in input_ids!')
segment_starts = [st_idx[0].item()] + [st_idx[i + 1].item() for i in discontinuities.tolist()] # noqa
segment_ends = [st_idx[i].item() for i in discontinuities.tolist()] + [st_idx[-1].item()] # noqa
offset = 0
for first, last in zip(segment_starts, segment_ends):
length = last - first + 1
# [15 1502] [1505 3289]
img_mask[first: last + 1] = ~select_mask[offset: offset + length]
else:
first, last = st_idx[0].item(), st_idx[-1].item()
img_mask[first: last + 1] = ~select_mask
first, last = st_idx[0].item(), st_idx[-1].item()
img_mask[first: last + 1] = ~select_mask
Comment on lines +588 to +589

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The previous implementation checked for discontinuous visual tokens and raised an error. This check has been removed. The new logic first, last = st_idx[0].item(), st_idx[-1].item() assumes that the visual tokens are always in a contiguous block. If st_idx can contain indices of multiple, separate blocks of visual tokens, this assumption is incorrect. For example, if st_idx is [10, 11, 100, 101], first will be 10 and last will be 101. The slice img_mask[first: last + 1] would incorrectly include non-visual tokens from index 12 to 99. This could lead to silent and incorrect behavior.

The original check was safer. If discontinuous tokens are not supported, it's better to fail loudly with a ValueError. If they are to be supported, the logic should handle them correctly. The original code had unreachable code after raise, but the check itself was valuable.

Please consider restoring the check for contiguity or implementing logic to handle multiple segments of visual tokens.

img_mask = ~img_mask
Comment on lines +588 to 590

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to detect and handle discontinuous visual tokens has been removed. The previous code explicitly checked for this and would raise a ValueError, which is a safe way to handle unexpected input. The new implementation assumes that the visual tokens are always in a contiguous block by just using the first and last indices.

If it's possible for visual tokens to be non-contiguous, this change could lead to incorrect behavior and silent errors by applying the select_mask to a range that includes non-visual tokens. The raise statement in the original code was followed by unreachable code, which was a bug, but removing the check entirely might be unsafe.

Could you please confirm that visual tokens are guaranteed to be contiguous for qwenv25vl? If not, it would be safer to restore a check for contiguity.

contextual_input_idx = false_pos[target_indices] + first

Expand Down
Loading