Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.torch_norm import LayerNormBuilder
from megatron.core.transformer.torch_norm import L2Norm, LayerNormBuilder
from megatron.core.typed_torch import apply_module, not_none
from megatron.core.utils import (
deprecate_inference_params,
Expand Down Expand Up @@ -102,10 +102,11 @@
from megatron.core.extensions.transformer_engine import (
SplitAlongDim,
TELinear,
TENorm,
set_save_original_input,
)
else:
SplitAlongDim, TELinear, set_save_original_input = None, None, None
SplitAlongDim, TELinear, TENorm, set_save_original_input = None, None, None, None

try:
from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb
Expand Down Expand Up @@ -1288,23 +1289,51 @@ def __init__(
tp_group=self.pg_collection.tp,
)

if submodules.q_layernorm is not None:
self.q_layernorm = submodules.q_layernorm(
# Resolve which norm class to use for Q and K.
# Config selects the default norm class; spec overrides if set.
if self.config.qk_l2_norm:
q_norm_cls = submodules.q_layernorm or L2Norm
k_norm_cls = submodules.k_layernorm or L2Norm
elif self.config.qk_layernorm:
# TODO(yuzhongw, janpabloe): Support local backend.
q_norm_cls = submodules.q_layernorm or TENorm
k_norm_cls = submodules.k_layernorm or TENorm
if q_norm_cls is None or k_norm_cls is None:
raise ValueError(
"qk_layernorm requires Transformer Engine (for TENorm) or "
"q_layernorm/k_layernorm set in the spec."
)
else:
if submodules.q_layernorm not in (None, IdentityOp):
raise ValueError(
f"spec sets q_layernorm={submodules.q_layernorm} but "
"qk_layernorm/qk_l2_norm are disabled"
)
if submodules.k_layernorm not in (None, IdentityOp):
raise ValueError(
f"spec sets k_layernorm={submodules.k_layernorm} but "
"qk_layernorm/qk_l2_norm are disabled"
)
q_norm_cls = k_norm_cls = None

self.q_layernorm = (
q_norm_cls(
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None

if submodules.k_layernorm is not None:
self.k_layernorm = submodules.k_layernorm(
if q_norm_cls is not None
else None
)
self.k_layernorm = (
k_norm_cls(
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
if k_norm_cls is not None
else None
)

def run_realtime_tests(self):
"""Performs a consistency check.
Expand Down
129 changes: 129 additions & 0 deletions tests/unit_tests/models/test_mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,135 @@ def test_with_custom_process_groups(self, tmp_path, tp_size, cp_size, pp_size):
assert logits.shape[2] == divide(model.vocab_size, tp_size)


class TestMambaQKLayernorm:

def setup_method(self, method):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)

def teardown_method(self, method):
Utils.destroy_model_parallel()

def _build_model(self, **config_overrides):
config = TransformerConfig(
num_layers=3,
hidden_size=256,
num_attention_heads=4,
use_cpu_initialization=True,
**config_overrides,
)
return MambaModel(
config=config,
mamba_stack_spec=mamba_stack_spec,
vocab_size=100,
max_sequence_length=4,
hybrid_layer_pattern="M*-",
)

def _get_attention_layer(self, model):
"""Return the SelfAttention submodule from the attention layer."""
for layer in model.decoder.layers:
if hasattr(layer, 'self_attention') and hasattr(layer.self_attention, 'q_layernorm'):
return layer.self_attention
return None

def test_no_qk_norm_by_default(self):
"""Without qk_layernorm, attention has no q/k layernorm."""
model = self._build_model()
attn = self._get_attention_layer(model)
assert attn is not None
assert attn.q_layernorm is None
assert attn.k_layernorm is None

def test_qk_layernorm_from_config(self):
"""config.qk_layernorm=True creates q/k layernorm even with static spec."""
model = self._build_model(qk_layernorm=True)
attn = self._get_attention_layer(model)
assert attn is not None
# TENorm is a factory (__new__ returns a TE LayerNorm/RMSNorm), so we
# verify the norm was created rather than checking for a specific type.
assert attn.q_layernorm is not None
assert attn.k_layernorm is not None

def test_qk_l2_norm_from_config(self):
"""config.qk_l2_norm=True creates L2Norm q/k layernorm."""
from megatron.core.transformer.torch_norm import L2Norm

model = self._build_model(qk_l2_norm=True)
attn = self._get_attention_layer(model)
assert attn is not None
assert isinstance(attn.q_layernorm, L2Norm)
assert isinstance(attn.k_layernorm, L2Norm)

def test_spec_provided_norm_not_overwritten(self):
"""When the spec already provides q/k layernorm, config doesn't override it."""
import copy

from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)

# Build a spec that explicitly sets q/k layernorm to IdentityOp
spec = copy.deepcopy(mamba_stack_spec)
spec.submodules.attention_layer.submodules.self_attention.submodules.q_layernorm = (
IdentityOp
)
spec.submodules.attention_layer.submodules.self_attention.submodules.k_layernorm = (
IdentityOp
)

config = TransformerConfig(
num_layers=3,
hidden_size=256,
num_attention_heads=4,
use_cpu_initialization=True,
qk_layernorm=True,
)
model = MambaModel(
config=config,
mamba_stack_spec=spec,
vocab_size=100,
max_sequence_length=4,
hybrid_layer_pattern="M*-",
)
attn = self._get_attention_layer(model)
assert attn is not None
assert isinstance(attn.q_layernorm, IdentityOp)
assert isinstance(attn.k_layernorm, IdentityOp)

def test_forward_with_qk_layernorm(self):
"""MambaModel forward pass works with qk_layernorm enabled."""
model = self._build_model(qk_layernorm=True)
model.cuda()

sequence_length = 4
micro_batch_size = 2
data = list(range(sequence_length))
input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
attention_mask = torch.ones(
(micro_batch_size, 1, sequence_length, sequence_length), dtype=bool
).cuda()

logits = model.forward(
input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask
)

assert logits.shape[0] == micro_batch_size
assert logits.shape[1] == sequence_length
assert logits.shape[2] == 100


class TestMambaWithDynamicInference:
"""Tests MambaModel with dynamic inference."""

Expand Down
71 changes: 71 additions & 0 deletions tests/unit_tests/transformer/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,3 +724,74 @@ def test_parallel_attention_correctness_num_query_groups_less_than_tp_size(
seed=123,
sequence_length=256,
)


def test_qk_layernorm_from_config_fallback():
"""config.qk_layernorm=True with spec q/k_layernorm=None builds TENorm."""
te_pytorch = pytest.importorskip("transformer_engine.pytorch")
from dataclasses import replace

Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)
try:
config = TransformerConfig(
num_layers=1,
hidden_size=128,
num_attention_heads=4,
use_cpu_initialization=True,
qk_layernorm=True,
)
base = get_gpt_layer_with_transformer_engine_submodules().self_attention.submodules
submodules = replace(base, q_layernorm=None, k_layernorm=None)
attn = SelfAttention(config, submodules, layer_number=1)
assert isinstance(attn.q_layernorm, te_pytorch.LayerNorm)
assert isinstance(attn.k_layernorm, te_pytorch.LayerNorm)
finally:
Utils.destroy_model_parallel()


def test_qk_l2_norm_from_config_fallback():
"""config.qk_l2_norm=True with spec q/k_layernorm=None builds L2Norm."""
pytest.importorskip("transformer_engine.pytorch")
from dataclasses import replace

from megatron.core.transformer.torch_norm import L2Norm

Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)
try:
config = TransformerConfig(
num_layers=1,
hidden_size=128,
num_attention_heads=4,
use_cpu_initialization=True,
qk_l2_norm=True,
)
base = get_gpt_layer_with_transformer_engine_submodules().self_attention.submodules
submodules = replace(base, q_layernorm=None, k_layernorm=None)
attn = SelfAttention(config, submodules, layer_number=1)
assert isinstance(attn.q_layernorm, L2Norm)
assert isinstance(attn.k_layernorm, L2Norm)
finally:
Utils.destroy_model_parallel()


def test_qk_layernorm_spec_config_mismatch_raises():
"""Spec sets a concrete norm but config disables qk_layernorm/qk_l2_norm -> ValueError."""
pytest.importorskip("transformer_engine")
from dataclasses import replace

from megatron.core.transformer.torch_norm import L2Norm

Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)
try:
config = TransformerConfig(
num_layers=1, hidden_size=128, num_attention_heads=4, use_cpu_initialization=True
)
base = get_gpt_layer_with_transformer_engine_submodules().self_attention.submodules
submodules = replace(base, q_layernorm=L2Norm, k_layernorm=L2Norm)
with pytest.raises(ValueError, match="qk_layernorm"):
SelfAttention(config, submodules, layer_number=1)
finally:
Utils.destroy_model_parallel()
6 changes: 5 additions & 1 deletion tests/unit_tests/transformer/test_spec_customization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def setup_method(self, method):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)
self.config = TransformerConfig(
num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
num_layers=2,
hidden_size=12,
num_attention_heads=4,
use_cpu_initialization=True,
qk_l2_norm=True,
)

# specify Transformer Layer spec with all identity ops
Expand Down
Loading