3333from megatron .core .tensor_parallel .mappings import all_gather_last_dim_from_tensor_parallel_region
3434from megatron .core .transformer .identity_op import IdentityOp
3535from megatron .core .transformer .module import MegatronModule
36- from megatron .core .transformer .spec_utils import ModuleSpec , build_module
3736from megatron .core .transformer .torch_norm import LayerNormBuilder
3837from megatron .core .typed_torch import apply_module , not_none
3938from megatron .core .utils import (
118117 HAVE_FUSED_QKV_ROPE = False
119118
120119
121- class LinearQkv (Protocol ):
122- """Protocol for linear_qkv modules."""
120+ class LinearQkvInterface (Protocol ):
121+ """Interface for linear_qkv modules."""
123122
124123 def forward (self , input : Tensor , / ) -> tuple [Tensor , object ]:
125124 """Applies linear_qkv."""
@@ -147,13 +146,13 @@ def __call__(
147146 is_expert : bool ,
148147 tp_comm_buffer_name : str ,
149148 tp_group : torch .distributed .ProcessGroup | None = None ,
150- ) -> LinearQkv : ...
149+ ) -> LinearQkvInterface : ...
151150
152151
153- class LinearLayer (Protocol ):
154- """Protocol for linear_q and linear_kv modules."""
152+ class LinearInterface (Protocol ):
153+ """Interface for linear_q and linear_kv modules."""
155154
156- def forward (self , input : Tensor , / ) -> Tuple [Tensor , object ]:
155+ def forward (self , input : Tensor , / ) -> tuple [Tensor , object ]:
157156 """Applies linear_q/linear_kv."""
158157 ...
159158
@@ -173,23 +172,23 @@ def __call__(
173172 bias : bool ,
174173 skip_bias_add : bool ,
175174 is_expert : bool ,
176- ) -> LinearLayer : ...
175+ ) -> LinearInterface : ...
177176
178177
179- class CoreAttention (Protocol ):
180- """Protocol for core_attention modules."""
178+ class CoreAttentionInterface (Protocol ):
179+ """Interface for core_attention modules."""
181180
182181 def forward (
183182 self ,
184183 query : Tensor ,
185184 key : Tensor ,
186185 value : Tensor ,
187- attention_mask : Optional [ Tensor ] ,
186+ attention_mask : Tensor | None ,
188187 / ,
189188 * ,
190189 attn_mask_type : AttnMaskType ,
191- attention_bias : Optional [ Tensor ] ,
192- packed_seq_params : Optional [ PackedSeqParams ] ,
190+ attention_bias : Tensor | None = None ,
191+ packed_seq_params : PackedSeqParams | None ,
193192 ) -> Tensor :
194193 """Applies dot product attention."""
195194 ...
@@ -205,10 +204,42 @@ def __call__(
205204 layer_number : int ,
206205 attn_mask_type : AttnMaskType ,
207206 attention_type : str ,
208- cp_comm_type : Optional [str ],
209- softmax_scale : Optional [float ],
210- pg_collection : Optional [ProcessGroupCollection ],
211- ) -> CoreAttention : ...
207+ softmax_scale : float | None ,
208+ cp_comm_type : str | None ,
209+ pg_collection : ProcessGroupCollection | None ,
210+ ) -> CoreAttentionInterface : ...
211+
212+
213+ class LinearProjInterface (Protocol ):
214+ """Interface for linear_proj modules."""
215+
216+ def forward (self , hidden_states : Tensor , / ) -> tuple [Tensor , Tensor | None ]:
217+ """Applies the linear projection to the output of the core attention."""
218+ ...
219+
220+ def backward_dw (self ) -> None :
221+ """Computes weight gradients of output projection layer."""
222+ ...
223+
224+
225+ class LinearProjBuilder (Protocol ):
226+ """Protocol for building linear_proj layers."""
227+
228+ def __call__ (
229+ self ,
230+ query_projection_size : int ,
231+ hidden_size : int ,
232+ / ,
233+ * ,
234+ config : TransformerConfig ,
235+ init_method : Callable [[torch .Tensor ], None ],
236+ bias : bool ,
237+ input_is_parallel : bool ,
238+ skip_bias_add : bool ,
239+ is_expert : bool ,
240+ tp_comm_buffer_name : str ,
241+ tp_group : torch .distributed .ProcessGroup | None ,
242+ ) -> LinearProjInterface : ...
212243
213244
214245@dataclass
@@ -219,7 +250,7 @@ class SelfAttentionSubmodules:
219250
220251 linear_qkv : LinearQkvBuilder
221252 core_attention : CoreAttentionBuilder
222- linear_proj : Union [ ModuleSpec , type ] = None
253+ linear_proj : LinearProjBuilder
223254 q_layernorm : LayerNormBuilder | None = None
224255 k_layernorm : LayerNormBuilder | None = None
225256
@@ -233,7 +264,7 @@ class CrossAttentionSubmodules:
233264 linear_q : LinearLayerBuilder
234265 linear_kv : LinearLayerBuilder
235266 core_attention : CoreAttentionBuilder
236- linear_proj : Union [ ModuleSpec , type ] = None
267+ linear_proj : LinearProjBuilder
237268
238269
239270class Attention (MegatronModule , ABC ):
@@ -349,12 +380,11 @@ def __init__(
349380 )
350381
351382 # Output.
352- self .linear_proj = build_module (
353- submodules .linear_proj ,
383+ self .linear_proj = submodules .linear_proj (
354384 self .query_projection_size ,
355385 self .config .hidden_size ,
356386 config = self .config ,
357- init_method = self .config .output_layer_init_method ,
387+ init_method = not_none ( self .config .output_layer_init_method ) ,
358388 bias = self .config .add_bias_linear ,
359389 input_is_parallel = True ,
360390 skip_bias_add = True ,
@@ -899,7 +929,7 @@ def forward(
899929 sequence_len_offset : Optional [int ] = None ,
900930 * ,
901931 inference_params : Optional [BaseInferenceContext ] = None ,
902- ) -> tuple [Tensor , Tensor ]:
932+ ) -> tuple [Tensor , Tensor | None ]:
903933 """
904934 Perform a forward pass through the attention module.
905935
@@ -1049,7 +1079,7 @@ def forward(
10491079 )
10501080 out = output .transpose (0 , 1 ).contiguous ()
10511081 context_layer = out .view (out .size (0 ), out .size (1 ), - 1 )
1052- output , bias = self .linear_proj (context_layer )
1082+ output , bias = apply_module ( self .linear_proj ) (context_layer )
10531083 return output , bias
10541084
10551085 if (
@@ -1217,7 +1247,7 @@ def forward(
12171247 # =================
12181248 nvtx_range_push (suffix = "linear_proj" )
12191249 with off_interface (self .offload_attn_proj , core_attn_out , "attn_proj" ) as core_attn_out :
1220- output , bias = self .linear_proj (core_attn_out )
1250+ output , bias = apply_module ( self .linear_proj ) (core_attn_out )
12211251 if self .offload_attn_proj :
12221252 output = off_interface .group_commit (
12231253 output , name = "attn_proj" , forced_released_tensors = [core_attn_out ]
0 commit comments