Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import Dict, List, Optional, Tuple, Union

import os
import torch
import torch.nn.functional as F
from torch import nn
Expand All @@ -25,6 +26,7 @@
# logger,
# )
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


def rotate_half(x):
Expand Down Expand Up @@ -234,7 +236,14 @@ def __qeff_init__(
# self.register_buffer("per_head_k_up", per_head_k_up.detach().clone(), persistent=False)
self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone())
self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone())
fusedqk = torch.bmm(per_head_q_up, per_head_k_up)

out = torch.matmul(per_head_q_up[0,:,:], per_head_k_up[0,:,:])
for i in range(1, self.num_heads):
x = torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:])
out = torch.cat((out,x), 0)
fusedqk = out.reshape(self.num_heads, -1, self.kv_lora_rank)

#fusedqk = torch.bmm(per_head_q_up, per_head_k_up)
# self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False)
self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone())
kv_a_proj_with_mqa_ckv, kv_a_proj_with_mqa_k_pe = self.kv_a_proj_with_mqa.weight.T.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
Expand Down Expand Up @@ -290,30 +299,39 @@ def fused_forward(
else:
enable_absorption = False

if enable_absorption:
if absorb_online:
print("online absorption")
atn = torch.matmul(
torch.matmul(q_a_proj_out.unsqueeze(1), torch.bmm(self.per_head_q_up, self.per_head_k_up)),
kva.transpose(1, 2).unsqueeze(1),
)
x = []
for i in range(self.num_heads):
if enable_absorption:
if absorb_online:
if i==0:
print("online absorption")
out = torch.matmul(self.per_head_q_up[i,:,:], self.per_head_k_up[i,:,:])
out = out.reshape(1, -1, self.kv_lora_rank)
out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out)
else:
if i==0:
print("using fused qk")
out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[i,:,:])

out3 = torch.cat((out2, q_pe[:,i,:,:].unsqueeze(1)), -1)
kva_kpe = torch.cat((kva,k_pe.squeeze(1)), -1)
attn_weights = torch.matmul(out3, kva_kpe.transpose(1, 2).unsqueeze(1)) * self.softmax_scale
else:
print("using fused qk")
atn = torch.matmul(
torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1)
)
else:
print("no absorption")
atn = torch.matmul(q_nope, k_nope.transpose(2, 3))
if i==0:
print("no absorption")
query_states = torch.cat((q_nope[:,i,:,:], q_pe[:,i,:,:]), -1)
key_states = torch.cat((k_nope[:,i,:,:].unsqueeze(1), k_pe), -1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale

atr = torch.matmul(q_pe, k_pe.expand(-1, self.num_heads, -1, -1).transpose(2, 3))
attn_weights = (atn + atr) * self.softmax_scale
if attention_mask is not None: # no matter the length, we just slice it
attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights)

if attention_mask is not None: # no matter the length, we just slice it
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype)
attn_output = torch.matmul(attn_weights, value_states[:,i,:,:])

attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype)
attn_output = torch.matmul(attn_weights, value_states)
x.append(attn_output)

attn_output = torch.cat(x, dim=1)

attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -356,7 +374,7 @@ def forward(
q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

query_states = torch.cat((q_nope, q_pe), -1)
k_pe_new = k_pe.repeat(1, self.num_heads, 1, 1)
k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1)
key_states = torch.cat((k_nope, k_pe_new), -1)

if past_key_value is not None:
Expand All @@ -366,7 +384,7 @@ def forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale

if attention_mask is not None: # no matter the length, we just slice it
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights)

attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
Expand Down Expand Up @@ -459,16 +477,18 @@ class QEffPrefillOnlyDeepseekV3MoE(nn.Module):
def __qeff_init__(
self,
):
self.all_gate_proj = torch.nn.Parameter(
torch.cat([exp.gate_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0)
)
self.all_up_proj = torch.nn.Parameter(
torch.cat([exp.up_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0)
)
self.all_down_proj = torch.nn.Parameter(
torch.cat([exp.down_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0)
)
self.act_fn = self.experts[0].act_fn
for exp in self.experts:
gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False)
up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False)
down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False)

gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj))
up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj))
down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj))

setattr(exp,"gate_proj", gate_proj)
setattr(exp,"up_proj", up_proj)
setattr(exp,"down_proj", down_proj)

def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int):
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
Expand All @@ -481,6 +501,7 @@ def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_ma
current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1)
final_hidden_states += current_hidden_states

print("\n\ninside prefill only moe\n")
return final_hidden_states.type(hidden_states.dtype)

def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
Expand Down
17 changes: 17 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@
KVCacheExternalModuleMapperTransform,
KVCacheTransform,
PoolingTransform,
PrefillOnlyExternalModuleMapperTransform,
PrefillOnlyChunkedTransform,
PrefillOnlyTransform,
ReplicateKVHeadTransform,
RevertPrefillKeepAttentionTransform,
RevertPrefillOnlyTransform,
RevertPrefillOnlyExternalModuleMapperTransform,
SamplerTransform,
SpDTransform,
VlmKVOffloadTransform,
Expand Down Expand Up @@ -2325,12 +2328,14 @@ def prefill(
retain_full_kv: Optional[bool] = False,
):
if enable:
self.model, tf = PrefillOnlyExternalModuleMapperTransform.apply(self.model)
if enable_chunking:
self.model, tf = PrefillOnlyChunkedTransform.apply(self.model)
else:
self.model, tf = PrefillOnlyTransform.apply(self.model)

else:
self.model, tf = RevertPrefillOnlyExternalModuleMapperTransform.apply(self.model)
if retain_full_kv:
self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model)
else:
Expand Down Expand Up @@ -2406,6 +2411,10 @@ def __init__(
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.hash_params["max_seq_len_cached"] = max_seq_len_cached

self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
if replicate_kv_transformed:
self.hash_params["config"] = model.config.to_diff_dict()

# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
# are done. The role of the sampler is to just add nodes at the output of the
Expand Down Expand Up @@ -2617,6 +2626,14 @@ def export(
self.hash_params["mla_absorption_config"] = mla_absorption_config
setattr(self.model.model, "mla_absorption_config", mla_absorption_config)

if self.model.config.model_type in {"kimi_k2"}:
if prefill_only:
self.prefill(enable=True)
self.hash_params["prefill_only"] = True
else:
self.prefill(enable=False)
self.hash_params.pop("prefill_only", None)

# TODO: move this to a DA Serving utility class
if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH:
if prefill_only:
Expand Down
104 changes: 101 additions & 3 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from types import MethodType
from typing import Callable, Optional, Tuple, Union

import torch
from torch import nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
Expand Down Expand Up @@ -456,6 +457,7 @@
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
from QEfficient.transformers.sampler.sampler import sampler_forward
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward
from QEfficient.utils.logging_utils import logger

SPD_TARGET = "target"

Expand Down Expand Up @@ -666,7 +668,6 @@ class PrefillOnlyTransform(ModuleMappingTransform):
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffGptOssAttention: QEffPrefillOnlyGptOssAttention,
QEffGptOssMLP: QEffPrefillOnlyGptOssMLP,
QEffDeepseekV3MoE: QEffPrefillOnlyDeepseekV3MoE,
}


Expand All @@ -675,7 +676,6 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform):
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP,
QEffDeepseekV3MoE: QEffPrefillOnlyDeepseekV3MoE,
}


Expand All @@ -686,7 +686,6 @@ class RevertPrefillKeepAttentionTransform(ModuleMappingTransform):
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffPrefillOnlyGptOssMLP: QEffGptOssMLP,
QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP,
QEffPrefillOnlyDeepseekV3MoE: QEffDeepseekV3MoE,
}


Expand All @@ -697,6 +696,82 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform):
}


class ReplicateKVHeadTransform:
"""
Replicates KV heads in attention modules to match the number of KV heads in the target model.
This transform is used when the source model has fewer KV heads than required in target model.
"""

def _duplicate_weights_for_linear_layer(
layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int
):
new_kv_heads = repeat #for mla

layer.weight.data = torch.repeat_interleave(
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
).view(new_kv_heads * head_dim, hidden_size)
if layer.bias is not None:
layer.bias.data = torch.repeat_interleave(
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim)
if layer.bias is not None:
layer.bias.data = torch.repeat_interleave(
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim)

def _get_text_model(model):
"""
Determine and return the appropriate text_model from a given model object.
"""
# Check for VLMs
if hasattr(model, "language_model"):
if hasattr(model.language_model, "model"):
return model.language_model.model
else:
return model.language_model
# Check for CausalLMs
if hasattr(model, "model"):
return model.model

raise AttributeError("No suitable text model found in the provided model.")

@classmethod
def apply(cls, model: nn.Module, **kwargs) -> nn.Module:
"""
Replicates KV heads in attention modules based on provided multiplier.

Args:
model: The model to apply the transform to.
kwargs: Additional arguments for the transformation. Includes:
- num_kv_heads_repeat: The number of times to repeat the KV heads.
"""
n_repeat = kwargs.pop("num_kv_heads_repeat", 1)
transformed = False
if n_repeat is not None and n_repeat > 1:
text_model = cls._get_text_model(model)

orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads
new_kv_heads = n_repeat*orig_kv_heads
text_model.config.orig_kv_heads = orig_kv_heads
text_model.config.num_key_value_heads = new_kv_heads

num_attention_heads = text_model.config.num_attention_heads
hidden_size = text_model.config.hidden_size

logger.warning(f"Original KV heads: {orig_kv_heads}")
logger.warning(f"Modified KV heads: {new_kv_heads}")
transformed = True
for block in text_model.layers:
attn = getattr(block, "cross_attn", getattr(block, "self_attn", None))
attn.num_key_value_heads = new_kv_heads
head_dim = attn.kv_lora_rank+attn.qk_rope_head_dim

cls._duplicate_weights_for_linear_layer(
attn.kv_a_proj_with_mqa, orig_kv_heads, n_repeat, head_dim, hidden_size
)
return model, transformed


class SpDTransform:
"""
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.
Expand Down Expand Up @@ -895,6 +970,29 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
},
}

class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform):
_match_class_replace_method = {}
_match_string_replace_method = {
"DeepseekV3MoE": {
"forward": QEffPrefillOnlyDeepseekV3MoE.forward,
"moe": QEffPrefillOnlyDeepseekV3MoE.moe,
"__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__,
},
}

class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform):
_match_class_replace_method = {}
_match_string_replace_method = {
"DeepseekV3MoE": {
"forward": QEffDeepseekV3MoE.forward,
"moe": QEffDeepseekV3MoE.moe,
"__qeff_init__": QEffDeepseekV3MoE.__qeff_init__,
},
}
'''_match_string_replace_method = {
**{v: k for k, v in PrefillOnlyExternalModuleMapperTransform._match_string_replace_method.items()},
}
'''

class T5ModelTransform(ModuleMappingTransform):
# supported architectures
Expand Down
Loading