Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 86 additions & 109 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
write_io_files,
)
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
from QEfficient.proxy.pytorch_transform import QeffProxyModuleTransform
from QEfficient.transformers.modeling_utils import (
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH,
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH,
Expand Down Expand Up @@ -248,6 +249,10 @@ def __init__(self, model: nn.Module, pooling=None, **kwargs):
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

super().__init__(model, **kwargs)

# Make Embedding specific transforms like appending pooling
Expand Down Expand Up @@ -1027,36 +1032,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))

def __update_prefill_transform(
self,
enable: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
retain_full_kv: Optional[bool] = False,
):
if enable:
if enable_chunking:
self.model, tf = PrefillOnlyChunkedTransform.apply(self.model)
else:
self.model, tf = PrefillOnlyTransform.apply(self.model)

else:
if retain_full_kv:
self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model)
else:
self.model, tf = RevertPrefillOnlyTransform.apply(self.model)

def export(
self,
inputs,
output_names,
dynamic_axes,
export_dir=None,
offload_pt_weights=True,
prefill_seq_len: Optional[int] = None,
prefill_only: bool = False,
enable_chunking: bool = False,
**kwargs,
):
def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs):
"""
Exports the language decoder component to ONNX format.

Expand All @@ -1080,18 +1056,6 @@ def export(
str
Path to the generated ONNX graph file for the language decoder.
"""
if prefill_only:
assert prefill_seq_len > 1
if not enable_chunking and self.continuous_batching:
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
self.hash_params["prefill_only"] = True
self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking)
else:
self.hash_params["prefill_only"] = False
self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False))

return self._export(
inputs,
output_names=output_names,
Expand Down Expand Up @@ -1277,15 +1241,28 @@ def onnx_path(self):
"""
return [self.vision_model.onnx_path, self.lang_model.onnx_path]

@property
def qpc_path(self):
"""
Get the QPC paths for the vision and language model components.

Returns
-------
Union[List[str], str, None]
A list containing both QPC paths if both are compiled, or just one if only one is,
or None if neither is compiled.
"""
if self.vision_model.qpc_path and self.lang_model.qpc_path:
return [self.vision_model.qpc_path, self.lang_model.qpc_path]
elif self.vision_model.qpc_path:
return self.vision_model.qpc_path
else:
return self.lang_model.qpc_path

def export(
self,
export_dir: Optional[str] = None,
use_onnx_subfunctions: bool = False,
skip_vision: Optional[bool] = False,
skip_lang: Optional[bool] = False,
prefill_seq_len: Optional[int] = None,
prefill_only: bool = False,
enable_chunking: bool = False,
**kwargs,
) -> str:
"""
Expand Down Expand Up @@ -1339,33 +1316,26 @@ def export(
vocab_size=self.model.language_model.config.vocab_size,
qaic_config=self.lang_model.model.qaic_config,
)
if not skip_vision:
self.vision_model.export(
inputs["vision"],
output_names["vision"],
dynamic_axes["vision"],
export_dir=export_dir,
offload_pt_weights=False,
use_onnx_subfunctions=use_onnx_subfunctions,
)

if prefill_only and prefill_seq_len > 1:
offload_pt_weights = False # to keep weight for decode onnx
else:
offload_pt_weights = kwargs.get("offload_pt_weights", True)
self.vision_model.export(
inputs["vision"],
output_names["vision"],
dynamic_axes["vision"],
export_dir=export_dir,
offload_pt_weights=False,
use_onnx_subfunctions=use_onnx_subfunctions,
)

offload_pt_weights = kwargs.get("offload_pt_weights", True)
self.lang_model.export(
inputs["lang"],
output_names["lang"],
dynamic_axes["lang"],
export_dir=export_dir,
offload_pt_weights=offload_pt_weights,
use_onnx_subfunctions=use_onnx_subfunctions,
)

if not skip_lang:
self.lang_model.export(
inputs["lang"],
output_names["lang"],
dynamic_axes["lang"],
export_dir=export_dir,
offload_pt_weights=offload_pt_weights,
use_onnx_subfunctions=use_onnx_subfunctions,
prefill_only=prefill_only,
enable_chunking=enable_chunking,
prefill_seq_len=prefill_seq_len,
)
return self.onnx_path

def compile(
Expand All @@ -1389,8 +1359,6 @@ def compile(
skip_vision: Optional[bool] = False,
skip_lang: Optional[bool] = False,
use_onnx_subfunctions: bool = False,
prefill_only=None,
enable_chunking=False,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -1509,23 +1477,19 @@ def compile(
if lang_onnx_path:
self.lang_model.onnx_path = lang_onnx_path

if vision_onnx_path is None or lang_onnx_path is None:
if (self.vision_model.onnx_path is None and vision_onnx_path is None) or (
self.lang_model.onnx_path is None and lang_onnx_path is None
):
self.export(
use_onnx_subfunctions=use_onnx_subfunctions,
skip_vision=skip_vision,
skip_lang=skip_lang,
prefill_only=prefill_only,
enable_chunking=enable_chunking,
prefill_seq_len=prefill_seq_len,
)

# TODO this hould be removed once the continous batching is supported for all the models.
compiler_options.pop("continuous_batching", None)
compiler_options.pop("kv_cache_batch_size", None)
compiler_options.pop("full_batch_size", None)
self.qpc_paths = {}
if not skip_vision:
vision_qpc_path = self.vision_model._compile(
self.vision_model._compile(
compile_dir=compile_dir,
compile_only=True,
specializations=specializations["vision"],
Expand All @@ -1538,8 +1502,6 @@ def compile(
use_onnx_subfunctions=use_onnx_subfunctions,
**compiler_options,
)
self.qpc_paths["vision_qpc_path"] = vision_qpc_path

# Custom NPI file options
if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options:
compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path)
Expand All @@ -1550,34 +1512,18 @@ def compile(
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name[: -len("_RetainedState")]] = (
"float16"
if ("vision_embeds" in output_name or "deepstack_features" in output_name)
else kv_cache_dtype
"float16" if "vision_embeds" in output_name else kv_cache_dtype
)

# outputs
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name] = (
"float16"
if ("vision_embeds" in output_name or "deepstack_features" in output_name)
else kv_cache_dtype
)
if prefill_only:
specializations = specializations["lang"][:1]
qpc_key = "lang_prefill_qpc_path"
elif prefill_seq_len == 1:
specializations = specializations["lang"][-1:]
qpc_key = "lang_decode_qpc_path"
else:
specializations = specializations["lang"]
qpc_key = "lang_qpc_path"

lang_qpc_path = self.lang_model._compile(
custom_io_lang[output_name] = "float16" if "vision_embeds" in output_name else kv_cache_dtype
self.lang_model._compile(
compile_dir=compile_dir,
compile_only=True,
retained_state=True,
specializations=specializations,
specializations=specializations["lang"],
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
Expand All @@ -1587,8 +1533,7 @@ def compile(
use_onnx_subfunctions=use_onnx_subfunctions,
**compiler_options,
)
self.qpc_paths.update({qpc_key: lang_qpc_path})
return self.qpc_paths
return self.qpc_path

def generate(
self,
Expand Down Expand Up @@ -1720,6 +1665,7 @@ def kv_offload_generate(
AssertionError
If `generation_len` is not greater than zero.
"""
# breakpoint()
if not self.lang_model.qpc_path:
raise TypeError("Please run compile API for language model first!")

Expand Down Expand Up @@ -1751,6 +1697,7 @@ def kv_offload_generate(
[x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes]
+ [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]]
)
# breakpoint()
input_len = inputs["attention_mask"].sum(1, keepdims=True)
input_ids_length = inputs["input_ids"].shape[1]
num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
Expand Down Expand Up @@ -1787,6 +1734,14 @@ def kv_offload_generate(

vision_inputs_fp16 = {"pixel_values", "image_masks"}
vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs})
pixel_values_shape = list(vision_inputs["pixel_values"].shape)
idx = next(i for i, inner in enumerate(vision_session.allowed_shapes) if (2, pixel_values_shape) in inner)

biffer_set = {
"vision_embeds": np.zeros(vision_session.allowed_shapes[idx][2][1], dtype=np.float16),
"image_grid_thw": np.zeros(vision_session.allowed_shapes[idx][0][1], dtype=np.int64),
}
vision_session.set_buffers(biffer_set)

vision_start = perf_counter()

Expand All @@ -1796,6 +1751,7 @@ def kv_offload_generate(
vision_end = perf_counter()

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
# breakpoint()
if "position_ids" in inputs:
lang_inputs["position_ids"] = inputs["position_ids"]
lang_inputs.pop("attention_mask")
Expand All @@ -1807,10 +1763,21 @@ def kv_offload_generate(
not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama"
if not_mllama:
lang_inputs["image_idx"] = np.array([[0]])
# breakpoint()
if self.vision_model.qpc_path:
vision_session.deactivate()
lang_session.activate()

vision_outputs["vision_embeds"] = np.pad(
vision_outputs["vision_embeds"],
pad_width=(
(0, 0),
(0, lang_session.allowed_shapes[0][1][1][1] - vision_session.allowed_shapes[idx][2][1][1]),
(0, 0),
), # pad axis=1 only
mode="constant",
constant_values=0,
)
lang_session.set_buffers(vision_outputs)

if self.comp_ctx_lengths_prefill is not None:
Expand All @@ -1821,6 +1788,7 @@ def kv_offload_generate(
lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

lang_start = perf_counter()
# breakpoint()
# Run prefill
chunk_inputs = lang_inputs.copy()
for i in range(num_chunks):
Expand Down Expand Up @@ -1852,6 +1820,7 @@ def kv_offload_generate(
)
if not_mllama:
lang_session.skip_buffers(vision_outputs.keys())
# breakpoint()
# Get first token
lang_inputs["input_ids"] = outputs["logits"].argmax(2)
lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1
Expand Down Expand Up @@ -2686,7 +2655,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):

_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __update_prefill_transform(
def prefill(
self,
enable: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
Expand Down Expand Up @@ -2745,6 +2714,10 @@ def __init__(
raise TypeError(f"Required pytorch module for CausalLM or LMHeadModel, got {model_class_name}")
_configure_proxy_for_model(self, kwargs.pop("enable_proxy", False))

if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

# TODO: remove from version 1.20
if kwargs.pop("full_batch_size", None):
continuous_batching = True
Expand Down Expand Up @@ -2981,7 +2954,7 @@ def export(
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking)
self.prefill(enable=True, enable_chunking=enable_chunking)
self.hash_params.pop("retain_full_kv", None)
seq_len = self.get_seq_len_and_handle_specialized_prefill_model(
prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
Expand All @@ -2992,7 +2965,7 @@ def export(
else seq_len
)
else:
self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.hash_params.pop("prefill_only", None)
self.hash_params.pop("NUM_Q_BLOCKS", None)
self.hash_params.pop("NUM_FFN_BLOCKS", None)
Expand Down Expand Up @@ -3995,6 +3968,10 @@ class QEFFAutoModelForCTC(QEFFTransformersBase):
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

super().__init__(model, **kwargs)
self.model.base_model.config.use_cache = True

Expand Down
Loading
Loading