Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
write_io_files,
)
from QEfficient.utils import LRUCache
from QEfficient.utils.constants import Constants
from QEfficient.utils.logging_utils import logger


Expand Down Expand Up @@ -303,6 +304,13 @@ def _execute_chunked_prefill(
prefill_ccl_id = 0
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

if self.include_sampler:
for op in Constants.SAMPLER_OPS:
if decode_batch_id is not None:
lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
lang_inputs[op] = self.sampling_params[op]

for i in range(num_chunks):
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
position_ids_slice = lang_inputs["position_ids"][
Expand All @@ -328,6 +336,11 @@ def _execute_chunked_prefill(

chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]

if self.include_sampler:
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
chunk_inputs[op] = lang_inputs[op]

outputs = self._session.run(chunk_inputs)

if "image_idx_output" in outputs:
Expand Down
139 changes: 47 additions & 92 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from pathlib import Path
from time import perf_counter
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -64,6 +64,7 @@
)
from QEfficient.utils.check_ccl_specializations import process_ccl_specializations
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs


class QEFFTransformersBase(QEFFBaseModel):
Expand Down Expand Up @@ -751,19 +752,22 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model, **kwargs):
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
"""
Initializes the language decoder component for multimodal models.

Parameters
----------
model : nn.Module
The full HuggingFace multimodal model from which the language decoder is extracted.
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
super().__init__(model, **kwargs)
super().__init__(model, qaic_config=qaic_config, **kwargs)
self.model = model.get_qeff_language_decoder()
self.model.qaic_config = qaic_config
self.hash_params["qeff_auto_class"] = self.__class__.__name__

def export(
Expand Down Expand Up @@ -919,25 +923,30 @@ def __init__(
----------
model : nn.Module
The full HuggingFace multimodal model.
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
**kwargs :
Additional keyword arguments. `full_batch_size` is not supported here.

Raises
------
NotImplementedError
If `full_batch_size` is provided.
Additional keyword arguments.
"""
if kwargs.pop("full_batch_size", None):
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
continuous_batching = True
warnings.warn(
"full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
)
self.model = model
self.config = model.config

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)

self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs)
self.continuous_batching = continuous_batching
self.input_shapes, self.output_names = None, None
# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
# are done. The role of the sampler is to just add nodes at the output of the
# previous transform function.
self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs)

@property
def model_name(self) -> str:
Expand Down Expand Up @@ -1062,6 +1071,19 @@ def export(
kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode
)
output_names = self.model.get_output_names(kv_offload=True)
if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get(
"include_sampler", False
):
logits_index = output_names["lang"].index("logits")
output_names["lang"][logits_index] = "next_tokens"
inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs(
example_inputs=inputs["lang"],
output_names=output_names["lang"],
dynamic_axes=dynamic_axes["lang"],
continuous_batching=self.continuous_batching,
vocab_size=self.model.language_model.config.vocab_size,
qaic_config=self.lang_model.model.qaic_config,
)

self.vision_model.export(
inputs["vision"],
Expand Down Expand Up @@ -1279,6 +1301,7 @@ def generate(
device_ids: List[int] = None,
runtime_ai100: bool = True,
generation_len: Optional[int] = None,
**kwargs,
) -> Union[torch.Tensor, np.ndarray]:
"""
Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards.
Expand Down Expand Up @@ -1337,6 +1360,7 @@ def generate(
full_batch_size=fbs,
comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill,
comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
**kwargs,
)

# Call generate method
Expand Down Expand Up @@ -1616,10 +1640,15 @@ def __init__(
Raises
------
NotImplementedError
If `full_batch_size` is provided.
If `full_batch_size` is provided or `include_sampler` is True.
"""
if kwargs.pop("full_batch_size", None):
warnings.warn(
"full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
)
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
if qaic_config is not None and qaic_config.pop("include_sampler", False):
raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
super().__init__(model, **kwargs)

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
Expand Down Expand Up @@ -2230,6 +2259,8 @@ def from_pretrained(
If True, uses the dual QPC approach (vision encoder KV offloaded).
If False, uses the single QPC approach (entire model in one QPC).
If None, the default behavior of the internal classes is used (typically dual QPC).
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
**kwargs :
Additional arguments passed to HuggingFace's ``from_pretrained``.

Expand Down Expand Up @@ -2600,10 +2631,13 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}

if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs(
example_inputs=example_inputs,
output_names=output_names,
dynamic_axes=dynamic_axes,
continuous_batching=self.continuous_batching,
vocab_size=self.model.config.vocab_size,
qaic_config=self.model.qaic_config,
)

return self._export(
Expand All @@ -2615,85 +2649,6 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
offload_pt_weights=kwargs.get("offload_pt_weights", True),
)

def get_sampling_inputs_and_outputs(
self,
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
):
"""
Updates the example inputs, output names, and dynamic axes to include
parameters relevant for on-device sampling during ONNX export.

Parameters
----------
example_inputs : Dict[str, torch.Tensor]
Current dictionary of example inputs.
output_names : List[str]
Current list of output names.
dynamic_axes : Dict[str, Dict[int, str]]
Current dictionary of dynamic axes configurations.

Returns
-------
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
Updated example inputs, output names, and dynamic axes including
sampling-related parameters.
"""
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS

example_inputs["last_accepted_output_tokens"] = torch.zeros(
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
)
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}

example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_repetition_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_repetition_penalty_buffer_RetainedState")

example_inputs["repetition_penalties"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
)
dynamic_axes["repetition_penalties"] = {0: "batch_size"}

example_inputs["past_presence_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_presence_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_presence_penalty_buffer_RetainedState")

example_inputs["presence_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
)
dynamic_axes["presence_penalties"] = {0: "batch_size"}

example_inputs["temperatures"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
)
dynamic_axes["temperatures"] = {0: "batch_size"}

max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
dynamic_axes["top_ks"] = {0: "batch_size"}

example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
dynamic_axes["top_ps"] = {0: "batch_size"}

example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
dynamic_axes["min_ps"] = {0: "batch_size"}

example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

return example_inputs, output_names, dynamic_axes

def build_prefill_specialization(
self,
prefill_seq_len: int = 32,
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@
QEffGrok1MultiHeadAttention,
)
from QEfficient.transformers.models.internvl.modeling_internvl import (
QEffInternDecoderWrapper,
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
Expand Down Expand Up @@ -394,6 +395,7 @@
QEffQwen2_5_VLModel,
QEffQwen2_5_VLTextModel,
QEffQwen2_5_VLVisionAttention,
QEffQwen_2_5_vl_DecoderWrapper,
QEffQwen_2_5_vl_ForConditionalGeneration,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
Expand Down Expand Up @@ -710,10 +712,12 @@ class SamplerTransform:
QEffGPTJForCausalLM,
QEffGraniteForCausalLM,
QEffGraniteMoeForCausalLM,
QEffInternDecoderWrapper,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we are enabling sampling only for intern model?
Will other VLMs also be supported?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other VLMs are also supposed to be supported. But currently only InternVL and Qwen VL 2.5 have been tested.

QEffLlamaForCausalLM,
QEffMptForCausalLM,
QEffPhi3ForCausalLM,
QEffQwen2ForCausalLM,
QEffQwen_2_5_vl_DecoderWrapper,
}

@classmethod
Expand Down
Loading
Loading