diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index b37fdc74a..c603a60d0 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -36,6 +36,7 @@ write_io_files, ) from QEfficient.utils import LRUCache +from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger @@ -313,6 +314,13 @@ def _execute_chunked_prefill( prefill_ccl_id = 0 lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + if self.include_sampler: + for op in Constants.SAMPLER_OPS: + if decode_batch_id is not None: + lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + else: + lang_inputs[op] = self.sampling_params[op] + for i in range(num_chunks): input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] position_ids_slice = lang_inputs["position_ids"][ @@ -338,6 +346,11 @@ def _execute_chunked_prefill( chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"] + if self.include_sampler: + chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] + for op in Constants.SAMPLER_OPS: + chunk_inputs[op] = lang_inputs[op] + outputs = self._session.run(chunk_inputs) if "image_idx_output" in outputs: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 16a809c96..5f3a2200a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -8,7 +8,7 @@ import warnings from pathlib import Path from time import perf_counter -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import torch @@ -62,6 +62,7 @@ ) from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger +from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs class QEFFTransformersBase(QEFFBaseModel): @@ -719,7 +720,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, qaic_config, **kwargs): + def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -733,7 +734,7 @@ def __init__(self, model, qaic_config, **kwargs): **kwargs : Additional keyword arguments passed to the base class constructor. """ - super().__init__(model, **kwargs) + super().__init__(model, qaic_config=qaic_config, **kwargs) self.model = model.get_qeff_language_decoder() self.model.qaic_config = qaic_config self.hash_params["qeff_auto_class"] = self.__class__.__name__ @@ -879,16 +880,16 @@ def __init__( ---------- model : nn.Module The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : - Additional keyword arguments. `full_batch_size` is not supported here. - - Raises - ------ - NotImplementedError - If `full_batch_size` is provided. + Additional keyword arguments. """ if kwargs.pop("full_batch_size", None): - raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) self.model = model self.config = model.config @@ -900,6 +901,11 @@ def __init__( self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.input_shapes, self.output_names = None, None + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): @@ -1010,6 +1016,19 @@ def export( kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode ) output_names = self.model.get_output_names(kv_offload=True) + if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get( + "include_sampler", False + ): + logits_index = output_names["lang"].index("logits") + output_names["lang"][logits_index] = "next_tokens" + inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs( + example_inputs=inputs["lang"], + output_names=output_names["lang"], + dynamic_axes=dynamic_axes["lang"], + continuous_batching=self.continuous_batching, + vocab_size=self.model.language_model.config.vocab_size, + qaic_config=self.lang_model.model.qaic_config, + ) self.vision_model.export( inputs["vision"], @@ -1242,6 +1261,7 @@ def generate( generation_len: Optional[int] = None, image_height: Optional[int] = None, image_width: Optional[int] = None, + **kwargs, ) -> Union[torch.Tensor, np.ndarray]: """ Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. @@ -1302,6 +1322,7 @@ def generate( comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, image_height=image_height, image_width=image_width, + **kwargs, ) # Call generate method @@ -1584,10 +1605,15 @@ def __init__( Raises ------ NotImplementedError - If `full_batch_size` is provided. + If `full_batch_size` is provided or `include_sampler` is True. """ if kwargs.pop("full_batch_size", None): + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if qaic_config is not None and qaic_config.pop("include_sampler", False): + raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) self.model.qaic_config = qaic_config @@ -2204,6 +2230,8 @@ def from_pretrained( If True, uses the dual QPC approach (vision encoder KV offloaded). If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. @@ -2564,10 +2592,13 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs( example_inputs=example_inputs, output_names=output_names, dynamic_axes=dynamic_axes, + continuous_batching=self.continuous_batching, + vocab_size=self.model.config.vocab_size, + qaic_config=self.model.qaic_config, ) return self._export( @@ -2579,85 +2610,6 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = offload_pt_weights=kwargs.get("offload_pt_weights", True), ) - def get_sampling_inputs_and_outputs( - self, - example_inputs: Dict[str, torch.Tensor], - output_names: List[str], - dynamic_axes: Dict[str, Dict[int, str]], - ): - """ - Updates the example inputs, output names, and dynamic axes to include - parameters relevant for on-device sampling during ONNX export. - - Parameters - ---------- - example_inputs : Dict[str, torch.Tensor] - Current dictionary of example inputs. - output_names : List[str] - Current list of output names. - dynamic_axes : Dict[str, Dict[int, str]] - Current dictionary of dynamic axes configurations. - - Returns - ------- - Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] - Updated example inputs, output names, and dynamic axes including - sampling-related parameters. - """ - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - example_inputs["last_accepted_output_tokens"] = torch.zeros( - (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 - ) - dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} - - example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_repetition_penalty_buffer_RetainedState") - - example_inputs["repetition_penalties"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES - ) - dynamic_axes["repetition_penalties"] = {0: "batch_size"} - - example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_presence_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_presence_penalty_buffer_RetainedState") - - example_inputs["presence_penalties"] = ( - torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES - ) - dynamic_axes["presence_penalties"] = {0: "batch_size"} - - example_inputs["temperatures"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES - ) - dynamic_axes["temperatures"] = {0: "batch_size"} - - max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) - example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) - dynamic_axes["top_ks"] = {0: "batch_size"} - - example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS - dynamic_axes["top_ps"] = {0: "batch_size"} - - example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS - dynamic_axes["min_ps"] = {0: "batch_size"} - - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) - dynamic_axes["random_numbers"] = {0: "batch_size"} - - return example_inputs, output_names, dynamic_axes - def build_prefill_specialization( self, prefill_seq_len: int = 32, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 07b9fe7e1..460f71f1a 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -296,6 +296,7 @@ QEffGrok1MultiHeadAttention, ) from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternDecoderWrapper, QEffInternVisionEmbeddings, QEffInternVLModel, ) @@ -399,6 +400,7 @@ QEffQwen2_5_VLModel, QEffQwen2_5_VLTextModel, QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_DecoderWrapper, QEffQwen_2_5_vl_ForConditionalGeneration, ) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( @@ -719,10 +721,12 @@ class SamplerTransform: QEffGPTJForCausalLM, QEffGraniteForCausalLM, QEffGraniteMoeForCausalLM, + QEffInternDecoderWrapper, QEffLlamaForCausalLM, QEffMptForCausalLM, QEffPhi3ForCausalLM, QEffQwen2ForCausalLM, + QEffQwen_2_5_vl_DecoderWrapper, } @classmethod diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 96846e712..f7473cbd0 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -103,6 +105,7 @@ def sampler_forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -122,6 +125,8 @@ def sampler_forward( top_ps: Optional[torch.Tensor] = None, min_ps: Optional[torch.Tensor] = None, random_numbers: Optional[torch.Tensor] = None, + vision_embeds: Optional[torch.Tensor] = None, + image_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -170,20 +175,37 @@ def sampler_forward( Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. """ - - outputs = self.old_forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - batch_index=batch_index, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) + if vision_embeds is not None: + forward_kwargs = dict( + input_ids=input_ids, + vision_embeds=vision_embeds, + position_ids=position_ids, + image_idx=image_idx, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + ) + if batch_index is not None: + forward_kwargs["batch_index"] = batch_index + + logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs) + outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) + if position_ids.dim() == 3: # For models using m-rope + position_ids = position_ids[0] + else: + outputs = self.old_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) logits = outputs.get("logits", None) assert logits is not None, f"{self.model.__class__.__name__} does not return logits." @@ -230,7 +252,9 @@ def sampler_forward( return SamplerOutput( probs=None, next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) @@ -300,9 +324,8 @@ def sampler_forward( ) # (batch_size, spec_length, vocab_size) # Random Sampling - topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids) gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick - y = topk_probs_asc + gumbel_noise + y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids) random_samples_indices = torch.argmax(y, dim=1, keepdim=True) random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1) @@ -314,7 +337,9 @@ def sampler_forward( return SamplerOutput( probs=probs, next_tokens=next_tokens, # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py index 6fb1b326f..0460eeb3a 100644 --- a/QEfficient/utils/sampler_utils.py +++ b/QEfficient/utils/sampler_utils.py @@ -5,8 +5,11 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Set +from typing import Dict, List, Optional, Set +import torch + +from QEfficient.utils import constants from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger @@ -56,3 +59,89 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ) return session_includes_sampler + + +def get_sampling_inputs_and_outputs( + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + continuous_batching: bool, + vocab_size: int, + qaic_config: Dict, +): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + continuous_batching : bool + Whether this model will be used for continuous batching in the future. + vocab_size: int + Vocabulary size for this model. + qaic_config : Dict + QAIC config dictionary. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + example_inputs["last_accepted_output_tokens"] = torch.zeros( + (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 + ) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes diff --git a/examples/performance/on_device_sampling.py b/examples/performance/on_device_sampling.py index 6cc72b715..b4e1f4e27 100644 --- a/examples/performance/on_device_sampling.py +++ b/examples/performance/on_device_sampling.py @@ -28,6 +28,7 @@ def main(args, **kwargs): if include_sampler is not None: return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) + np.random.seed(int(args.random_number)) sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -36,7 +37,9 @@ def main(args, **kwargs): "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), "min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1), - "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype( + np.float32 + ), } qaic_config = { k: v @@ -110,10 +113,10 @@ def main(args, **kwargs): --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 2. For non-continuous batching: python3.10 examples/on_device_sampling.py \ @@ -130,10 +133,10 @@ def main(args, **kwargs): --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 """ parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9335e1d91..99eb98a73 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -5,12 +5,13 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Union import numpy as np import pytest +from transformers import AutoProcessor -from QEfficient import QEFFAutoModelForCausalLM +from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import load_hf_tokenizer from QEfficient.utils.constants import Constants @@ -24,6 +25,20 @@ 20, # generation_len 2, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "Qwen/Qwen2.5-VL-3B-Instruct", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] greedy_sampling_configs = [ @@ -35,6 +50,20 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "Qwen/Qwen2.5-VL-3B-Instruct", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] random_sampling_configs = [ @@ -46,23 +75,38 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm ), + # pytest.param( + # "Qwen/Qwen2.5-VL-3B-Instruct", # model + # ( + # ["https://picsum.photos/id/237/536/354"] * 2, + # ["Can you describe the image in detail."] * 2, + # ), # images and prompts + # 128, # prefill_seq_len + # 4096, # ctx_len + # 20, # generation_len + # 2, # full_batch_size + # None, # spec_length + # True, # is_vlm + # ), ] @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", sampler_transform_configs, ) def test_sampler_transform( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the @@ -70,45 +114,56 @@ def test_sampler_transform( next tokens and/or probability distributions. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + else: + additional_configs["num_hidden_layers"] = 2 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) - model_w_sampler_qpc_path: str = model_w_sampler.compile( + model_w_sampler_qpc_path = model_w_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) - model_wo_sampler_qpc_path: str = model_wo_sampler.compile( + model_wo_sampler_qpc_path = model_wo_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) + if is_vlm: + model_w_sampler_qpc_path = model_w_sampler_qpc_path[1] + model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1] # Init qaic session model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) @@ -139,40 +194,54 @@ def test_sampler_transform( @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", greedy_sampling_configs, ) def test_greedy_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test greedy sampling with QPC compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + additional_configs["num_hidden_layers"] = 4 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -180,7 +249,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -190,7 +259,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -211,8 +280,9 @@ def test_greedy_sampling( "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -221,6 +291,7 @@ def test_greedy_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts and ids @@ -233,25 +304,37 @@ def test_greedy_sampling( @pytest.mark.on_qaic -@pytest.mark.skip @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", random_sampling_configs, ) def test_random_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test random sampling with QPC compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -259,14 +342,16 @@ def test_random_sampling( "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -274,7 +359,7 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -284,13 +369,14 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) # Generate texts from prompts tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) model_w_sampler_exec_info = model_w_sampler.generate( tokenizer=tokenizer, prompts=prompts, @@ -301,12 +387,15 @@ def test_random_sampling( "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype( + np.float32 + ), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -315,36 +404,37 @@ def test_random_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts golden_texts = { - "w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both", + "w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over", "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", } golden_ids = { "w_sampler": [ [ - 21380, + 319, + 3615, 322, - 590, - 25448, - 2927, - 29892, - 19963, - 2654, - 29879, - 470, - 3708, - 2701, - 313, - 29902, - 508, - 30010, - 29873, - 505, - 963, - 1716, + 306, + 626, + 263, + 3005, + 295, + 749, + 9227, + 1058, + 12355, + 267, + 304, + 26987, + 278, + 3186, + 29889, + 2973, + 975, ] ], "wo_sampler": [