Added fp16/bf16 based export and compile support for VLMs#819
Added fp16/bf16 based export and compile support for VLMs#819
Conversation
| retained_state=True, | ||
| specializations=specializations["lang"], | ||
| convert_to_fp16=True, | ||
| convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"), |
There was a problem hiding this comment.
nit: why is this condition? required for AI200? @quic-rishinr
There was a problem hiding this comment.
This condition is required in case user wants bf16 support which will come in AI200, I have updated the code to convert_to_fp16 = True when passed dtype is either fp16 or fp32.
QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py
Outdated
Show resolved
Hide resolved
| self.model, transformed = transform.apply(self.model) | ||
| any_transformed = any_transformed or transformed | ||
|
|
||
| self._normalize_torch_dtype() |
There was a problem hiding this comment.
nit: does this take care of embedding and ASR models too?
| "allenai/Molmo-7B-D-0924", | ||
| "meta-llama/Llama-3.2-11B-Vision-Instruct", | ||
| ]: | ||
| pytest.skip("Test skipped for this model due to some issues.") |
There was a problem hiding this comment.
nit: with our dummy configs, can we run all sample lm models w/this test quickly?
d84668e to
c658e0f
Compare
| torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0) | ||
| ) | ||
| needed_dtype = self.model.config.torch_dtype | ||
| input_values = input_values.astype(CUSTOM_IO_DTYPE_MAP[needed_dtype]) |
There was a problem hiding this comment.
Since inputs are in numpy format we should be using Torch_to_numpy_map right?
| router_logits = self.gate(hidden_states) | ||
|
|
||
| routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||
| routing_weights = F.softmax(router_logits, dim=1, dtype=self.gate.weight.dtype) |
There was a problem hiding this comment.
Can we update the softmax to be in original percision?
| def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float): | ||
| hidden_states = hidden_states.to(torch.float32) | ||
| div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32)) | ||
| div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1])) |
There was a problem hiding this comment.
RMS norm would create issue if we set in default precision similar to softmax. We should verify if its causing the issue. if its causing the issue we should revert this
There was a problem hiding this comment.
I tested with changing softmax to default, and it ran successfully.
1e15c1a to
9df3b31
Compare
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
… and inference. Almost all LLMs can now be compiled and infered in fp16, test_causal_lm_models script has the following notion regarding how the tests happened : # means the model wasn't tested due to the size, not sure if it'll run through or have an accuracy mismatch. ## means the ouputs match for fp16 and things worked fine. ### means, outputs come but don't match properly with HF tokens. #### means they're quantized model and additional effort is needed to enable these. These commits cover almost all LLMs currently supported. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Enabled CI tests for fp16 based LMs, embedding and sequence classification models. Modified CI based config for LLM tests. Embedding models have high MAD for fp16 exported models(~0.015) Certain CausalLMs cause a token mismatch after few tokens for fp16 setup. Whisper model has a clip operator issue for fp16 exported models so its not enabled yet. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Added a try catch setup for dtype casting of model weights post loading since gptq type models don't allow such conversion. Fixed a few dtype related issues for Audio based models. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
…both in bfloat16. Added a patch incloud infer to map bfloat16 or 11 key type to np.float16 for AI200 inference. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
…e to False for _compile when appropriate params are missing. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Added fp16/bf16 based export and compile support for VLMs