@@ -985,6 +985,14 @@ def yarn_get_mscale(scale=1, mscale=1):
985985 is_neox = pos_embd_params .is_neox ,
986986 )
987987
988+ self .llama_4_scaling = False
989+ if hasattr (config .pretrained_config , 'llama_4_scaling' ):
990+ self .llama_4_scaling = True
991+ self .floor_scale = getattr (config .pretrained_config .llama_4_scaling ,
992+ 'original_max_position_embeddings' , 8192 )
993+ self .attn_scale = getattr (config .pretrained_config .llama_4_scaling ,
994+ 'beta' , 0.1 )
995+
988996 if not config .skip_create_weights_in_init :
989997 self .create_weights ()
990998
@@ -1127,6 +1135,18 @@ def create_output(self, hidden_states: torch.Tensor, num_contexts: int):
11271135 return hidden_states .new_empty ([num_tokens , hidden_size ],
11281136 dtype = hidden_states .dtype )
11291137
1138+ def _attention_scaling (self , q , position_ids ):
1139+
1140+ def _get_attn_scale (position_ids : torch .Tensor ) -> torch .Tensor :
1141+ positions = position_ids .view (- 1 )
1142+ floor = torch .floor ((positions + 1.0 ) / self .floor_scale )
1143+ attn_scale = torch .log (floor + 1.0 ) * self .attn_scale + 1.0
1144+ return attn_scale .unsqueeze (- 1 )
1145+
1146+ attn_scale = _get_attn_scale (position_ids )
1147+ q = (q * attn_scale ).to (q .dtype )
1148+ return q
1149+
11301150 def forward_impl (self ,
11311151 position_ids : Optional [torch .Tensor ],
11321152 hidden_states : torch .Tensor ,
@@ -1197,6 +1217,10 @@ def forward_impl(self,
11971217 assert position_ids is not None
11981218 k_pe_ctx = self .apply_rope (q_ctx , k_pe_ctx , position_ids )
11991219
1220+ if self .llama_4_scaling :
1221+ q_ctx = self ._attention_scaling (
1222+ q_ctx , position_ids [..., :num_ctx_tokens ])
1223+
12001224 self .forward_context (
12011225 q_ctx ,
12021226 compressed_kv_ctx ,
@@ -1217,6 +1241,10 @@ def forward_impl(self,
12171241 assert position_ids is not None
12181242 k_pe_gen = self .apply_rope (q_gen , k_pe_gen , position_ids )
12191243
1244+ if self .llama_4_scaling :
1245+ q_gen = self ._attention_scaling (
1246+ q_gen , position_ids [..., num_ctx_tokens :])
1247+
12201248 self .forward_absorption_generation (
12211249 q_gen ,
12221250 compressed_kv_gen ,
0 commit comments