Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
f76e043
Added fp16/bf16 based export and compile for InternVL Model
asmigosw Feb 26, 2026
edef200
Ruff format
asmigosw Feb 26, 2026
853d999
Added bf16/fp16/fp32 support for mistral3
asmigosw Feb 26, 2026
ebda0e8
Added changes for Llama4
asmigosw Mar 2, 2026
28d6499
Ruff check
asmigosw Mar 2, 2026
aa659cb
Added custom dtype support for Molmo
asmigosw Mar 2, 2026
848577e
Added custom dtype support for llava_next
asmigosw Mar 2, 2026
8da9eac
Ruff format
asmigosw Mar 2, 2026
41addfa
Added custom_dtype support for Qwen2_5_vl
asmigosw Mar 3, 2026
c274445
Added custom_dtype support for mllama
asmigosw Mar 3, 2026
fd9a5a7
Ruff format
asmigosw Mar 3, 2026
2ad6706
Added custom_dtype support for Gemma3
asmigosw Mar 5, 2026
0326bd0
BF16 changes to be used
quic-dhirajku Mar 3, 2026
5c15fa0
Added modifications and changes to enable fp16/bf16 based compilation…
quic-dhirajku Mar 6, 2026
64fa655
Added custom dtype support for llava
asmigosw Mar 9, 2026
4f81aa8
Ruff format
asmigosw Mar 9, 2026
97b515a
Updated logits to dtype float32
asmigosw Mar 9, 2026
43b351c
Updatd the test file
asmigosw Mar 10, 2026
e938886
Added custom_dtype support for wav2vec2
asmigosw Mar 10, 2026
6764007
Updated SoftMax calculation precision for all modeling files.
quic-dhirajku Mar 11, 2026
9cdcb08
Comments Addressed
asmigosw Mar 13, 2026
60a088f
Addressed Comments
asmigosw Mar 13, 2026
8d21947
Updated custom fp16 models for causalLM
asmigosw Mar 13, 2026
ae00b72
Updated needed_dtype to handle edge cases
asmigosw Mar 13, 2026
aaf8dcc
Made changes for CI tests.
quic-dhirajku Mar 14, 2026
fc9e39b
Removed grok config from CI models list
quic-dhirajku Mar 14, 2026
d376e7b
Fixed grok1 model CI issue, added the custom config back
quic-dhirajku Mar 16, 2026
2b19a39
Added default dtype for string and None case
asmigosw Mar 16, 2026
541f9df
Updating QAIC LLM Test Time
asmigosw Mar 16, 2026
babf230
Replaced some model configs for quicker CI tests for LLMs
quic-dhirajku Mar 16, 2026
61422cb
CI failures addressed
asmigosw Mar 17, 2026
2915e49
removing comments
asmigosw Mar 17, 2026
8680d1a
Added check to not pass Custom_IO yaml when model weight and pkv are …
quic-dhirajku Mar 17, 2026
a9494aa
Added additional check to default bf16 model dtype and pkv cache dtyp…
quic-dhirajku Mar 18, 2026
fc45acd
Undo unit test for HL API Tests.
quic-dhirajku Mar 18, 2026
cb88f6c
Merge branch 'main' into custom_dtype
asmigosw Mar 25, 2026
2186375
Added logger warning for bf16
asmigosw Mar 25, 2026
8284da6
Merge branch 'main' into custom_dtype
asmigosw Mar 25, 2026
f68881c
Merge branch 'main' into custom_dtype
asmigosw Mar 30, 2026
a534872
Skipping sampler tests
asmigosw Apr 1, 2026
69cdd88
Merge branch 'main' into custom_dtype
asmigosw Apr 1, 2026
677fb64
Merge branch 'main' into custom_dtype
asmigosw Apr 2, 2026
52feedc
Added aic-hw-version args
asmigosw Apr 2, 2026
f22fd7d
Updated parsed args
asmigosw Apr 2, 2026
6429f8e
Merge branch 'main' into custom_dtype
asmigosw Apr 3, 2026
44809ea
Merge branch 'main' into custom_dtype
asmigosw Apr 7, 2026
7f10015
Merge branch 'main' into custom_dtype
asmigosw Apr 7, 2026
7f53e42
Updated Gemma3 for TF v4.57
asmigosw Apr 8, 2026
d73766b
Ruff format
asmigosw Apr 8, 2026
f92dd12
Removed Comments
asmigosw Apr 8, 2026
da0cb15
Comments Addressed
asmigosw Apr 9, 2026
838f0e7
ruff check fix
asmigosw Apr 9, 2026
b2532ba
Merge branch 'main' into custom_dtype
asmigosw Apr 13, 2026
d2d4127
Merge branch 'main' into custom_dtype
asmigosw Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _transform_names(self) -> List[str]:
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.config = model.config
self.hash_params = create_model_params(self, **kwargs)
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
Expand All @@ -77,11 +78,51 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
self.model, transformed = transform.apply(self.model)
any_transformed = any_transformed or transformed

self._normalize_torch_dtype()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: does this take care of embedding and ASR models too?


if not any_transformed:
warnings.warn(f"No transforms applied to model: {self.model_name}. It may be an unsupported model!")
else:
logger.info(f"Pytorch transforms applied to model: {self.model_name}")

if self.config.torch_dtype == torch.bfloat16:
logger.warning("BFloat16 dtype is not yet supported; converting to float16 precision!")

def _normalize_torch_dtype(self):
"""
Normalizes torch_dtype across all nested configs to match the top-level config.

This method ensures consistency by propagating the top-level torch_dtype
to all nested configs (llm_config, vision_config, etc.) that may exist in
multimodal models.
"""
top_level_dtype = getattr(self.config, "torch_dtype", torch.float32)
Comment thread
asmigosw marked this conversation as resolved.

if top_level_dtype is None:
top_level_dtype = torch.float32
elif isinstance(top_level_dtype, str):
top_level_dtype = getattr(torch, top_level_dtype, torch.float32)

self.config.torch_dtype = top_level_dtype

# Normalize llm_config if it exists
if hasattr(self.config, "llm_config"):
self.config.llm_config.torch_dtype = top_level_dtype
if hasattr(self.config.llm_config, "use_bfloat16"):
self.config.llm_config.use_bfloat16 = top_level_dtype == torch.bfloat16

# Normalize vision_config if it exists
if hasattr(self.config, "vision_config"):
self.config.vision_config.torch_dtype = top_level_dtype
if hasattr(self.config.vision_config, "use_bfloat16"):
self.config.vision_config.use_bfloat16 = top_level_dtype == torch.bfloat16

# Normalize text_config if it exists (for models like Qwen2.5-VL)
if hasattr(self.config, "text_config"):
self.config.text_config.torch_dtype = top_level_dtype

logger.info(f"Normalized all config torch_dtype to: {top_level_dtype}")

def _offload_model_weights(self, offload_pt_weights: bool) -> bool:
"""Clear PyTorch model weights to reduce memory usage after ONNX export."""
if offload_pt_weights and not self._is_weights_offloaded:
Expand Down Expand Up @@ -506,12 +547,21 @@ def _compile(
command.append(f"-network-specialization-config={specializations_json}")

# Write custom_io.yaml file
model_in_bfloat16 = hasattr(self, "config") and (self.config.torch_dtype == torch.bfloat16)
pkv_in_bfloat16 = (custom_io is not None) and any(
"past_" in key and "bfloat16" in value for key, value in custom_io.items()
)
if custom_io is not None:
custom_io_yaml = compile_dir / "custom_io.yaml"
with open(custom_io_yaml, "w") as fp:
for io_name, dtype in custom_io.items():
fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n")
command.append(f"-custom-IO-list-file={custom_io_yaml}")
if model_in_bfloat16 and pkv_in_bfloat16:
logger.warning(
"Model and Past KV types are both bfloat16. Custom IO list file will be ignored during compile."
)
else:
command.append(f"-custom-IO-list-file={custom_io_yaml}")

command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
Expand Down
1 change: 1 addition & 0 deletions QEfficient/generation/cloud_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(

# Build dtype mapping once (depends on aicapi constants)
self.aic_to_np_dtype_mapping = {
getattr(aicapi, "BFLOAT16_TYPE", 11): np.dtype(np.float16),
aicapi.FLOAT_TYPE: np.dtype(np.float32),
aicapi.FLOAT_16_TYPE: np.dtype(np.float16),
aicapi.INT8_Q_TYPE: np.dtype(np.int8),
Expand Down
8 changes: 7 additions & 1 deletion QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,10 +734,16 @@ def from_legacy_cache(
) -> "HybridCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""

# Get the sliding_window_pattern from config
sliding_window_pattern = getattr(
config, "_sliding_window_pattern", getattr(config, "sliding_window_pattern", None)
)

cache = cls(
config,
batch_size=past_key_values[0][0].shape[0],
max_cache_len=past_key_values[config.sliding_window_pattern - 1][0].shape[2],
max_cache_len=past_key_values[sliding_window_pattern - 1][0].shape[2],
sliding_window_len=past_key_values[0][0].shape[2],
)
if past_key_values is not None:
Expand Down
7 changes: 3 additions & 4 deletions QEfficient/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _attn(
head_mask=None,
):
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
query = query.to(value.dtype)
key = key.to(value.dtype)

attn_weights = torch.matmul(query, key.transpose(-1, -2))

Expand Down Expand Up @@ -349,8 +349,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = transformer_outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
lm_logits = self.lm_head(hidden_states)

lm_logits = self.lm_head(hidden_states).float()
return CausalLMOutputWithPast(
loss=None,
logits=lm_logits,
Expand Down
8 changes: 5 additions & 3 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ def forward(
attention_scores = query_layer @ key_layer.transpose(-1, -2)
attention_scores /= math.sqrt(self.head_dim)
attention_scores = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attention_scores
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=self.config.torch_dtype), attention_scores
)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=torch.float32).to(
query_layer.dtype
)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
attn_output = attention_scores @ value_layer

Expand Down Expand Up @@ -401,7 +403,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = transformer_outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
lm_logits = self.lm_head(hidden_states)
lm_logits = self.lm_head(hidden_states).float()

return CausalLMOutputWithCrossAttentions(
loss=None,
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down Expand Up @@ -448,7 +448,7 @@ def forward(
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping

logits = logits.float()
return CausalLMOutputWithPast(
loss=None,
logits=logits,
Expand Down
50 changes: 29 additions & 21 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
class GemmaRMSNormFunc(torch.autograd.Function):
@staticmethod
def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float):
hidden_states = hidden_states.to(torch.float32)
div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32))
div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=hidden_states.dtype))
variance = div_first.pow(2).sum(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + epsilon)
return weight * hidden_states
Expand All @@ -61,7 +60,7 @@ class QEffGemma3CustomRMSNormAIC(nn.Module):
def forward(self, hidden_states):
return GemmaRMSNormFunc.apply(
hidden_states,
self.weight.float() + 1.0,
(self.weight).to(hidden_states.dtype) + 1.0,
self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps,
)

Expand Down Expand Up @@ -164,7 +163,7 @@ def eager_attention_forward(

if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down Expand Up @@ -198,7 +197,7 @@ def __qeff_init__(self):
config = copy.deepcopy(self.config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {"rope_type": "default", "factor": 1.0}
self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern)
self.is_local = _is_local(self.layer_idx, self.config._sliding_window_pattern)
self.window = self.config.sliding_window if self.is_local else None

self.rotary_emb_local = QEffGemma3RotaryEmbedding(
Expand Down Expand Up @@ -253,7 +252,7 @@ def forward(
"batch_index": batch_index,
"position_ids": position_ids,
"is_sliding": self.is_sliding,
"sliding_window_pattern": self.config.sliding_window_pattern,
"sliding_window_pattern": self.config._sliding_window_pattern,
"sliding_window": past_key_values.sliding_window_len,
}
if comp_ctx_lengths is not None:
Expand All @@ -272,7 +271,9 @@ def forward(

if attention_mask is not None: # no matter the length, we just slice it
attn_weights = torch.where(
attention_mask.bool(), torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask.bool(),
torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=self.config.torch_dtype),
attn_weights,
)

# upcast attention to fp32
Expand Down Expand Up @@ -322,7 +323,7 @@ def forward(
else:
attention_mask = _create_causal_mask(
position_ids=position_ids,
target_length=past_key_value.key_cache[self.config.sliding_window_pattern - 1].shape[-2],
target_length=past_key_value.key_cache[self.config._sliding_window_pattern - 1].shape[-2],
)

hidden_states, self_attn_weights = self.self_attn(
Expand Down Expand Up @@ -534,6 +535,9 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if self.config.torch_dtype == torch.float16:
logger.warning("Accuracy might drop with float16 as torch_dtype")

Comment thread
asmigosw marked this conversation as resolved.
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand All @@ -551,7 +555,7 @@ def forward(
)
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
logits = self.lm_head(hidden_states)
logits = self.lm_head(hidden_states).float()

if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
Expand All @@ -569,7 +573,9 @@ def forward(
def get_dummy_pkv_cache(self, config, batch_size, seq_len):
n_heads = config.num_key_value_heads
d_head = config.head_dim
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
layer_switch = (
config._sliding_window_pattern if hasattr(config, "_sliding_window_pattern") else 2
) # 2 is for BC
is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
Expand All @@ -581,8 +587,8 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
for i in range(config.num_hidden_layers):
if hasattr(config, "sliding_window"):
cache_shape = global_cache_shape if not is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_value_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.config.torch_dtype)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.config.torch_dtype)
pkv = (new_layer_key_cache, new_layer_value_cache)
past_key_values.append(pkv)
return past_key_values
Expand Down Expand Up @@ -835,15 +841,15 @@ def get_onnx_dynamic_axes(
pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"}
pkv_dynamic_sliding_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"}
layer_switch = (
self.language_model.config.sliding_window_pattern
if hasattr(self.language_model.config, "sliding_window_pattern")
self.language_model.config._sliding_window_pattern
if hasattr(self.language_model.config, "_sliding_window_pattern")
else 2
)
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
apply_dynamic_axes = (
pkv_dynamic_sliding_axes
if ((i + 1) % layer_switch and hasattr(self.language_model.config, "sliding_window_pattern"))
if ((i + 1) % layer_switch and hasattr(self.language_model.config, "_sliding_window_pattern"))
else pkv_dynamic_axes
)
lang_dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes
Expand Down Expand Up @@ -881,7 +887,9 @@ def get_output_names(self, kv_offload: bool = False):
def get_dummy_pkv_cache(self, config, batch_size, seq_len):
n_heads = config.num_key_value_heads
d_head = config.head_dim
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
layer_switch = (
config._sliding_window_pattern if hasattr(config, "_sliding_window_pattern") else 2
) # 2 is for BC
is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
Expand All @@ -893,8 +901,8 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
for i in range(config.num_hidden_layers):
if hasattr(config, "sliding_window"):
cache_shape = global_cache_shape if not is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_value_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.config.torch_dtype)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.config.torch_dtype)
pkv = (new_layer_key_cache, new_layer_value_cache)
past_key_values.append(pkv)
return past_key_values
Expand Down Expand Up @@ -931,9 +939,9 @@ def get_dummy_inputs(
# Define inputs
vision_inputs = {}
lang_inputs = {}
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=self.config.torch_dtype)
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype)
lang_inputs["position_ids"] = (
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
Expand Down Expand Up @@ -972,7 +980,7 @@ def get_inputs_info(self):
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
IOInfo(
name="pixel_values",
datatype=torch.float32,
datatype=self.config.torch_dtype,
shape=("batch_size", 3, "img_size", "img_size"),
),
]
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
if attention_mask is not None:
# Apply the attention mask
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
Comment thread
asmigosw marked this conversation as resolved.
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)

# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def eager_attention_forward(

if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down Expand Up @@ -439,7 +439,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = transformer_outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
lm_logits = self.lm_head(hidden_states)
lm_logits = self.lm_head(hidden_states).float()

return CausalLMOutputWithCrossAttentions(
loss=None,
Expand Down
Loading
Loading