Skip to content

Commit 1a01d57

Browse files
committed
Update to align with recent VLM CB changes
Signed-off-by: quic-xiyushi <xiyushi@qti.qualcomm.com>
1 parent 3e242ce commit 1a01d57

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -721,14 +721,17 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
721721
]
722722
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
723723

724-
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
724+
def __init__(self, model, continuous_batching: bool = False, qaic_config: Optional[dict] = None, **kwargs):
725725
"""
726726
Initializes the language decoder component for multimodal models.
727727
728728
Parameters
729729
----------
730730
model : nn.Module
731731
The full HuggingFace multimodal model from which the language decoder is extracted.
732+
continuous_batching : bool, optional
733+
If True, enables continuous batching mode for future compilation and execution.
734+
This setting must be consistent across `from_pretrained` and `compile` calls. Default is False.
732735
qaic_config : dict, optional
733736
A dictionary for QAIC-specific configurations.
734737
Only the following keys are supported by the text model of the dual QPC multimodal model:
@@ -741,6 +744,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
741744
super().__init__(model, **kwargs)
742745
self.model = model.get_qeff_language_decoder()
743746
self.hash_params["qeff_auto_class"] = self.__class__.__name__
747+
self.continuous_batching = continuous_batching
744748
self.model.qaic_config = qaic_config
745749
# ---Sampling---
746750
# Note: SamplerTransform should be applied after all other transforms
@@ -804,6 +808,7 @@ def get_sampling_inputs_and_outputs(
804808
sampling-related parameters.
805809
"""
806810
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
811+
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
807812

808813
assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling"
809814

@@ -816,10 +821,10 @@ def get_sampling_inputs_and_outputs(
816821
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
817822

818823
example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
819-
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
824+
(fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool
820825
)
821826
dynamic_axes["past_repetition_penalty_buffer"] = {
822-
0: "batch_size",
827+
0: "full_batch_size" if self.continuous_batching else "batch_size",
823828
}
824829
output_names.append("past_repetition_penalty_buffer_RetainedState")
825830

@@ -829,10 +834,10 @@ def get_sampling_inputs_and_outputs(
829834
dynamic_axes["repetition_penalties"] = {0: "batch_size"}
830835

831836
example_inputs["past_presence_penalty_buffer"] = torch.zeros(
832-
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
837+
(fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool
833838
)
834839
dynamic_axes["past_presence_penalty_buffer"] = {
835-
0: "batch_size",
840+
0: "full_batch_size" if self.continuous_batching else "batch_size",
836841
}
837842
output_names.append("past_presence_penalty_buffer_RetainedState")
838843

@@ -981,7 +986,7 @@ def __init__(
981986
self.model = model
982987
self.config = model.config
983988
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
984-
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
989+
self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs)
985990
self.continuous_batching = continuous_batching
986991
self.input_shapes, self.output_names = None, None
987992

QEfficient/transformers/sampler/sampler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,17 @@ def sampler_forward(
175175
Must be in [-1, 1].
176176
"""
177177
if vision_embeds is not None:
178-
logits, vision_embeds, image_idx, past_key_values = self.old_forward(
178+
forward_kwargs = dict(
179179
input_ids=input_ids,
180180
vision_embeds=vision_embeds,
181181
position_ids=position_ids,
182182
image_idx=image_idx,
183183
past_key_values=past_key_values,
184184
)
185+
if batch_index is not None:
186+
forward_kwargs["batch_index"] = batch_index
187+
188+
logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs)
185189
outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values)
186190
if position_ids.dim() == 3: # For models using m-rope
187191
position_ids = position_ids[0]

0 commit comments

Comments
 (0)