From d95bbebc7c90ab7e6c7f4124792d39092d67a379 Mon Sep 17 00:00:00 2001 From: shagsood Date: Mon, 10 Nov 2025 23:24:12 -0800 Subject: [PATCH 1/4] Onbaord GLM model Signed-off-by: shagsood --- .../transformers/models/glm4_moe/__init__.py | 7 + .../models/glm4_moe/modeling_glm4_moe.py | 360 ++++++++++++++++++ .../transformers/models/pytorch_transforms.py | 22 ++ 3 files changed, 389 insertions(+) create mode 100644 QEfficient/transformers/models/glm4_moe/__init__.py create mode 100644 QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py diff --git a/QEfficient/transformers/models/glm4_moe/__init__.py b/QEfficient/transformers/models/glm4_moe/__init__.py new file mode 100644 index 000000000..51b99e6ed --- /dev/null +++ b/QEfficient/transformers/models/glm4_moe/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py new file mode 100644 index 000000000..184eb1526 --- /dev/null +++ b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -0,0 +1,360 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import Callable, List, Optional, Union + +import torch +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.glm4_moe.modeling_glm4_moe import ( + Glm4MoeAttention, + Glm4MoeConfig, + Glm4MoeDecoderLayer, + Glm4MoeForCausalLM, + Glm4MoeModel, + Glm4MoeRotaryEmbedding, + repeat_kv, + rotate_half, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger + + +class QEffGlm4MoeRotaryEmbedding(Glm4MoeRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: Glm4MoeConfig, device=None): + super().__init__(config=config) + + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + +def qeff_apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + unsqueeze_dim: int = 1, +): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + +class QEffGlm4MoeAttention(Glm4MoeAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + if self.use_qk_norm: # main diff from Llama + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class QEffGlm4MoeDecoderLayer(Glm4MoeDecoderLayer): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class QEffGlm4MoeModel(Glm4MoeModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = QEffDynamicCache() + else: + cache_kwargs = {"batch_index" : batch_index, "position_ids": position_ids} + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values,cache_kwargs) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = _create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + batch_index=batch_index, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + +class QEffGlm4MoeForCausalLM(Glm4MoeForCausalLM): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Glm4MoeForCausalLM + + >>> model = Glm4MoeForCausalLM.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..71dbd7be7 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -44,6 +44,14 @@ Gemma3RMSNorm, Gemma3TextModel, ) +from transformers.models.glm4_moe.modeling_glm4_moe import ( + Glm4MoeAttention, + Glm4MoeDecoderLayer, + Glm4MoeForCausalLM, + Glm4MoeModel, + Glm4MoeRMSNorm, + Glm4MoeRotaryEmbedding, +) from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( GPTBigCodeAttention, @@ -240,6 +248,13 @@ QEffGemma3ForConditionalGeneration, QEffGemma3TextModel, ) +from QEfficient.transformers.models.glm4_moe.modeling_glm4_moe import ( + QEffGlm4MoeAttention, + QEffGlm4MoeDecoderLayer, + QEffGlm4MoeForCausalLM, + QEffGlm4MoeModel, + QEffGlm4MoeRotaryEmbedding, +) from QEfficient.transformers.models.gpt2.modeling_gpt2 import ( QEffGPT2Attention, QEffGPT2Block, @@ -451,6 +466,7 @@ class CustomOpsTransform(ModuleMappingTransform): Qwen3MoeRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, Olmo2RMSNorm: CustomRMSNormAIC, + Glm4MoeRMSNorm:CustomRMSNormAIC, } @@ -539,6 +555,12 @@ class KVCacheTransform(ModuleMappingTransform): GraniteMoeParallelExperts: QEffGraniteMoeParallelExperts, GraniteMoeTopKGating: QEffGraniteMoeTopKGating, GraniteMoeMoE: QEffGraniteMoeMoE, + # GLMMoe + Glm4MoeModel: QEffGlm4MoeModel, + Glm4MoeForCausalLM: QEffGlm4MoeForCausalLM, + Glm4MoeAttention: QEffGlm4MoeAttention, + Glm4MoeDecoderLayer: QEffGlm4MoeDecoderLayer, + Glm4MoeRotaryEmbedding: QEffGlm4MoeRotaryEmbedding, # mllama MllamaTextRMSNorm: CustomRMSNormAIC, MllamaTextSelfAttention: QEffMllamaTextSelfAttention, From 982a6e85bccd28d972772e9be31f2e2ac87d56ae Mon Sep 17 00:00:00 2001 From: shagsood Date: Wed, 12 Nov 2025 05:46:18 -0800 Subject: [PATCH 2/4] Fix modeling file issue Signed-off-by: shagsood --- .../models/glm4_moe/modeling_glm4_moe.py | 87 +++++++++---------- .../transformers/models/pytorch_transforms.py | 2 +- 2 files changed, 40 insertions(+), 49 deletions(-) diff --git a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py index 184eb1526..9f31bf1bb 100644 --- a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -54,7 +54,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): + def forward(self, x: torch.Tensor, seq_len: int = None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) @@ -64,6 +64,7 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, ) + def qeff_apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, @@ -95,12 +96,23 @@ def qeff_apply_rotary_pos_emb( cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -126,8 +138,13 @@ def eager_attention_forward( return attn_output, attn_weights + class QEffGlm4MoeAttention(Glm4MoeAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGlm4MoeRotaryEmbedding(config=self.config) + def forward( self, hidden_states: torch.Tensor, @@ -136,6 +153,7 @@ def forward( past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -153,12 +171,20 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - cos, sin = position_embeddings - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "batch_index": batch_index, + "position_ids": position_ids, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -180,7 +206,6 @@ def forward( class QEffGlm4MoeDecoderLayer(Glm4MoeDecoderLayer): - def forward( self, hidden_states: torch.Tensor, @@ -232,20 +257,17 @@ def forward( ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - + use_cache = use_cache if use_cache is not None else self.config.use_cache - + if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) - return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True if past_key_values is None: past_key_values = QEffDynamicCache() else: - cache_kwargs = {"batch_index" : batch_index, "position_ids": position_ids} - past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values,cache_kwargs) + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " @@ -261,43 +283,30 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = _create_causal_mask( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - batch_index=batch_index, - position_ids=position_ids, - ) + attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, batch_index=batch_index, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) + class QEffGlm4MoeForCausalLM(Glm4MoeForCausalLM): - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -312,23 +321,6 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, Glm4MoeForCausalLM - - >>> model = Glm4MoeForCausalLM.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -357,4 +349,3 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 71dbd7be7..e307bdcef 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -466,7 +466,7 @@ class CustomOpsTransform(ModuleMappingTransform): Qwen3MoeRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, Olmo2RMSNorm: CustomRMSNormAIC, - Glm4MoeRMSNorm:CustomRMSNormAIC, + Glm4MoeRMSNorm: CustomRMSNormAIC, } From 3f19177c21f5d43412b7a663c0a2fc5cbae8894e Mon Sep 17 00:00:00 2001 From: shagsood Date: Thu, 13 Nov 2025 00:52:33 -0800 Subject: [PATCH 3/4] Fix modeling file issue Signed-off-by: shagsood --- .../models/glm4_moe/modeling_glm4_moe.py | 98 +++++++++++++++++-- .../transformers/models/pytorch_transforms.py | 6 ++ 2 files changed, 94 insertions(+), 10 deletions(-) diff --git a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py index 9f31bf1bb..ac73dc54b 100644 --- a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -8,6 +8,7 @@ from typing import Callable, List, Optional, Union import torch +import torch.nn.functional as F from torch import nn from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -17,7 +18,9 @@ Glm4MoeDecoderLayer, Glm4MoeForCausalLM, Glm4MoeModel, + Glm4MoeMoE, Glm4MoeRotaryEmbedding, + Glm4MoeTopkRouter, repeat_kv, rotate_half, ) @@ -252,9 +255,14 @@ def forward( batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -263,7 +271,9 @@ def forward( if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True if past_key_values is None: past_key_values = QEffDynamicCache() else: @@ -287,7 +297,13 @@ def forward( hidden_states = inputs_embeds + # decoder layers + all_hidden_states = () if output_hidden_states else None + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + hidden_states = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -300,12 +316,78 @@ def forward( hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, + hidden_states=all_hidden_states, ) +class QEffGlm4MoeTopkRouter(Glm4MoeTopkRouter): + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + + # Create expert mask similar to Granite + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=self.n_routed_experts) + expert_mask = expert_mask.permute(2, 0, 1) # [num_experts, batch*seq, top_k] + + return topk_weights, expert_mask, router_logits, self.n_routed_experts + + +class QEffGlm4MoeMoE(Glm4MoeMoE): + """ + Optimized mixed expert module for ONNX export. + """ + + def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): + """ + Optimized MoE forward pass avoiding dynamic operations. + """ + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + + for expert_idx in range(num_experts): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx].to(hidden_states.dtype) + mask_weight = (topk_weights * mask).sum(dim=1, keepdim=True) + gate_out = expert.gate_proj(hidden_states) + up_out = expert.up_proj(hidden_states) + hidden = expert.act_fn(gate_out) * up_out + expert_output = expert.down_proj(hidden) + current_hidden_states = expert_output * mask_weight + final_hidden_states += current_hidden_states + + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + """ + Forward pass of the mixture of experts layer. + """ + residuals = hidden_states + orig_shape = hidden_states.shape + + topk_weights, expert_mask, router_logits, num_experts = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_weights, expert_mask, num_experts).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + + return hidden_states + + class QEffGlm4MoeForCausalLM(Glm4MoeForCausalLM): def forward( self, @@ -315,10 +397,9 @@ def forward( past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: outputs: BaseModelOutputWithPast = self.model( @@ -329,21 +410,18 @@ def forward( batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, + output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() return CausalLMOutputWithPast( - loss=loss, + loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index e307bdcef..0eb912935 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -49,8 +49,10 @@ Glm4MoeDecoderLayer, Glm4MoeForCausalLM, Glm4MoeModel, + Glm4MoeMoE, Glm4MoeRMSNorm, Glm4MoeRotaryEmbedding, + Glm4MoeTopkRouter, ) from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( @@ -253,7 +255,9 @@ QEffGlm4MoeDecoderLayer, QEffGlm4MoeForCausalLM, QEffGlm4MoeModel, + QEffGlm4MoeMoE, QEffGlm4MoeRotaryEmbedding, + QEffGlm4MoeTopkRouter, ) from QEfficient.transformers.models.gpt2.modeling_gpt2 import ( QEffGPT2Attention, @@ -561,6 +565,8 @@ class KVCacheTransform(ModuleMappingTransform): Glm4MoeAttention: QEffGlm4MoeAttention, Glm4MoeDecoderLayer: QEffGlm4MoeDecoderLayer, Glm4MoeRotaryEmbedding: QEffGlm4MoeRotaryEmbedding, + Glm4MoeMoE: QEffGlm4MoeMoE, + Glm4MoeTopkRouter: QEffGlm4MoeTopkRouter, # mllama MllamaTextRMSNorm: CustomRMSNormAIC, MllamaTextSelfAttention: QEffMllamaTextSelfAttention, From d80edb67be5e2c17361cdcb0b650e7d5dd847404 Mon Sep 17 00:00:00 2001 From: Shagun Date: Fri, 14 Nov 2025 01:33:30 -0800 Subject: [PATCH 4/4] Add Glm4MoeForCausalLM support Signed-off-by: shagsood --- .../models/glm4_moe/modeling_glm4_moe.py | 25 ++++++------------- QEfficient/utils/run_utils.py | 3 +++ .../models/test_causal_lm_models.py | 1 + 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py index ac73dc54b..85ce65f57 100644 --- a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -30,12 +30,11 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE -from QEfficient.utils.logging_utils import logger class QEffGlm4MoeRotaryEmbedding(Glm4MoeRotaryEmbedding): """ - Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + Copied from Glm4MoeForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4_moe/modeling_glm4_moe.py The only differences are: - Add static sin/cos computations. """ @@ -151,7 +150,6 @@ def __qeff_init__(self): def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, @@ -274,15 +272,7 @@ def forward( return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True - if past_key_values is None: - past_key_values = QEffDynamicCache() - else: - past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -351,13 +341,10 @@ def forward(self, hidden_states): class QEffGlm4MoeMoE(Glm4MoeMoE): """ - Optimized mixed expert module for ONNX export. + MoE Block """ def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): - """ - Optimized MoE forward pass avoiding dynamic operations. - """ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) for expert_idx in range(num_experts): @@ -375,7 +362,7 @@ def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_ma def forward(self, hidden_states): """ - Forward pass of the mixture of experts layer. + Forward pass of MoE block. """ residuals = hidden_states orig_shape = hidden_states.shape @@ -389,6 +376,10 @@ def forward(self, hidden_states): class QEffGlm4MoeForCausalLM(Glm4MoeForCausalLM): + """ + Copied from Glm4MoeForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4_moe/modeling_glm4_moe.py + """ + def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index c54dadeac..448aba57a 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -104,6 +104,9 @@ def run_hf_model_on_pytorch(self, model_hf): """ model_inputs = self.input_handler.tokenizer(self.input_handler.prompt[0], return_tensors="pt") + if "token_type_ids" in model_inputs: + model_inputs.pop("token_type_ids") + input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 321a466ab..658c2d8a9 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -53,6 +53,7 @@ "hpcai-tech/grok-1", "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", "allenai/OLMo-2-0425-1B", + "zai-org/GLM-4.5-Air", ] test_models_qnn = [