Skip to content

Commit f71559b

Browse files
Add fused_rms_norm usage
1 parent aca147c commit f71559b

1 file changed

Lines changed: 26 additions & 3 deletions

File tree

src/transformers/models/mistral/modeling_mistral_moreh.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,21 @@
5656

5757
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
5858

59+
try:
60+
# Moreh extension
61+
moreh_ops = torch.ops.moreh
62+
MorehRMSNorm = moreh_ops.T5LayerNorm
63+
except AttributeError:
64+
MorehRMSNorm = None
5965

6066
logger = logging.get_logger(__name__)
6167

6268
_CONFIG_FOR_DOC = "MistralConfig"
6369

70+
if MorehRMSNorm is not None:
71+
logger.warning(
72+
"You can't use Masked Structured Growth Training..! You should avoid using rmsnorm in any way. "
73+
)
6474

6575
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
6676
def _get_unpad_data(attention_mask):
@@ -74,6 +84,12 @@ def _get_unpad_data(attention_mask):
7484
max_seqlen_in_batch,
7585
)
7686

87+
def get_moreh_fused_rms_norm(config):
88+
moreh_config = getattr(config, "moreh_config", None)
89+
if moreh_config is not None and "fused_rms_norm" in moreh_config:
90+
return moreh_config["fused_rms_norm"]
91+
return False
92+
7793

7894
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
7995
class MistralRMSNorm(nn.Module):
@@ -710,8 +726,12 @@ def __init__(self, config: MistralMorehConfig, layer_idx: int):
710726
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
711727

712728
self.mlp = MistralMLP(config)
713-
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
714-
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
729+
if get_moreh_fused_rms_norm(config):
730+
self.input_layernorm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
731+
self.post_attention_layernorm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
732+
else:
733+
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
734+
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
715735

716736
def forward(
717737
self,
@@ -912,7 +932,10 @@ def __init__(self, config: MistralMorehConfig):
912932
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
913933
)
914934
self._attn_implementation = config._attn_implementation
915-
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
935+
if get_moreh_fused_rms_norm(config):
936+
self.norm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
937+
else:
938+
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
916939

917940
self.gradient_checkpointing = False
918941
# Initialize weights and apply final processing

0 commit comments

Comments
 (0)