Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions megatron/core/extensions/transformer_engine_spec_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import warnings
from functools import partial
from typing import Optional, cast
from typing import cast

from typing_extensions import final, override

from megatron.core.extensions.transformer_engine import (
TEActivationOp,
Expand Down Expand Up @@ -33,29 +35,31 @@ def __new__(cls, *args, **kwargs):
return TENorm(*args, has_residual=True, **kwargs)


@final
class TESpecProvider(BackendSpecProvider):
"""A protocol for providing the submodules used in Spec building."""

def linear(self) -> type:
@override
def linear(self) -> type[TELinear]:
"""Which linear module TE backend uses"""
return TELinear

def column_parallel_linear(self) -> type:
@override
def column_parallel_linear(self) -> type[TEColumnParallelLinear]:
"""Which column parallel linear module TE backend uses"""
return TEColumnParallelLinear

def row_parallel_linear(self) -> type:
@override
def row_parallel_linear(self) -> type[TERowParallelLinear]:
"""Which row parallel linear module TE backend uses"""
return TERowParallelLinear

def fuse_layernorm_and_linear(self) -> bool:
"""TE backend chooses a single module for layernorm and linear"""
return True

def column_parallel_layer_norm_linear(self) -> Optional[type]:
@override
def column_parallel_layer_norm_linear(self) -> type[TELayerNormColumnParallelLinear]:
"""Which module for sequential layernorm and linear"""
return TELayerNormColumnParallelLinear

@override
def layer_norm(
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
) -> LayerNormBuilder:
Expand All @@ -68,10 +72,12 @@ def layer_norm(
# Keep returning a class so this path stays aligned with build_module's class handling.
return _TENormWithResidual if has_residual else TENorm

def core_attention(self) -> type:
@override
def core_attention(self) -> type[TEDotProductAttention]:
"""Which module to use for attention"""
return TEDotProductAttention

@override
def grouped_mlp_modules(self, moe_use_grouped_gemm: bool) -> ExpertsBuilder:
"""Which module and submodules to use for grouped mlp"""
if moe_use_grouped_gemm and TEColumnParallelGroupedLinear is not None:
Expand Down Expand Up @@ -107,6 +113,7 @@ def grouped_mlp_modules(self, moe_use_grouped_gemm: bool) -> ExpertsBuilder:
),
)

@override
def activation_func(self) -> TEActivationFunctionBuilder | None:
"""Which module to use for activation function"""
# transformer_engine.BasicOperation.forward has an overly permissive return type, but by
Expand Down
68 changes: 45 additions & 23 deletions megatron/core/models/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
from functools import partial
from typing import Optional, Protocol, cast

from typing_extensions import final, override

from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TERowParallelGroupedLinear,
)
from megatron.core.models.protocols import (
ColumnParallelLinearBuilder,
CoreAttentionBuilder,
LinearBuilder,
RowParallelLinearBuilder,
)
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.mlp import MLPSubmodules, TEActivationFunctionBuilder
Expand Down Expand Up @@ -54,22 +62,22 @@ class BackendSpecProvider(Protocol):
"""A protocol for providing the submodules used in Spec building."""

@abstractmethod
def column_parallel_linear(self) -> type:
"""Which column parallel linear module the backend uses"""
def linear(self) -> LinearBuilder:
"""Which linear module the backend uses"""
...

@abstractmethod
def row_parallel_linear(self) -> type:
"""Which row parallel linear module the backend uses"""
def column_parallel_linear(self) -> ColumnParallelLinearBuilder:
"""Which column parallel linear module the backend uses"""
...

@abstractmethod
def fuse_layernorm_and_linear(self) -> bool:
"""Does the backend support a single module for layernorm and linear"""
def row_parallel_linear(self) -> RowParallelLinearBuilder:
"""Which row parallel linear module the backend uses"""
...

@abstractmethod
def column_parallel_layer_norm_linear(self) -> Optional[type]:
def column_parallel_layer_norm_linear(self) -> Optional[ColumnParallelLinearBuilder]:
"""Which module for sequential layernorm and linear"""
...

Expand All @@ -81,7 +89,7 @@ def layer_norm(
...

@abstractmethod
def core_attention(self) -> type:
def core_attention(self) -> CoreAttentionBuilder:
"""Which module to use for attention"""
...

Expand All @@ -96,25 +104,31 @@ def activation_func(self) -> TEActivationFunctionBuilder | None:
...


@final
class LocalSpecProvider(BackendSpecProvider):
"""A protocol for providing Local submodules used in Spec building."""

def column_parallel_linear(self) -> type:
@override
def linear(self) -> LinearBuilder:
"""Which linear module the backend uses"""
raise NotImplementedError("LocalSpecProvider does not have a linear module")

@override
def column_parallel_linear(self) -> type[ColumnParallelLinear]:
"""Which column parallel linear module the backend uses"""
return ColumnParallelLinear

def row_parallel_linear(self) -> type:
@override
def row_parallel_linear(self) -> type[RowParallelLinear]:
"""Which row parallel linear module the backend uses"""
return RowParallelLinear

def fuse_layernorm_and_linear(self) -> bool:
"""Does the backend choose a single module for layernorm and linear"""
return False

def column_parallel_layer_norm_linear(self) -> Optional[type]:
@override
def column_parallel_layer_norm_linear(self) -> None:
"""Which module for sequential layernorm and linear"""
return None

@override
def layer_norm(
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
) -> LayerNormBuilder:
Expand All @@ -126,10 +140,12 @@ def layer_norm(
LNImpl = WrappedTorchNorm
return LNImpl

def core_attention(self) -> type:
@override
def core_attention(self) -> type[DotProductAttention]:
"""Which module to use for attention"""
return DotProductAttention

@override
def grouped_mlp_modules(self, moe_use_grouped_gemm: bool) -> ExpertsBuilder:
"""Which module and submodules to use for grouped mlp"""
return partial(
Expand All @@ -141,34 +157,37 @@ def grouped_mlp_modules(self, moe_use_grouped_gemm: bool) -> ExpertsBuilder:
),
)

@override
def activation_func(self) -> TEActivationFunctionBuilder | None:
"""Which module to use for activation function"""
return None


@final
class InferenceSpecProvider(BackendSpecProvider):
"""A protocol for providing the submodules used in Spec building."""

def linear(self) -> type:
@override
def linear(self) -> type[TELinear]:
"""Which linear module TE backend uses"""
return TELinear

def column_parallel_linear(self) -> type:
@override
def column_parallel_linear(self) -> type[InferenceColumnParallelLinear]:
"""Which column parallel linear module TE backend uses"""
return InferenceColumnParallelLinear

def row_parallel_linear(self) -> type:
@override
def row_parallel_linear(self) -> type[InferenceRowParallelLinear]:
"""Which row parallel linear module TE backend uses"""
return InferenceRowParallelLinear

def fuse_layernorm_and_linear(self) -> bool:
"""TE backend chooses a single module for layernorm and linear"""
return True

@override
def column_parallel_layer_norm_linear(self) -> type[InferenceLayerNormColumnParallelLinear]:
"""Which module for sequential layernorm and linear"""
return InferenceLayerNormColumnParallelLinear

@override
def layer_norm(
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
) -> LayerNormBuilder:
Expand All @@ -180,16 +199,19 @@ def layer_norm(
return not_none(FusedLayerNorm)
return TENorm

@override
def core_attention(self) -> type[TEDotProductAttention]:
"""Which module to use for attention"""
return TEDotProductAttention

@override
def activation_func(self) -> TEActivationFunctionBuilder | None:
"""Which module to use for activation function"""
# transformer_engine.BasicOperation.forward has an overly permissive return type, but by
# design these classes always meet the interface.
return cast(TEActivationFunctionBuilder, TEActivationOp)

@override
def grouped_mlp_modules(self, moe_use_grouped_gemm: bool) -> ExpertsBuilder:
"""Which module and submodules to use for grouped mlp"""
return partial(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ def _get_self_attention_module_spec(
if config.multi_latent_attention:
attn_spec.metainfo["fuse_input_layernorm"] = False
else:
attn_spec.metainfo["fuse_input_layernorm"] = backend.fuse_layernorm_and_linear()
attn_spec.metainfo["fuse_input_layernorm"] = (
backend.column_parallel_layer_norm_linear() is None
)

return attn_spec

Expand All @@ -425,7 +427,9 @@ def _get_dense_mlp_module_spec(
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_for_backend

mlp_spec = get_mlp_module_spec_for_backend(backend=backend, num_experts=None)
mlp_spec.metainfo["fuse_pre_mlp_layernorm"] = backend.fuse_layernorm_and_linear()
mlp_spec.metainfo["fuse_pre_mlp_layernorm"] = (
backend.column_parallel_layer_norm_linear() is None
)

return mlp_spec

Expand Down
10 changes: 3 additions & 7 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
TransformerLayerSubmodules,
get_transformer_layer_offset,
)
from megatron.core.typed_torch import copy_signature
from megatron.core.typed_torch import copy_signature, not_none
from megatron.core.utils import is_te_min_version

if HAVE_TE:
Expand Down Expand Up @@ -313,7 +313,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=backend.column_parallel_layer_norm_linear(),
linear_qkv=not_none(backend.column_parallel_layer_norm_linear()),
core_attention=backend.core_attention(),
linear_proj=backend.row_parallel_linear(),
q_layernorm=(
Expand Down Expand Up @@ -525,11 +525,7 @@ def get_mlp_module_spec_for_backend(
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
module = TEFusedMLP if use_te_op_fuser else MLP
if backend.fuse_layernorm_and_linear():
linear_fc1 = backend.column_parallel_layer_norm_linear()
assert linear_fc1 is not None
else:
linear_fc1 = backend.column_parallel_linear()
linear_fc1 = backend.column_parallel_layer_norm_linear() or backend.column_parallel_linear()
return ModuleSpec(
module=module,
submodules=MLPSubmodules(
Expand Down
Loading
Loading