diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index e1b3d8a1049..1b884be6607 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -7580,6 +7580,7 @@ def test_static_llm_model(self): # noqa: C901 "1024", "--max_context_len", "1024", + "--skip_user_prompt_calibration", ] match self.static_llm_eval_method: @@ -7588,9 +7589,13 @@ def test_static_llm_model(self): # noqa: C901 [ "--eval_methods", "tasks_eval", - "--tasks", + "--eval_tasks", "wikitext", - "--limit", + "--eval_limit", + "1", + "--calib_tasks", + "wikitext", + "--calib_limit", "1", ] ) @@ -7599,25 +7604,33 @@ def test_static_llm_model(self): # noqa: C901 [ "--eval_methods", "tasks_eval", - "--tasks", + "--eval_tasks", + "hellaswag", + "--eval_limit", + "10", + "--calib_tasks", "hellaswag", - "--limit", + "--calib_limit", "10", ] ) case "sqnr": cmds.extend( [ - "--skip_user_prompt_calibration", - "--tasks", + "--eval_tasks", "wikitext", - "--limit", + "--eval_limit", "1", "--eval_methods", "sqnr_eval", + "--calib_tasks", + "wikitext", + "--calib_limit", + "1", ] ) case _: + cmds.remove("--skip_user_prompt_calibration") logging.warning( "No llm eval method chosen. Only generate model output." ) @@ -7883,9 +7896,13 @@ def test_attention_sink(self): "1024", "--eval_methods", "tasks_eval", - "--tasks", + "--eval_tasks", + "wikitext", + "--eval_limit", + "1", + "--calib_tasks", "wikitext", - "--limit", + "--calib_limit", "1", "--use_attention_sink", "4,32", diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index c606b3641b5..7bd1ef10efe 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -123,13 +123,13 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### LLAMA3.2 1B Instruct Default example using hybrid mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### LLAMA3.2 3B Instruct Default example using hybrid mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### Codegen2 @@ -141,73 +141,73 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### Gemma 2B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma-2b --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### Gemma2 2B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma2-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma2-2b --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### Gemma3 1B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### GLM 1.5B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model glm-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model glm-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### Granite3.3 2B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model granite_3_3-2b_instruct --prompt "I would like to learn python, could you teach me with a simple example?" --eval_methods tasks_eval --task hellaswag --limit 10 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model granite_3_3-2b_instruct --prompt "I would like to learn python, could you teach me with a simple example?" --eval_methods tasks_eval --eval_tasks hellaswag --eval_limit 10 --calib_tasks hellaswag --calib_limit 10 ``` #### Phi4-mini-instruct Default example using hybrid mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### QWEN2.5 0.5B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model qwen2_5-0_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model qwen2_5-0_5b --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### QWEN2.5 1.5B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model qwen2_5-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model qwen2_5-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### QWEN3 0.6B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model qwen3-0_6b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model qwen3-0_6b --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### QWEN3 1.7B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### SmolLM2 Default example using hybrid mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` #### SmolLM3 Default example using hybrid mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 ``` ## Multimodal Support @@ -472,7 +472,7 @@ The VLM inference pipeline consists of: - KV cache is updated for efficient subsequent token generation -### KV Cache update mechanism +## KV Cache update mechanism We use Smart Mask mechanisms for updating the key-value (KV) cache. #### Smart Mask mechanism: @@ -538,23 +538,23 @@ To evaluate the perplexity across all 3 phases, users should provide the `--eval For example, using the Qwen model and 1 wikitext sample as the evaluation task, users can assess all 3 phases perplexity score in a single run by including the appropriate configuration: ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods tasks_eval --tasks wikitext --limit 1 --verbose +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods tasks_eval --calib_tasks wikitext --calib_limit 1 --eval_tasks wikitext --eval_limit 1 --verbose ``` From the example script above, 1 wikitext sample is used to evaluate all 3 phases. However, there are cases where a user may want to use one sample for quantization calibration and multiple samples for perplexity evaluation. In this case, the process should be split into two runs. In the 1st run, the model is compiled using one sample. In the 2nd run, the user can provide a different configuration for QNN device execution. Example: ```bash -# 1st run to compile with --limit 1 -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods tasks_eval --tasks wikitext --limit 1 --compile_only +# 1st run to compile with --calib_limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods tasks_eval --calib_tasks wikitext --calib_limit 1 --compile_only ``` ```bash -# 2nd run to perform QNN device execution with --limit 3 -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods tasks_eval --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json +# 2nd run to perform QNN device execution with --eval_limit 3 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods tasks_eval --eval_tasks wikitext --eval_limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json ``` #### Tasks quantization calibration -If `--tasks ${TASK}` is not provided, the program will use `--prompt ${PROMPT}` as the dataset for quantization calibration. -Regardless of whether `--eval_methods tasks_eval` is provided, as long as `--tasks ${TASK}` is specified, the specified tasks will be used for model quantization calibration instead of the prompt. +If `--calib_tasks ${TASK}` is not provided, the program will use `--prompt ${PROMPT}` as the dataset for quantization calibration. +`--calib_tasks` and `--eval_tasks` are independent flags. `--calib_tasks` controls which tasks are used for quantization calibration, while `--eval_tasks` controls which tasks are used for perplexity evaluation. They can be set to different tasks or limits as needed. #### SQNR Evalution To evaluate QNN's output logits against the golden logits from `nn.Module`, users can provide the flag `--sqnr_eval`. Please note that SQNR evaluation will only compare the logits of the user's prompt and will not compare the new tokens generated by the model. @@ -572,7 +572,7 @@ To automatically identify sensitive layers and generate a mixed-precision recipe Example: ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --tasks wikitext --limit 1 --quant_recipe_suggestion --compile_only +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --calib_tasks wikitext --calib_limit 1 --quant_recipe_suggestion --compile_only ``` After the run, pick one of the generated classes from `qwen3-1_7b_suggest_recipe.py` as your new recipe. For a full walkthrough, see [quantization_guidance.md](quantization_guidance.md). @@ -601,7 +601,7 @@ This feature supports fluent multi-turn conversations and manages long-context s Example: ```bash # Compile llama pte file and attention sink evictor pte file with sink_size = 4 and batch_eviction_size = 64 -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 4096 --max_context_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 --use_attention_sink 4,64 --compile_only +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 4096 --max_context_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --calib_tasks wikitext --calib_limit 1 --use_attention_sink 4,64 --compile_only ``` After running this, the `attention_sink_evictor.pte` file will be generated in the artifacts directory. This file is necessary for using the attention sink feature, as it handles removing the `eviction_batch_size` tokens from the kv cache, retaining the first `sink_size` tokens, and re-rotating the remaining tokens in the kv cache. diff --git a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py index ddd9ac68f00..6e04bdca61c 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py @@ -685,9 +685,9 @@ def __init__( is_multimodal=is_multimodal, ) self.inference_speed = None - self.tasks = args.tasks - self.num_fewshot = args.num_fewshot - self.limit = args.limit + self.tasks = args.eval_tasks + self.num_fewshot = args.eval_num_fewshot + self.limit = args.eval_limit adb = self._get_adb() self.eval_wrapper = TaskEval.QnnRunnerEvalWrapper( args=args, diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 92e6c43e642..ea09451a697 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -533,25 +533,48 @@ def _build_parser(): ) parser.add_argument( - "--tasks", + "--eval_tasks", nargs="+", type=str, default=None, - help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", + help="list of lm-eluther tasks to evaluate usage: --eval_tasks task1 task2", ) parser.add_argument( - "--limit", + "--eval_limit", type=int, default=1, help="number of samples to evalulate. If not set, evaluate all samples", ) parser.add_argument( - "--num_fewshot", + "--eval_num_fewshot", type=int, default=None, metavar="N", - help="Number of examples in few-shot context", + help="Number of examples to eval in few-shot context", + ) + + parser.add_argument( + "--calib_tasks", + nargs="+", + type=str, + default=None, + help="list of lm-eluther tasks to calibrate usage: --calib_tasks task1 task2", + ) + + parser.add_argument( + "--calib_limit", + type=int, + default=1, + help="number of samples to calibrate. If not set, calibrate all samples", + ) + + parser.add_argument( + "--calib_num_fewshot", + type=int, + default=None, + metavar="N", + help="Number of examples to calibrate in few-shot context", ) parser.add_argument( @@ -598,8 +621,8 @@ def export_llama(args) -> None: raise RuntimeError( "Eval device perplexity is only supported for KV mode. Hybrid mode will only use KV mode when evaluating tasks/sqnr." ) - if TASKS_EVAL in args.eval_methods and args.tasks is None: - raise RuntimeError("Please provide --tasks to eval perplexity") + if TASKS_EVAL in args.eval_methods and args.eval_tasks is None: + raise RuntimeError("Please provide --eval_tasks to eval perplexity") assert ( args.decoder_model in SUPPORTED_LLM_MODELS ), f"Unknown decoder_model: {args.decoder_model}." diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index 135fabd7f7b..98d8764f711 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -591,7 +591,7 @@ def _calibrate( is_multimodal = tok_embedding is not None # Determine if task-based calibration is requested - has_task_calibration = self.control_args.tasks is not None + has_task_calibration = self.control_args.calib_tasks is not None # Task-based calibration: Only for text-only LLMs # Multimodal models (VLMs) cannot use task-based evaluation currently. @@ -604,9 +604,9 @@ def _calibrate( tokenizer=tokenizer, ar_len=self.meta["get_ar_len"], max_seq_len=self.meta["get_max_context_len"], - tasks=self.control_args.tasks, - tasks_limit=self.control_args.limit, - num_fewshot=self.control_args.num_fewshot, + tasks=self.control_args.calib_tasks, + tasks_limit=self.control_args.calib_limit, + num_fewshot=self.control_args.calib_num_fewshot, use_i64_token=self.control_args.embedding_quantize is not None, event_name=f"{event}_tasks", seq_mse_candidates=self.config.seq_mse_candidates, @@ -828,7 +828,12 @@ def __init__( self.apply_embedding = apply_embedding - def _encoding_override(self, quantized_model, unquantized_model): # noqa: C901 + def _encoding_override( # noqa: C901 + self, + quantized_model, + unquantized_model, + override_kv_cache, + ): pbq_target = { torch.ops.torchao.dequantize_affine, torch.ops.torchao.quantize_affine, @@ -920,51 +925,54 @@ def parameter_override(quantized_node, unquantized_node): for param_quantized, param_unquantized in zip(*[p.keys() for p in parameters]): parameter_override(param_quantized, param_unquantized) - k_input_cache_nodes = [] - v_input_cache_nodes = [] - for node in unquantized_model.graph.nodes: - if node.op != "placeholder": - continue + if override_kv_cache: + k_input_cache_nodes = [] + v_input_cache_nodes = [] + for node in unquantized_model.graph.nodes: + if node.op != "placeholder": + continue - if "args_" in node.name: - args_idx = int(node.name.split("_")[-1]) + if "args_" in node.name: + args_idx = int(node.name.split("_")[-1]) - if args_idx >= self.decode.meta["get_n_layers"]: - v_input_cache_nodes.append(node) - else: - k_input_cache_nodes.append(node) + if args_idx >= self.decode.meta["get_n_layers"]: + v_input_cache_nodes.append(node) + else: + k_input_cache_nodes.append(node) - if not k_input_cache_nodes or not v_input_cache_nodes: - raise RuntimeError( - "KV cache input detection failed. This likely means the model naming " - "does not match expected prefixes." - ) + if not k_input_cache_nodes or not v_input_cache_nodes: + raise RuntimeError( + "KV cache input detection failed. This likely means the model naming " + "does not match expected prefixes." + ) - k_output_cache_nodes = [] - v_output_cache_nodes = [] - for node in quantized_model.graph.nodes: - if not is_graph_output(node): - continue - cache_output_node = node.args[0].args[0] - if is_node_src_start_with_name(cache_output_node, kv_cache_prefix="k_"): - k_output_cache_nodes.append(cache_output_node) - elif is_node_src_start_with_name(cache_output_node, kv_cache_prefix="v_"): - v_output_cache_nodes.append(cache_output_node) - - if not k_output_cache_nodes or not v_output_cache_nodes: - raise RuntimeError( - "KV cache detection failed. This likely means the model naming " - "does not match expected prefixes." - ) + k_output_cache_nodes = [] + v_output_cache_nodes = [] + for node in quantized_model.graph.nodes: + if not is_graph_output(node): + continue + cache_output_node = node.args[0].args[0] + if is_node_src_start_with_name(cache_output_node, kv_cache_prefix="k_"): + k_output_cache_nodes.append(cache_output_node) + elif is_node_src_start_with_name( + cache_output_node, kv_cache_prefix="v_" + ): + v_output_cache_nodes.append(cache_output_node) - for input_k_cache_node, output_k_cache_node in zip( - k_input_cache_nodes, k_output_cache_nodes - ): - activation_override(output_k_cache_node, input_k_cache_node) - for input_v_cache_node, output_v_cache_node in zip( - v_input_cache_nodes, v_output_cache_nodes - ): - activation_override(output_v_cache_node, input_v_cache_node) + if not k_output_cache_nodes or not v_output_cache_nodes: + raise RuntimeError( + "KV cache detection failed. This likely means the model naming " + "does not match expected prefixes." + ) + + for input_k_cache_node, output_k_cache_node in zip( + k_input_cache_nodes, k_output_cache_nodes + ): + activation_override(output_k_cache_node, input_k_cache_node) + for input_v_cache_node, output_v_cache_node in zip( + v_input_cache_nodes, v_output_cache_nodes + ): + activation_override(output_v_cache_node, input_v_cache_node) unquantized_model.recompile() @@ -1127,6 +1135,7 @@ def compile(self, request: Request): # noqa: C901 self._encoding_override( quantized_model=self.calibration_prefill.decoder, unquantized_model=self.decode.decoder, + override_kv_cache=True, ) # save logit's quantization attributes to meta @@ -1139,6 +1148,7 @@ def compile(self, request: Request): # noqa: C901 self._encoding_override( quantized_model=self.calibration_prefill.tok_embedding, unquantized_model=self.decode.tok_embedding, + override_kv_cache=False, ) # Saving Decode QDQ Model EP for SQNR evaluation @@ -1157,12 +1167,14 @@ def compile(self, request: Request): # noqa: C901 self._encoding_override( quantized_model=self.decode.decoder, unquantized_model=self.prefill.decoder, + override_kv_cache=True, ) if self.apply_embedding: self._encoding_override( quantized_model=self.decode.tok_embedding, unquantized_model=self.prefill.tok_embedding, + override_kv_cache=False, ) # calibration_prefill is only used for encoding override