Skip to content

Commit 3f8feac

Browse files
committed
Use Protocols to type-check linear_proj submodules of Attention
1 parent f8becec commit 3f8feac

12 files changed

Lines changed: 141 additions & 73 deletions

File tree

examples/multimodal/layer_specs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
114114
submodules=SelfAttentionSubmodules(
115115
linear_qkv=not_none(TELayerNormColumnParallelLinear),
116116
core_attention=not_none(TEDotProductAttention),
117-
linear_proj=TERowParallelLinear,
117+
linear_proj=not_none(TERowParallelLinear),
118118
q_layernorm=IdentityOp,
119119
k_layernorm=IdentityOp,
120120
),
@@ -160,7 +160,7 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
160160
submodules=SelfAttentionSubmodules(
161161
linear_qkv=not_none(TELayerNormColumnParallelLinear),
162162
core_attention=not_none(TEDotProductAttention),
163-
linear_proj=TERowParallelLinear,
163+
linear_proj=not_none(TERowParallelLinear),
164164
),
165165
),
166166
self_attn_bda=get_bias_dropout_add,

examples/multimodal/radio/radio_g.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def get_radio_g_layer_spec_te() -> ModuleSpec:
127127
submodules=SelfAttentionSubmodules(
128128
linear_qkv=not_none(TELayerNormColumnParallelLinear),
129129
core_attention=not_none(TEDotProductAttention),
130-
linear_proj=TERowParallelLinear,
130+
linear_proj=not_none(TERowParallelLinear),
131131
q_layernorm=IdentityOp,
132132
k_layernorm=IdentityOp,
133133
),

megatron/core/extensions/transformer_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
from __future__ import annotations
23

34
import dataclasses
45
import enum
@@ -859,7 +860,7 @@ def will_execute_quantized(self, is_context_quantized: bool) -> bool:
859860
self.te_quant_params, self.training, is_context_quantized
860861
)
861862

862-
def forward(self, x):
863+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
863864
"""Forward."""
864865
_is_first_microbatch = (
865866
None if self.disable_parameter_transpose_cache else self.is_first_microbatch

megatron/core/extensions/transformer_engine_spec_provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
1919
from megatron.core.models.backends import BackendSpecProvider
2020
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
21+
from megatron.core.transformer.attention import LinearProjBuilder
2122
from megatron.core.transformer.mlp import MLPSubmodules, TEActivationFunctionBuilder
2223
from megatron.core.transformer.moe.experts import (
2324
SequentialMLP,
@@ -46,7 +47,11 @@ def column_parallel_linear(self) -> type:
4647
"""Which column parallel linear module TE backend uses"""
4748
return TEColumnParallelLinear
4849

49-
def row_parallel_linear(self) -> type:
50+
def row_parallel_linear_proj(self) -> LinearProjBuilder:
51+
"""Which module the backend uses for the final linear projection in attention"""
52+
return TERowParallelLinear
53+
54+
def row_parallel_linear(self) -> type[TERowParallelLinear]:
5055
"""Which row parallel linear module TE backend uses"""
5156
return TERowParallelLinear
5257

megatron/core/models/T5/t5_spec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
6767
submodules=SelfAttentionSubmodules(
6868
linear_qkv=not_none(TELayerNormColumnParallelLinear),
6969
core_attention=not_none(TEDotProductAttention),
70-
linear_proj=TERowParallelLinear,
70+
linear_proj=not_none(TERowParallelLinear),
7171
q_layernorm=IdentityOp,
7272
k_layernorm=IdentityOp,
7373
),
@@ -97,7 +97,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
9797
submodules=SelfAttentionSubmodules(
9898
linear_qkv=not_none(TELayerNormColumnParallelLinear),
9999
core_attention=not_none(TEDotProductAttention),
100-
linear_proj=TERowParallelLinear,
100+
linear_proj=not_none(TERowParallelLinear),
101101
q_layernorm=IdentityOp,
102102
k_layernorm=IdentityOp,
103103
),
@@ -111,7 +111,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
111111
linear_q=not_none(TEColumnParallelLinear),
112112
linear_kv=not_none(TEColumnParallelLinear),
113113
core_attention=not_none(TEDotProductAttention),
114-
linear_proj=TERowParallelLinear,
114+
linear_proj=not_none(TERowParallelLinear),
115115
),
116116
),
117117
cross_attn_bda=get_bias_dropout_add,

megatron/core/models/backends.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TERowParallelGroupedLinear,
1111
)
1212
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
13+
from megatron.core.transformer.attention import LinearProjBuilder
1314
from megatron.core.transformer.dot_product_attention import DotProductAttention
1415
from megatron.core.transformer.mlp import MLPSubmodules, TEActivationFunctionBuilder
1516
from megatron.core.transformer.moe.experts import (
@@ -56,6 +57,11 @@ def column_parallel_linear(self) -> type:
5657
"""Which column parallel linear module the backend uses"""
5758
...
5859

60+
@abstractmethod
61+
def row_parallel_linear_proj(self) -> LinearProjBuilder:
62+
"""Which module the backend uses for the final linear projection in attention"""
63+
...
64+
5965
@abstractmethod
6066
def row_parallel_linear(self) -> type:
6167
"""Which row parallel linear module the backend uses"""
@@ -103,7 +109,11 @@ def column_parallel_linear(self) -> type:
103109
"""Which column parallel linear module the backend uses"""
104110
return ColumnParallelLinear
105111

106-
def row_parallel_linear(self) -> type:
112+
def row_parallel_linear_proj(self) -> LinearProjBuilder:
113+
"""Which module the backend uses for the final linear projection in attention"""
114+
return RowParallelLinear
115+
116+
def row_parallel_linear(self) -> type[RowParallelLinear]:
107117
"""Which row parallel linear module the backend uses"""
108118
return RowParallelLinear
109119

@@ -154,8 +164,12 @@ def column_parallel_linear(self) -> type:
154164
"""Which column parallel linear module TE backend uses"""
155165
return TEColumnParallelLinear
156166

157-
def row_parallel_linear(self) -> type:
158-
"""Which row parallel linear module TE backend uses"""
167+
def row_parallel_linear_proj(self) -> LinearProjBuilder:
168+
"""Which module the backend uses for the final linear projection in attention"""
169+
return InferenceRowParallelLinear
170+
171+
def row_parallel_linear(self) -> type[InferenceRowParallelLinear]:
172+
"""Which row parallel linear module Inference backend uses"""
159173
return InferenceRowParallelLinear
160174

161175
def fuse_layernorm_and_linear(self) -> bool:

megatron/core/models/gpt/experimental_attention_variant_module_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def get_dsa_module_spec_for_backend(
119119
linear_kv_down_proj=backend.linear(),
120120
linear_kv_up_proj=linear_kv_up_proj,
121121
core_attention=core_attention,
122-
linear_proj=backend.row_parallel_linear(),
122+
linear_proj=backend.row_parallel_linear_proj(),
123123
q_layernorm=IdentityOp,
124124
kv_layernorm=IdentityOp,
125125
),

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def get_gpt_layer_with_inference_submodules(
122122
linear_kv_down_proj=backend.linear(),
123123
linear_kv_up_proj=linear_kv_up_proj,
124124
core_attention=backend.core_attention(),
125-
linear_proj=backend.row_parallel_linear(),
125+
linear_proj=backend.row_parallel_linear_proj(),
126126
q_layernorm=IdentityOp,
127127
kv_layernorm=IdentityOp,
128128
),
@@ -141,7 +141,7 @@ def get_gpt_layer_with_inference_submodules(
141141
submodules=SelfAttentionSubmodules(
142142
linear_qkv=backend.column_parallel_layer_norm_linear(),
143143
core_attention=backend.core_attention(),
144-
linear_proj=backend.row_parallel_linear(),
144+
linear_proj=backend.row_parallel_linear_proj(),
145145
q_layernorm=(
146146
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
147147
),
@@ -299,7 +299,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
299299
linear_kv_down_proj=backend.linear(),
300300
linear_kv_up_proj=linear_kv_up_proj,
301301
core_attention=backend.core_attention(),
302-
linear_proj=backend.row_parallel_linear(),
302+
linear_proj=backend.row_parallel_linear_proj(),
303303
q_layernorm=IdentityOp,
304304
kv_layernorm=IdentityOp,
305305
),
@@ -318,7 +318,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
318318
submodules=SelfAttentionSubmodules(
319319
linear_qkv=backend.column_parallel_layer_norm_linear(),
320320
core_attention=backend.core_attention(),
321-
linear_proj=backend.row_parallel_linear(),
321+
linear_proj=backend.row_parallel_linear_proj(),
322322
q_layernorm=(
323323
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
324324
),
@@ -419,7 +419,7 @@ def get_gpt_layer_local_submodules(
419419
linear_kv_down_proj=backend.column_parallel_linear(),
420420
linear_kv_up_proj=backend.column_parallel_linear(),
421421
core_attention=backend.core_attention(),
422-
linear_proj=backend.row_parallel_linear(),
422+
linear_proj=backend.row_parallel_linear_proj(),
423423
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
424424
kv_layernorm=qk_norm if qk_layernorm else IdentityOp,
425425
),
@@ -438,7 +438,7 @@ def get_gpt_layer_local_submodules(
438438
submodules=SelfAttentionSubmodules(
439439
linear_qkv=backend.column_parallel_linear(),
440440
core_attention=backend.core_attention(),
441-
linear_proj=backend.row_parallel_linear(),
441+
linear_proj=backend.row_parallel_linear_proj(),
442442
q_layernorm=(
443443
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
444444
),

megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _get_heterogenous_attention_spec(
122122
not_none(TELayerNormColumnParallelLinear) if use_te else ColumnParallelLinear
123123
),
124124
core_attention=not_none(TEDotProductAttention) if use_te else DotProductAttention,
125-
linear_proj=TERowParallelLinear if use_te else RowParallelLinear,
125+
linear_proj=not_none(TERowParallelLinear) if use_te else RowParallelLinear,
126126
q_layernorm=ln,
127127
k_layernorm=ln,
128128
),

megatron/core/transformer/attention.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from megatron.core.tensor_parallel.mappings import all_gather_last_dim_from_tensor_parallel_region
3434
from megatron.core.transformer.identity_op import IdentityOp
3535
from megatron.core.transformer.module import MegatronModule
36-
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
3736
from megatron.core.transformer.torch_norm import LayerNormBuilder
3837
from megatron.core.typed_torch import apply_module, not_none
3938
from megatron.core.utils import (
@@ -118,8 +117,8 @@
118117
HAVE_FUSED_QKV_ROPE = False
119118

120119

121-
class LinearQkv(Protocol):
122-
"""Protocol for linear_qkv modules."""
120+
class LinearQkvInterface(Protocol):
121+
"""Interface for linear_qkv modules."""
123122

124123
def forward(self, input: Tensor, /) -> tuple[Tensor, object]:
125124
"""Applies linear_qkv."""
@@ -147,13 +146,13 @@ def __call__(
147146
is_expert: bool,
148147
tp_comm_buffer_name: str,
149148
tp_group: torch.distributed.ProcessGroup | None = None,
150-
) -> LinearQkv: ...
149+
) -> LinearQkvInterface: ...
151150

152151

153-
class LinearLayer(Protocol):
154-
"""Protocol for linear_q and linear_kv modules."""
152+
class LinearInterface(Protocol):
153+
"""Interface for linear_q and linear_kv modules."""
155154

156-
def forward(self, input: Tensor, /) -> Tuple[Tensor, object]:
155+
def forward(self, input: Tensor, /) -> tuple[Tensor, object]:
157156
"""Applies linear_q/linear_kv."""
158157
...
159158

@@ -173,23 +172,23 @@ def __call__(
173172
bias: bool,
174173
skip_bias_add: bool,
175174
is_expert: bool,
176-
) -> LinearLayer: ...
175+
) -> LinearInterface: ...
177176

178177

179-
class CoreAttention(Protocol):
180-
"""Protocol for core_attention modules."""
178+
class CoreAttentionInterface(Protocol):
179+
"""Interface for core_attention modules."""
181180

182181
def forward(
183182
self,
184183
query: Tensor,
185184
key: Tensor,
186185
value: Tensor,
187-
attention_mask: Optional[Tensor],
186+
attention_mask: Tensor | None,
188187
/,
189188
*,
190189
attn_mask_type: AttnMaskType,
191-
attention_bias: Optional[Tensor],
192-
packed_seq_params: Optional[PackedSeqParams],
190+
attention_bias: Tensor | None = None,
191+
packed_seq_params: PackedSeqParams | None,
193192
) -> Tensor:
194193
"""Applies dot product attention."""
195194
...
@@ -205,10 +204,42 @@ def __call__(
205204
layer_number: int,
206205
attn_mask_type: AttnMaskType,
207206
attention_type: str,
208-
cp_comm_type: Optional[str],
209-
softmax_scale: Optional[float],
210-
pg_collection: Optional[ProcessGroupCollection],
211-
) -> CoreAttention: ...
207+
softmax_scale: float | None,
208+
cp_comm_type: str | None,
209+
pg_collection: ProcessGroupCollection | None,
210+
) -> CoreAttentionInterface: ...
211+
212+
213+
class LinearProjInterface(Protocol):
214+
"""Interface for linear_proj modules."""
215+
216+
def forward(self, hidden_states: Tensor, /) -> tuple[Tensor, Tensor | None]:
217+
"""Applies the linear projection to the output of the core attention."""
218+
...
219+
220+
def backward_dw(self) -> None:
221+
"""Computes weight gradients of output projection layer."""
222+
...
223+
224+
225+
class LinearProjBuilder(Protocol):
226+
"""Protocol for building linear_proj layers."""
227+
228+
def __call__(
229+
self,
230+
query_projection_size: int,
231+
hidden_size: int,
232+
/,
233+
*,
234+
config: TransformerConfig,
235+
init_method: Callable[[torch.Tensor], None],
236+
bias: bool,
237+
input_is_parallel: bool,
238+
skip_bias_add: bool,
239+
is_expert: bool,
240+
tp_comm_buffer_name: str,
241+
tp_group: torch.distributed.ProcessGroup | None,
242+
) -> LinearProjInterface: ...
212243

213244

214245
@dataclass
@@ -219,7 +250,7 @@ class SelfAttentionSubmodules:
219250

220251
linear_qkv: LinearQkvBuilder
221252
core_attention: CoreAttentionBuilder
222-
linear_proj: Union[ModuleSpec, type] = None
253+
linear_proj: LinearProjBuilder
223254
q_layernorm: LayerNormBuilder | None = None
224255
k_layernorm: LayerNormBuilder | None = None
225256

@@ -233,7 +264,7 @@ class CrossAttentionSubmodules:
233264
linear_q: LinearLayerBuilder
234265
linear_kv: LinearLayerBuilder
235266
core_attention: CoreAttentionBuilder
236-
linear_proj: Union[ModuleSpec, type] = None
267+
linear_proj: LinearProjBuilder
237268

238269

239270
class Attention(MegatronModule, ABC):
@@ -349,12 +380,11 @@ def __init__(
349380
)
350381

351382
# Output.
352-
self.linear_proj = build_module(
353-
submodules.linear_proj,
383+
self.linear_proj = submodules.linear_proj(
354384
self.query_projection_size,
355385
self.config.hidden_size,
356386
config=self.config,
357-
init_method=self.config.output_layer_init_method,
387+
init_method=not_none(self.config.output_layer_init_method),
358388
bias=self.config.add_bias_linear,
359389
input_is_parallel=True,
360390
skip_bias_add=True,
@@ -899,7 +929,7 @@ def forward(
899929
sequence_len_offset: Optional[int] = None,
900930
*,
901931
inference_params: Optional[BaseInferenceContext] = None,
902-
) -> tuple[Tensor, Tensor]:
932+
) -> tuple[Tensor, Tensor | None]:
903933
"""
904934
Perform a forward pass through the attention module.
905935
@@ -1049,7 +1079,7 @@ def forward(
10491079
)
10501080
out = output.transpose(0, 1).contiguous()
10511081
context_layer = out.view(out.size(0), out.size(1), -1)
1052-
output, bias = self.linear_proj(context_layer)
1082+
output, bias = apply_module(self.linear_proj)(context_layer)
10531083
return output, bias
10541084

10551085
if (
@@ -1217,7 +1247,7 @@ def forward(
12171247
# =================
12181248
nvtx_range_push(suffix="linear_proj")
12191249
with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out:
1220-
output, bias = self.linear_proj(core_attn_out)
1250+
output, bias = apply_module(self.linear_proj)(core_attn_out)
12211251
if self.offload_attn_proj:
12221252
output = off_interface.group_commit(
12231253
output, name="attn_proj", forced_released_tensors=[core_attn_out]

0 commit comments

Comments
 (0)