Skip to content

Commit d6f961d

Browse files
authored
[None][feat] Add llama4 scaling (#9771)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
1 parent 1c4dacb commit d6f961d

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,14 @@ def yarn_get_mscale(scale=1, mscale=1):
985985
is_neox=pos_embd_params.is_neox,
986986
)
987987

988+
self.llama_4_scaling = False
989+
if hasattr(config.pretrained_config, 'llama_4_scaling'):
990+
self.llama_4_scaling = True
991+
self.floor_scale = getattr(config.pretrained_config.llama_4_scaling,
992+
'original_max_position_embeddings', 8192)
993+
self.attn_scale = getattr(config.pretrained_config.llama_4_scaling,
994+
'beta', 0.1)
995+
988996
if not config.skip_create_weights_in_init:
989997
self.create_weights()
990998

@@ -1127,6 +1135,18 @@ def create_output(self, hidden_states: torch.Tensor, num_contexts: int):
11271135
return hidden_states.new_empty([num_tokens, hidden_size],
11281136
dtype=hidden_states.dtype)
11291137

1138+
def _attention_scaling(self, q, position_ids):
1139+
1140+
def _get_attn_scale(position_ids: torch.Tensor) -> torch.Tensor:
1141+
positions = position_ids.view(-1)
1142+
floor = torch.floor((positions + 1.0) / self.floor_scale)
1143+
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
1144+
return attn_scale.unsqueeze(-1)
1145+
1146+
attn_scale = _get_attn_scale(position_ids)
1147+
q = (q * attn_scale).to(q.dtype)
1148+
return q
1149+
11301150
def forward_impl(self,
11311151
position_ids: Optional[torch.Tensor],
11321152
hidden_states: torch.Tensor,
@@ -1197,6 +1217,10 @@ def forward_impl(self,
11971217
assert position_ids is not None
11981218
k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)
11991219

1220+
if self.llama_4_scaling:
1221+
q_ctx = self._attention_scaling(
1222+
q_ctx, position_ids[..., :num_ctx_tokens])
1223+
12001224
self.forward_context(
12011225
q_ctx,
12021226
compressed_kv_ctx,
@@ -1217,6 +1241,10 @@ def forward_impl(self,
12171241
assert position_ids is not None
12181242
k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)
12191243

1244+
if self.llama_4_scaling:
1245+
q_gen = self._attention_scaling(
1246+
q_gen, position_ids[..., num_ctx_tokens:])
1247+
12201248
self.forward_absorption_generation(
12211249
q_gen,
12221250
compressed_kv_gen,

0 commit comments

Comments
 (0)