diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 0c44070b7..f5b67460e 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -1,6 +1,7 @@ import math from typing import Dict, List, Optional, Tuple, Union +import os import torch import torch.nn.functional as F from torch import nn @@ -25,6 +26,7 @@ # logger, # ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE def rotate_half(x): @@ -234,7 +236,14 @@ def __qeff_init__( # self.register_buffer("per_head_k_up", per_head_k_up.detach().clone(), persistent=False) self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) - fusedqk = torch.bmm(per_head_q_up, per_head_k_up) + + out = torch.matmul(per_head_q_up[0,:,:], per_head_k_up[0,:,:]) + for i in range(1, self.num_heads): + x = torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:]) + out = torch.cat((out,x), 0) + fusedqk = out.reshape(self.num_heads, -1, self.kv_lora_rank) + + #fusedqk = torch.bmm(per_head_q_up, per_head_k_up) # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) kv_a_proj_with_mqa_ckv, kv_a_proj_with_mqa_k_pe = self.kv_a_proj_with_mqa.weight.T.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -290,30 +299,39 @@ def fused_forward( else: enable_absorption = False - if enable_absorption: - if absorb_online: - print("online absorption") - atn = torch.matmul( - torch.matmul(q_a_proj_out.unsqueeze(1), torch.bmm(self.per_head_q_up, self.per_head_k_up)), - kva.transpose(1, 2).unsqueeze(1), - ) + x = [] + for i in range(self.num_heads): + if enable_absorption: + if absorb_online: + if i==0: + print("online absorption") + out = torch.matmul(self.per_head_q_up[i,:,:], self.per_head_k_up[i,:,:]) + out = out.reshape(1, -1, self.kv_lora_rank) + out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out) + else: + if i==0: + print("using fused qk") + out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[i,:,:]) + + out3 = torch.cat((out2, q_pe[:,i,:,:].unsqueeze(1)), -1) + kva_kpe = torch.cat((kva,k_pe.squeeze(1)), -1) + attn_weights = torch.matmul(out3, kva_kpe.transpose(1, 2).unsqueeze(1)) * self.softmax_scale else: - print("using fused qk") - atn = torch.matmul( - torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1) - ) - else: - print("no absorption") - atn = torch.matmul(q_nope, k_nope.transpose(2, 3)) + if i==0: + print("no absorption") + query_states = torch.cat((q_nope[:,i,:,:], q_pe[:,i,:,:]), -1) + key_states = torch.cat((k_nope[:,i,:,:].unsqueeze(1), k_pe), -1) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - atr = torch.matmul(q_pe, k_pe.expand(-1, self.num_heads, -1, -1).transpose(2, 3)) - attn_weights = (atn + atr) * self.softmax_scale + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + attn_output = torch.matmul(attn_weights, value_states[:,i,:,:]) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - attn_output = torch.matmul(attn_weights, value_states) + x.append(attn_output) + + attn_output = torch.cat(x, dim=1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) @@ -356,7 +374,7 @@ def forward( q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = torch.cat((q_nope, q_pe), -1) - k_pe_new = k_pe.repeat(1, self.num_heads, 1, 1) + k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) key_states = torch.cat((k_nope, k_pe_new), -1) if past_key_value is not None: @@ -366,7 +384,7 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) @@ -459,16 +477,18 @@ class QEffPrefillOnlyDeepseekV3MoE(nn.Module): def __qeff_init__( self, ): - self.all_gate_proj = torch.nn.Parameter( - torch.cat([exp.gate_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) - ) - self.all_up_proj = torch.nn.Parameter( - torch.cat([exp.up_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) - ) - self.all_down_proj = torch.nn.Parameter( - torch.cat([exp.down_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) - ) - self.act_fn = self.experts[0].act_fn + for exp in self.experts: + gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) + + gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) + up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) + down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) + + setattr(exp,"gate_proj", gate_proj) + setattr(exp,"up_proj", up_proj) + setattr(exp,"down_proj", down_proj) def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) @@ -481,6 +501,7 @@ def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_ma current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) final_hidden_states += current_hidden_states + print("\n\ninside prefill only moe\n") return final_hidden_states.type(hidden_states.dtype) def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5ff05fed2..cd9d2ca9f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -48,10 +48,13 @@ KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + PrefillOnlyExternalModuleMapperTransform, PrefillOnlyChunkedTransform, PrefillOnlyTransform, + ReplicateKVHeadTransform, RevertPrefillKeepAttentionTransform, RevertPrefillOnlyTransform, + RevertPrefillOnlyExternalModuleMapperTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -2325,12 +2328,14 @@ def prefill( retain_full_kv: Optional[bool] = False, ): if enable: + self.model, tf = PrefillOnlyExternalModuleMapperTransform.apply(self.model) if enable_chunking: self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) else: self.model, tf = PrefillOnlyTransform.apply(self.model) else: + self.model, tf = RevertPrefillOnlyExternalModuleMapperTransform.apply(self.model) if retain_full_kv: self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) else: @@ -2406,6 +2411,10 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + if replicate_kv_transformed: + self.hash_params["config"] = model.config.to_diff_dict() + # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms # are done. The role of the sampler is to just add nodes at the output of the @@ -2617,6 +2626,14 @@ def export( self.hash_params["mla_absorption_config"] = mla_absorption_config setattr(self.model.model, "mla_absorption_config", mla_absorption_config) + if self.model.config.model_type in {"kimi_k2"}: + if prefill_only: + self.prefill(enable=True) + self.hash_params["prefill_only"] = True + else: + self.prefill(enable=False) + self.hash_params.pop("prefill_only", None) + # TODO: move this to a DA Serving utility class if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: if prefill_only: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 216dd6cd3..e676fbc46 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -10,6 +10,7 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +import torch from torch import nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -456,6 +457,7 @@ from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward +from QEfficient.utils.logging_utils import logger SPD_TARGET = "target" @@ -666,7 +668,6 @@ class PrefillOnlyTransform(ModuleMappingTransform): QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, - QEffDeepseekV3MoE: QEffPrefillOnlyDeepseekV3MoE, } @@ -675,7 +676,6 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, - QEffDeepseekV3MoE: QEffPrefillOnlyDeepseekV3MoE, } @@ -686,7 +686,6 @@ class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, - QEffPrefillOnlyDeepseekV3MoE: QEffDeepseekV3MoE, } @@ -697,6 +696,82 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): } +class ReplicateKVHeadTransform: + """ + Replicates KV heads in attention modules to match the number of KV heads in the target model. + This transform is used when the source model has fewer KV heads than required in target model. + """ + + def _duplicate_weights_for_linear_layer( + layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int + ): + new_kv_heads = repeat #for mla + + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave( + layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim) + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave( + layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim) + + def _get_text_model(model): + """ + Determine and return the appropriate text_model from a given model object. + """ + # Check for VLMs + if hasattr(model, "language_model"): + if hasattr(model.language_model, "model"): + return model.language_model.model + else: + return model.language_model + # Check for CausalLMs + if hasattr(model, "model"): + return model.model + + raise AttributeError("No suitable text model found in the provided model.") + + @classmethod + def apply(cls, model: nn.Module, **kwargs) -> nn.Module: + """ + Replicates KV heads in attention modules based on provided multiplier. + + Args: + model: The model to apply the transform to. + kwargs: Additional arguments for the transformation. Includes: + - num_kv_heads_repeat: The number of times to repeat the KV heads. + """ + n_repeat = kwargs.pop("num_kv_heads_repeat", 1) + transformed = False + if n_repeat is not None and n_repeat > 1: + text_model = cls._get_text_model(model) + + orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads + new_kv_heads = n_repeat*orig_kv_heads + text_model.config.orig_kv_heads = orig_kv_heads + text_model.config.num_key_value_heads = new_kv_heads + + num_attention_heads = text_model.config.num_attention_heads + hidden_size = text_model.config.hidden_size + + logger.warning(f"Original KV heads: {orig_kv_heads}") + logger.warning(f"Modified KV heads: {new_kv_heads}") + transformed = True + for block in text_model.layers: + attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) + attn.num_key_value_heads = new_kv_heads + head_dim = attn.kv_lora_rank+attn.qk_rope_head_dim + + cls._duplicate_weights_for_linear_layer( + attn.kv_a_proj_with_mqa, orig_kv_heads, n_repeat, head_dim, hidden_size + ) + return model, transformed + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. @@ -895,6 +970,29 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): }, } +class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_class_replace_method = {} + _match_string_replace_method = { + "DeepseekV3MoE": { + "forward": QEffPrefillOnlyDeepseekV3MoE.forward, + "moe": QEffPrefillOnlyDeepseekV3MoE.moe, + "__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, + }, + } + +class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_class_replace_method = {} + _match_string_replace_method = { + "DeepseekV3MoE": { + "forward": QEffDeepseekV3MoE.forward, + "moe": QEffDeepseekV3MoE.moe, + "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, + }, + } + '''_match_string_replace_method = { + **{v: k for k, v in PrefillOnlyExternalModuleMapperTransform._match_string_replace_method.items()}, + } + ''' class T5ModelTransform(ModuleMappingTransform): # supported architectures diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index 0df64cceb..227109a85 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -5,10 +5,15 @@ from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," +num_kv_heads_repeat=4 #TS=4 +num_hidden_layers=2 +enable_mla=True +mla_absorption_config={"enable": True, "online": True} -model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +#model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +model_path ="/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True + model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) @@ -27,8 +32,8 @@ out = model(**inputs) predictions = torch.argmax(out.logits, dim=-1) -qeff_model = QEFFAutoModelForCausalLM(model) -qeff_model.mla(enable_mla=True, mla_absorption_config={"enable": True, "online": True}) +qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat) +qeff_model.mla(enable_mla=enable_mla, mla_absorption_config=mla_absorption_config) inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) @@ -84,17 +89,22 @@ print("Completion:", repr(predicted_string)) + +prefill_seq_len = 128 +ctx_len = 2048 + onnx_path = qeff_model.export( - prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable": True, "online": True} + prefill_seq_len=prefill_seq_len, enable_mla=enable_mla, mla_absorption_config=mla_absorption_config ) + qpc_path = qeff_model.compile( - prefill_seq_len=1, - ctx_len=1024, - enable_mla=True, - mla_absorption_config={"enable": True, "online": True}, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + enable_mla=enable_mla, + mla_absorption_config=mla_absorption_config, mxfp6_matmul=True, mxint8_kv_cache=False, - num_devices=1, + num_devices=num_kv_heads_repeat, num_cores=16, )