5656
5757 _flash_supports_window_size = "window_size" in list (inspect .signature (flash_attn_func ).parameters )
5858
59+ try :
60+ # Moreh extension
61+ moreh_ops = torch .ops .moreh
62+ MorehRMSNorm = moreh_ops .T5LayerNorm
63+ except AttributeError :
64+ MorehRMSNorm = None
5965
6066logger = logging .get_logger (__name__ )
6167
6268_CONFIG_FOR_DOC = "MistralConfig"
6369
70+ if MorehRMSNorm is not None :
71+ logger .warning (
72+ "You can't use Masked Structured Growth Training..! You should avoid using rmsnorm in any way. "
73+ )
6474
6575# Copied from transformers.models.llama.modeling_llama._get_unpad_data
6676def _get_unpad_data (attention_mask ):
@@ -74,6 +84,12 @@ def _get_unpad_data(attention_mask):
7484 max_seqlen_in_batch ,
7585 )
7686
87+ def get_moreh_fused_rms_norm (config ):
88+ moreh_config = getattr (config , "moreh_config" , None )
89+ if moreh_config is not None and "fused_rms_norm" in moreh_config :
90+ return moreh_config ["fused_rms_norm" ]
91+ return False
92+
7793
7894# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
7995class MistralRMSNorm (nn .Module ):
@@ -710,8 +726,12 @@ def __init__(self, config: MistralMorehConfig, layer_idx: int):
710726 self .self_attn = MISTRAL_ATTENTION_CLASSES [config ._attn_implementation ](config = config , layer_idx = layer_idx )
711727
712728 self .mlp = MistralMLP (config )
713- self .input_layernorm = MistralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
714- self .post_attention_layernorm = MistralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
729+ if get_moreh_fused_rms_norm (config ):
730+ self .input_layernorm = MorehRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
731+ self .post_attention_layernorm = MorehRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
732+ else :
733+ self .input_layernorm = MistralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
734+ self .post_attention_layernorm = MistralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
715735
716736 def forward (
717737 self ,
@@ -912,7 +932,10 @@ def __init__(self, config: MistralMorehConfig):
912932 [MistralDecoderLayer (config , layer_idx ) for layer_idx in range (config .num_hidden_layers )]
913933 )
914934 self ._attn_implementation = config ._attn_implementation
915- self .norm = MistralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
935+ if get_moreh_fused_rms_norm (config ):
936+ self .norm = MorehRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
937+ else :
938+ self .norm = MistralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
916939
917940 self .gradient_checkpointing = False
918941 # Initialize weights and apply final processing
0 commit comments