Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 66 additions & 62 deletions QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,20 +196,21 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
)
cu_seqlens = torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), cu_seqlens])

deepstack_feature_lists = []
# deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
hidden_states
)
deepstack_feature_lists.append(deepstack_feature)
# if layer_num in self.deepstack_visual_indexes:
# deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
# hidden_states
# )
# deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
return hidden_states, deepstack_feature_lists
# return hidden_states, deepstack_feature_lists
return hidden_states


class QEffQwen3VLVisionAttention(Qwen3VLVisionAttention):
Expand Down Expand Up @@ -492,7 +493,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

layer_idx = 0
# layer_idx = 0
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -516,13 +517,13 @@ def forward(
if output_attentions:
all_self_attns += (layer_outputs[1],)

if deepstack_visual_embeds is not None and layer_idx in range(deepstack_visual_embeds.shape[0]):
hidden_states = self._deepstack_process(
hidden_states,
visual_pos_masks,
deepstack_visual_embeds[layer_idx],
)
layer_idx += 1
# if deepstack_visual_embeds is not None and layer_idx in range(deepstack_visual_embeds.shape[0]):
# hidden_states = self._deepstack_process(
# hidden_states,
# visual_pos_masks,
# deepstack_visual_embeds[layer_idx],
# )
# layer_idx += 1

hidden_states = self.norm(hidden_states)
if output_hidden_states:
Expand All @@ -540,20 +541,20 @@ def forward(

return (hidden_states, past_key_values)

def _deepstack_process(
self,
hidden_states: torch.Tensor,
visual_pos_masks: torch.Tensor,
visual_embeds: torch.Tensor,
):
visual_pos_masks = visual_pos_masks.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
hidden_states = hidden_states.clone()
mixed_embeds = hidden_states + visual_embeds
# def _deepstack_process(
# self,
# hidden_states: torch.Tensor,
# visual_pos_masks: torch.Tensor,
# visual_embeds: torch.Tensor,
# ):
# visual_pos_masks = visual_pos_masks.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
# visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
# hidden_states = hidden_states.clone()
# mixed_embeds = hidden_states + visual_embeds

local_this = torch.where(visual_pos_masks, mixed_embeds, hidden_states)
# local_this = torch.where(visual_pos_masks, mixed_embeds, hidden_states)

return local_this
# return local_this


class QEffQwen3VLEncoderWrapper(nn.Module):
Expand All @@ -572,15 +573,17 @@ def get_submodules_for_export(self) -> Type[nn.Module]:
return {self.model.visual.blocks[0].__class__}

def forward(self, pixel_values, image_grid_thw):
image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw)
# image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw)
image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)
bs = image_grid_thw.shape[0]
split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs)
image_embeds = image_embeds.reshape(bs, split_size, image_embeds.size(1))
deepstack_features = torch.stack(
[feature.reshape(bs, split_size, feature.size(1)) for feature in deepstack_feature_lists],
dim=0, # new axis for "features"
)
return image_embeds, deepstack_features
# deepstack_features = torch.stack(
# [feature.reshape(bs, split_size, feature.size(1)) for feature in deepstack_feature_lists],
# dim=0, # new axis for "features"
# )
# return image_embeds, deepstack_features
return image_embeds


class QEffQwen3VLDecoderWrapper(nn.Module):
Expand All @@ -602,7 +605,7 @@ def forward(
self,
input_ids,
vision_embeds,
deepstack_features,
# deepstack_features,
position_ids,
image_idx,
past_key_values,
Expand All @@ -617,20 +620,20 @@ def forward(
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]

num_features, bs, split_size, C = deepstack_features.shape
x = deepstack_features.reshape(num_features, bs * split_size, C)
deepstack_features_expanded = x[:, indices1, :]
# num_features, bs, split_size, C = deepstack_features.shape
# x = deepstack_features.reshape(num_features, bs * split_size, C)
# deepstack_features_expanded = x[:, indices1, :]
image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds)

image_mask = selected.clone()
# image_mask = selected.clone()

visual_pos_masks = None
deepstack_visual_embeds = None
# visual_pos_masks = None
# deepstack_visual_embeds = None

if image_mask is not None:
visual_pos_masks = image_mask
deepstack_visual_embeds = deepstack_features_expanded
# if image_mask is not None:
# visual_pos_masks = image_mask
# deepstack_visual_embeds = deepstack_features_expanded

outputs = self.language_model(
inputs_embeds=inputs_embeds,
Expand All @@ -639,14 +642,15 @@ def forward(
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
use_cache=True,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
# visual_pos_masks=visual_pos_masks,
# deepstack_visual_embeds=deepstack_visual_embeds,
)
logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index]
logits = self.model.lm_head(hidden_states)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values
# return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values
return logits, vision_embeds, image_idx, outputs.past_key_values


class QEffQwen3VLModel(Qwen3VLModel):
Expand Down Expand Up @@ -770,12 +774,12 @@ def get_dummy_inputs(
inputs_shapes["pixel_values"] = (748, 1536)
inputs_shapes["image_idx"] = (1, 1)
inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2)
inputs_shapes["deepstack_features"] = (
len(self.config.vision_config.deepstack_visual_indexes),
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
vision_size,
self.model.config.vision_config.out_hidden_size,
)
# inputs_shapes["deepstack_features"] = (
# len(self.config.vision_config.deepstack_visual_indexes),
# constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
# vision_size,
# self.model.config.vision_config.out_hidden_size,
# )

vision_inputs = {}
lang_inputs = {}
Expand All @@ -793,7 +797,7 @@ def get_dummy_inputs(
.repeat(4, 1, 1)
)
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64)
lang_inputs["deepstack_features"] = torch.zeros((inputs_shapes["deepstack_features"]), dtype=torch.float32)
# lang_inputs["deepstack_features"] = torch.zeros((inputs_shapes["deepstack_features"]), dtype=torch.float32)
# Add data for KV

bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
Expand Down Expand Up @@ -921,7 +925,7 @@ def smart_resize(
"time": time,
"grid_h": grid_h,
"grid_w": grid_w,
"num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
# "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
}
]

Expand All @@ -936,7 +940,7 @@ def smart_resize(
"vision_size": vision_size,
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
"vision_batch_size": batch_size,
"num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
# "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
}

if continuous_batching:
Expand All @@ -956,7 +960,7 @@ def smart_resize(
"vision_size": vision_size,
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
"vision_batch_size": batch_size,
"num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
# "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
}

if continuous_batching:
Expand All @@ -972,7 +976,7 @@ def smart_resize(
"ctx_len": ctx_len,
"vision_size": vision_size,
"vision_batch_size": batch_size,
"num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
# "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
}

if continuous_batching:
Expand All @@ -988,7 +992,7 @@ def smart_resize(
"ctx_len": ctx_len,
"vision_size": vision_size,
"vision_batch_size": batch_size,
"num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
# "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes),
}

if continuous_batching:
Expand Down Expand Up @@ -1017,14 +1021,14 @@ def get_onnx_dynamic_axes(
vision_dynamic_axes = {
"pixel_values": {0: "grid_height", 1: "grid_width"},
"image_grid_thw": {0: "batch_size", 1: "time", 2: "grid_h", 3: "grid_w"},
"deepstack_features": {0: "num_feature_layers", 1: "batch_size", 2: "vision_size"},
# "deepstack_features": {0: "num_feature_layers", 1: "batch_size", 2: "vision_size"},
}

lang_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {1: "batch_size", 2: "seq_len"},
"vision_embeds": {0: "vision_batch_size", 1: "vision_size"},
"deepstack_features": {0: "num_feature_layers", 1: "vision_batch_size", 2: "vision_size"},
# "deepstack_features": {0: "num_feature_layers", 1: "vision_batch_size", 2: "vision_size"},
}

for i in range(num_layers):
Expand Down Expand Up @@ -1055,7 +1059,7 @@ def get_onnx_dynamic_axes(

def get_output_names(self, kv_offload: bool = False):
vision_output_names = ["vision_embeds"]
vision_output_names.append("deepstack_features")
# vision_output_names.append("deepstack_features")
lang_output_names = ["logits"]
for i in range(self.model.config.text_config.num_hidden_layers):
for kv in ["key", "value"]:
Expand All @@ -1065,7 +1069,7 @@ def get_output_names(self, kv_offload: bool = False):
if kv_offload:
lang_output_names.insert(1, "vision_embeds_RetainedState")
lang_output_names.insert(2, "image_idx_output")
lang_output_names.insert(2, "deepstack_features_RetainedState")
# lang_output_names.insert(2, "deepstack_features_RetainedState")
output_names["vision"] = vision_output_names
output_names["lang"] = lang_output_names
else:
Expand Down
Loading
Loading