Skip to content

Commit 9db13d6

Browse files
committed
Use Protocols to type-check linear_proj submodules of Attention
1 parent 28ccdaa commit 9db13d6

12 files changed

Lines changed: 141 additions & 73 deletions

File tree

examples/multimodal/layer_specs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
114114
submodules=SelfAttentionSubmodules(
115115
linear_qkv=not_none(TELayerNormColumnParallelLinear),
116116
core_attention=not_none(TEDotProductAttention),
117-
linear_proj=TERowParallelLinear,
117+
linear_proj=not_none(TERowParallelLinear),
118118
q_layernorm=IdentityOp,
119119
k_layernorm=IdentityOp,
120120
),
@@ -160,7 +160,7 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
160160
submodules=SelfAttentionSubmodules(
161161
linear_qkv=not_none(TELayerNormColumnParallelLinear),
162162
core_attention=not_none(TEDotProductAttention),
163-
linear_proj=TERowParallelLinear,
163+
linear_proj=not_none(TERowParallelLinear),
164164
),
165165
),
166166
self_attn_bda=get_bias_dropout_add,

examples/multimodal/radio/radio_g.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def get_radio_g_layer_spec_te() -> ModuleSpec:
127127
submodules=SelfAttentionSubmodules(
128128
linear_qkv=not_none(TELayerNormColumnParallelLinear),
129129
core_attention=not_none(TEDotProductAttention),
130-
linear_proj=TERowParallelLinear,
130+
linear_proj=not_none(TERowParallelLinear),
131131
q_layernorm=IdentityOp,
132132
k_layernorm=IdentityOp,
133133
),

megatron/core/extensions/transformer_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
from __future__ import annotations
23

34
import dataclasses
45
import enum
@@ -672,7 +673,7 @@ def will_execute_quantized(self, is_context_quantized: bool) -> bool:
672673
self.te_quant_params, self.training, is_context_quantized
673674
)
674675

675-
def forward(self, x):
676+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
676677
"""Forward."""
677678
_is_first_microbatch = (
678679
None if self.disable_parameter_transpose_cache else self.is_first_microbatch

megatron/core/extensions/transformer_engine_spec_provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
1919
from megatron.core.models.backends import BackendSpecProvider
2020
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
21+
from megatron.core.transformer.attention import LinearProjBuilder
2122
from megatron.core.transformer.mlp import MLPSubmodules, TEActivationFunctionBuilder
2223
from megatron.core.transformer.moe.experts import (
2324
GroupedMLP,
@@ -40,7 +41,11 @@ def column_parallel_linear(self) -> type:
4041
"""Which column parallel linear module TE backend uses"""
4142
return TEColumnParallelLinear
4243

43-
def row_parallel_linear(self) -> type:
44+
def linear_proj(self) -> LinearProjBuilder:
45+
"""Which module the backend uses for the final linear projection in attention"""
46+
return TERowParallelLinear
47+
48+
def row_parallel_linear(self) -> type[TERowParallelLinear]:
4449
"""Which row parallel linear module TE backend uses"""
4550
return TERowParallelLinear
4651

megatron/core/models/T5/t5_spec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
6767
submodules=SelfAttentionSubmodules(
6868
linear_qkv=not_none(TELayerNormColumnParallelLinear),
6969
core_attention=not_none(TEDotProductAttention),
70-
linear_proj=TERowParallelLinear,
70+
linear_proj=not_none(TERowParallelLinear),
7171
q_layernorm=IdentityOp,
7272
k_layernorm=IdentityOp,
7373
),
@@ -97,7 +97,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
9797
submodules=SelfAttentionSubmodules(
9898
linear_qkv=not_none(TELayerNormColumnParallelLinear),
9999
core_attention=not_none(TEDotProductAttention),
100-
linear_proj=TERowParallelLinear,
100+
linear_proj=not_none(TERowParallelLinear),
101101
q_layernorm=IdentityOp,
102102
k_layernorm=IdentityOp,
103103
),
@@ -111,7 +111,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
111111
linear_q=not_none(TEColumnParallelLinear),
112112
linear_kv=not_none(TEColumnParallelLinear),
113113
core_attention=not_none(TEDotProductAttention),
114-
linear_proj=TERowParallelLinear,
114+
linear_proj=not_none(TERowParallelLinear),
115115
),
116116
),
117117
cross_attn_bda=get_bias_dropout_add,

megatron/core/models/backends.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional, Protocol, cast
77

88
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
9+
from megatron.core.transformer.attention import LinearProjBuilder
910
from megatron.core.transformer.dot_product_attention import DotProductAttention
1011
from megatron.core.transformer.mlp import MLPSubmodules, TEActivationFunctionBuilder
1112
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLPSubmodules
@@ -47,6 +48,11 @@ def column_parallel_linear(self) -> type:
4748
"""Which column parallel linear module the backend uses"""
4849
...
4950

51+
@abstractmethod
52+
def linear_proj(self) -> LinearProjBuilder:
53+
"""Which module the backend uses for the final linear projection in attention"""
54+
...
55+
5056
@abstractmethod
5157
def row_parallel_linear(self) -> type:
5258
"""Which row parallel linear module the backend uses"""
@@ -92,7 +98,11 @@ def column_parallel_linear(self) -> type:
9298
"""Which column parallel linear module the backend uses"""
9399
return ColumnParallelLinear
94100

95-
def row_parallel_linear(self) -> type:
101+
def linear_proj(self) -> LinearProjBuilder:
102+
"""Which module the backend uses for the final linear projection in attention"""
103+
return RowParallelLinear
104+
105+
def row_parallel_linear(self) -> type[RowParallelLinear]:
96106
"""Which row parallel linear module the backend uses"""
97107
return RowParallelLinear
98108

@@ -148,8 +158,12 @@ def column_parallel_linear(self) -> type:
148158
"""Which column parallel linear module TE backend uses"""
149159
return TEColumnParallelLinear
150160

151-
def row_parallel_linear(self) -> type:
152-
"""Which row parallel linear module TE backend uses"""
161+
def linear_proj(self) -> LinearProjBuilder:
162+
"""Which module the backend uses for the final linear projection in attention"""
163+
return InferenceRowParallelLinear
164+
165+
def row_parallel_linear(self) -> type[InferenceRowParallelLinear]:
166+
"""Which row parallel linear module Inference backend uses"""
153167
return InferenceRowParallelLinear
154168

155169
def fuse_layernorm_and_linear(self) -> bool:

megatron/core/models/gpt/experimental_attention_variant_module_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def get_dsa_module_spec_for_backend(
119119
linear_kv_down_proj=backend.linear(),
120120
linear_kv_up_proj=linear_kv_up_proj,
121121
core_attention=core_attention,
122-
linear_proj=backend.row_parallel_linear(),
122+
linear_proj=backend.linear_proj(),
123123
q_layernorm=IdentityOp,
124124
kv_layernorm=IdentityOp,
125125
),

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def get_gpt_layer_with_inference_submodules(
119119
linear_kv_down_proj=backend.linear(),
120120
linear_kv_up_proj=linear_kv_up_proj,
121121
core_attention=backend.core_attention(),
122-
linear_proj=backend.row_parallel_linear(),
122+
linear_proj=backend.linear_proj(),
123123
q_layernorm=IdentityOp,
124124
kv_layernorm=IdentityOp,
125125
),
@@ -138,7 +138,7 @@ def get_gpt_layer_with_inference_submodules(
138138
submodules=SelfAttentionSubmodules(
139139
linear_qkv=backend.column_parallel_layer_norm_linear(),
140140
core_attention=backend.core_attention(),
141-
linear_proj=backend.row_parallel_linear(),
141+
linear_proj=backend.linear_proj(),
142142
q_layernorm=(
143143
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
144144
),
@@ -257,7 +257,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
257257
linear_kv_down_proj=backend.linear(),
258258
linear_kv_up_proj=linear_kv_up_proj,
259259
core_attention=backend.core_attention(),
260-
linear_proj=backend.row_parallel_linear(),
260+
linear_proj=backend.linear_proj(),
261261
q_layernorm=IdentityOp,
262262
kv_layernorm=IdentityOp,
263263
),
@@ -276,7 +276,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
276276
submodules=SelfAttentionSubmodules(
277277
linear_qkv=backend.column_parallel_layer_norm_linear(),
278278
core_attention=backend.core_attention(),
279-
linear_proj=backend.row_parallel_linear(),
279+
linear_proj=backend.linear_proj(),
280280
q_layernorm=(
281281
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
282282
),
@@ -383,7 +383,7 @@ def get_gpt_layer_local_submodules(
383383
linear_kv_down_proj=backend.column_parallel_linear(),
384384
linear_kv_up_proj=backend.column_parallel_linear(),
385385
core_attention=backend.core_attention(),
386-
linear_proj=backend.row_parallel_linear(),
386+
linear_proj=backend.linear_proj(),
387387
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
388388
kv_layernorm=qk_norm if qk_layernorm else IdentityOp,
389389
),
@@ -402,7 +402,7 @@ def get_gpt_layer_local_submodules(
402402
submodules=SelfAttentionSubmodules(
403403
linear_qkv=backend.column_parallel_linear(),
404404
core_attention=backend.core_attention(),
405-
linear_proj=backend.row_parallel_linear(),
405+
linear_proj=backend.linear_proj(),
406406
q_layernorm=(
407407
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
408408
),

megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _get_heterogenous_attention_spec(
122122
not_none(TELayerNormColumnParallelLinear) if use_te else ColumnParallelLinear
123123
),
124124
core_attention=not_none(TEDotProductAttention) if use_te else DotProductAttention,
125-
linear_proj=TERowParallelLinear if use_te else RowParallelLinear,
125+
linear_proj=not_none(TERowParallelLinear) if use_te else RowParallelLinear,
126126
q_layernorm=ln,
127127
k_layernorm=ln,
128128
),

megatron/core/transformer/attention.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from megatron.core.tensor_parallel.mappings import all_gather_last_dim_from_tensor_parallel_region
3434
from megatron.core.transformer.identity_op import IdentityOp
3535
from megatron.core.transformer.module import MegatronModule
36-
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
3736
from megatron.core.transformer.torch_norm import LayerNormBuilder
3837
from megatron.core.typed_torch import apply_module, not_none
3938
from megatron.core.utils import (
@@ -118,8 +117,8 @@
118117
HAVE_FUSED_QKV_ROPE = False
119118

120119

121-
class LinearQkv(Protocol):
122-
"""Protocol for linear_qkv modules."""
120+
class LinearQkvInterface(Protocol):
121+
"""Interface for linear_qkv modules."""
123122

124123
def forward(self, input: Tensor, /) -> tuple[Tensor, object]:
125124
"""Applies linear_qkv."""
@@ -147,13 +146,13 @@ def __call__(
147146
is_expert: bool,
148147
tp_comm_buffer_name: str,
149148
tp_group: torch.distributed.ProcessGroup | None = None,
150-
) -> LinearQkv: ...
149+
) -> LinearQkvInterface: ...
151150

152151

153-
class LinearLayer(Protocol):
154-
"""Protocol for linear_q and linear_kv modules."""
152+
class LinearInterface(Protocol):
153+
"""Interface for linear_q and linear_kv modules."""
155154

156-
def forward(self, input: Tensor, /) -> Tuple[Tensor, object]:
155+
def forward(self, input: Tensor, /) -> tuple[Tensor, object]:
157156
"""Applies linear_q/linear_kv."""
158157
...
159158

@@ -173,23 +172,23 @@ def __call__(
173172
bias: bool,
174173
skip_bias_add: bool,
175174
is_expert: bool,
176-
) -> LinearLayer: ...
175+
) -> LinearInterface: ...
177176

178177

179-
class CoreAttention(Protocol):
180-
"""Protocol for core_attention modules."""
178+
class CoreAttentionInterface(Protocol):
179+
"""Interface for core_attention modules."""
181180

182181
def forward(
183182
self,
184183
query: Tensor,
185184
key: Tensor,
186185
value: Tensor,
187-
attention_mask: Optional[Tensor],
186+
attention_mask: Tensor | None,
188187
/,
189188
*,
190189
attn_mask_type: AttnMaskType,
191-
attention_bias: Optional[Tensor],
192-
packed_seq_params: Optional[PackedSeqParams],
190+
attention_bias: Tensor | None = None,
191+
packed_seq_params: PackedSeqParams | None,
193192
) -> Tensor:
194193
"""Applies dot product attention."""
195194
...
@@ -205,10 +204,42 @@ def __call__(
205204
layer_number: int,
206205
attn_mask_type: AttnMaskType,
207206
attention_type: str,
208-
cp_comm_type: Optional[str],
209-
softmax_scale: Optional[float],
210-
pg_collection: Optional[ProcessGroupCollection],
211-
) -> CoreAttention: ...
207+
softmax_scale: float | None,
208+
cp_comm_type: str | None,
209+
pg_collection: ProcessGroupCollection | None,
210+
) -> CoreAttentionInterface: ...
211+
212+
213+
class LinearProjInterface(Protocol):
214+
"""Interface for linear_proj modules."""
215+
216+
def forward(self, hidden_states: Tensor, /) -> tuple[Tensor, Tensor | None]:
217+
"""Applies the linear projection to the output of the core attention."""
218+
...
219+
220+
def backward_dw(self) -> None:
221+
"""Computes weight gradients of output projection layer."""
222+
...
223+
224+
225+
class LinearProjBuilder(Protocol):
226+
"""Protocol for building linear_proj layers."""
227+
228+
def __call__(
229+
self,
230+
query_projection_size: int,
231+
hidden_size: int,
232+
/,
233+
*,
234+
config: TransformerConfig,
235+
init_method: Callable[[torch.Tensor], None],
236+
bias: bool,
237+
input_is_parallel: bool,
238+
skip_bias_add: bool,
239+
is_expert: bool,
240+
tp_comm_buffer_name: str,
241+
tp_group: torch.distributed.ProcessGroup | None,
242+
) -> LinearProjInterface: ...
212243

213244

214245
@dataclass
@@ -219,7 +250,7 @@ class SelfAttentionSubmodules:
219250

220251
linear_qkv: LinearQkvBuilder
221252
core_attention: CoreAttentionBuilder
222-
linear_proj: Union[ModuleSpec, type] = None
253+
linear_proj: LinearProjBuilder
223254
q_layernorm: LayerNormBuilder | None = None
224255
k_layernorm: LayerNormBuilder | None = None
225256

@@ -233,7 +264,7 @@ class CrossAttentionSubmodules:
233264
linear_q: LinearLayerBuilder
234265
linear_kv: LinearLayerBuilder
235266
core_attention: CoreAttentionBuilder
236-
linear_proj: Union[ModuleSpec, type] = None
267+
linear_proj: LinearProjBuilder
237268

238269

239270
class Attention(MegatronModule, ABC):
@@ -347,12 +378,11 @@ def __init__(
347378
)
348379

349380
# Output.
350-
self.linear_proj = build_module(
351-
submodules.linear_proj,
381+
self.linear_proj = submodules.linear_proj(
352382
self.query_projection_size,
353383
self.config.hidden_size,
354384
config=self.config,
355-
init_method=self.config.output_layer_init_method,
385+
init_method=not_none(self.config.output_layer_init_method),
356386
bias=self.config.add_bias_linear,
357387
input_is_parallel=True,
358388
skip_bias_add=True,
@@ -888,7 +918,7 @@ def forward(
888918
sequence_len_offset: Optional[int] = None,
889919
*,
890920
inference_params: Optional[BaseInferenceContext] = None,
891-
) -> tuple[Tensor, Tensor]:
921+
) -> tuple[Tensor, Tensor | None]:
892922
"""
893923
Perform a forward pass through the attention module.
894924
@@ -1038,7 +1068,7 @@ def forward(
10381068
)
10391069
out = output.transpose(0, 1).contiguous()
10401070
context_layer = out.view(out.size(0), out.size(1), -1)
1041-
output, bias = self.linear_proj(context_layer)
1071+
output, bias = apply_module(self.linear_proj)(context_layer)
10421072
return output, bias
10431073

10441074
if (
@@ -1206,7 +1236,7 @@ def forward(
12061236
# =================
12071237
nvtx_range_push(suffix="linear_proj")
12081238
with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out:
1209-
output, bias = self.linear_proj(core_attn_out)
1239+
output, bias = apply_module(self.linear_proj)(core_attn_out)
12101240
if self.offload_attn_proj:
12111241
output = off_interface.group_commit(
12121242
output, name="attn_proj", forced_released_tensors=[core_attn_out]

0 commit comments

Comments
 (0)