Skip to content

Added fp16/bf16 based export and compile support for VLMs#819

Open
asmigosw wants to merge 35 commits intoquic:mainfrom
asmigosw:custom_dtype
Open

Added fp16/bf16 based export and compile support for VLMs#819
asmigosw wants to merge 35 commits intoquic:mainfrom
asmigosw:custom_dtype

Conversation

@asmigosw
Copy link
Contributor

@asmigosw asmigosw commented Mar 2, 2026

Added fp16/bf16 based export and compile support for VLMs

retained_state=True,
specializations=specializations["lang"],
convert_to_fp16=True,
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why is this condition? required for AI200? @quic-rishinr

Copy link
Contributor Author

@asmigosw asmigosw Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition is required in case user wants bf16 support which will come in AI200, I have updated the code to convert_to_fp16 = True when passed dtype is either fp16 or fp32.

self.model, transformed = transform.apply(self.model)
any_transformed = any_transformed or transformed

self._normalize_torch_dtype()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: does this take care of embedding and ASR models too?

"allenai/Molmo-7B-D-0924",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
]:
pytest.skip("Test skipped for this model due to some issues.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: with our dummy configs, can we run all sample lm models w/this test quickly?

@quic-rishinr quic-rishinr marked this pull request as ready for review March 13, 2026 08:34
torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0)
)
needed_dtype = self.model.config.torch_dtype
input_values = input_values.astype(CUSTOM_IO_DTYPE_MAP[needed_dtype])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since inputs are in numpy format we should be using Torch_to_numpy_map right?

router_logits = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights = F.softmax(router_logits, dim=1, dtype=self.gate.weight.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the softmax to be in original percision?

def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float):
hidden_states = hidden_states.to(torch.float32)
div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32))
div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RMS norm would create issue if we set in default precision similar to softmax. We should verify if its causing the issue. if its causing the issue we should revert this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested with changing softmax to default, and it ran successfully.

asmigosw and others added 17 commits March 18, 2026 05:49
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
… and inference.

Almost all LLMs can now be compiled and infered in fp16, test_causal_lm_models script has the following notion regarding how the tests happened :
# means the model wasn't tested due to the size, not sure if it'll run through or have an accuracy mismatch.
## means the ouputs match for fp16 and things worked fine.
### means, outputs come but don't match properly with HF tokens.
#### means they're quantized model and additional effort is needed to enable these.
These commits cover almost all LLMs currently supported.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
asmigosw and others added 16 commits March 18, 2026 05:49
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Enabled CI tests for fp16 based LMs, embedding and sequence classification models.

Modified CI based config for LLM tests.

Embedding models have high MAD for fp16 exported models(~0.015)

Certain CausalLMs cause a token mismatch after few tokens for fp16 setup.
Whisper model has a clip operator issue for fp16 exported models so its not enabled yet.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Added a try catch setup for dtype casting of model weights post loading since gptq type models don't allow such conversion.
Fixed a few dtype related issues for Audio based models.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
…both in bfloat16.

Added a patch incloud infer to map bfloat16 or 11 key type to np.float16 for AI200 inference.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
…e to False for _compile when appropriate params are missing.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants