Skip to content

Commit e3e0fe1

Browse files
committed
added swa optimization for reducing MACCs using less KV
Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent 0ff742c commit e3e0fe1

File tree

2 files changed

+73
-6
lines changed

2 files changed

+73
-6
lines changed

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(self, hidden: torch.Tensor):
8383
up = (hidden @ W_u) + b_u # [T, I]
8484

8585
# Apply GptOss activation with clamping
86-
gate = gate.clamp(min=None, max=self.experts.limit)
86+
gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
8787
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
8888

8989
# GLU activation
@@ -584,11 +584,12 @@ def eager_attention_forward_blocked(
584584
value_states = repeat_kv(value, module.num_key_value_groups)
585585

586586
BS, NH, CL, DH = query.shape
587-
target_blocks = int(os.environ.get("NUM_BLOCKS", 1))
587+
target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1))
588588
block_positions = []
589589
for j in range(target_blocks):
590590
block_positions.append(j * (CL // target_blocks))
591591
block_count = 0
592+
592593
outs = []
593594
for block_idx in range(target_blocks):
594595
block_count += 1
@@ -621,6 +622,69 @@ def eager_attention_forward_blocked(
621622
return output, output
622623

623624

625+
def opt_eager_attention_forward_blocked(
626+
module: nn.Module,
627+
query: torch.Tensor,
628+
key: torch.Tensor,
629+
value: torch.Tensor,
630+
attention_mask: Optional[torch.Tensor],
631+
scaling: float,
632+
**kwargs,
633+
):
634+
key_states = repeat_kv(key, module.num_key_value_groups)
635+
value_states = repeat_kv(value, module.num_key_value_groups)
636+
637+
BS, NH, CL, DH = query.shape
638+
target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1))
639+
block_positions = []
640+
for j in range(target_blocks):
641+
block_positions.append(j * (CL // target_blocks))
642+
block_count = 0
643+
outs = []
644+
for block_idx in range(target_blocks):
645+
block_count += 1
646+
qi = block_positions[block_idx]
647+
# Calculate block size (last block should be handled with remainder)
648+
649+
if block_idx == target_blocks - 1:
650+
real_q_len = CL - qi
651+
else:
652+
real_q_len = block_positions[block_idx + 1] - qi
653+
654+
if block_idx == 0:
655+
kv_start_idx = 0
656+
else:
657+
kv_start_idx = qi - 128
658+
659+
q_block = query[:, :, qi : qi + real_q_len, :]
660+
if kwargs.get("sliding_window"):
661+
k_block = key_states[:, :, kv_start_idx : qi + real_q_len, :]
662+
v_block = value_states[:, :, kv_start_idx : qi + real_q_len, :]
663+
attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, kv_start_idx : qi + real_q_len]
664+
else:
665+
k_block = key_states
666+
v_block = value_states
667+
attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :]
668+
669+
scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling
670+
curr_attn_weights = torch.where(
671+
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
672+
)
673+
sinks = module.sinks.reshape(1, -1, 1, 1).expand(
674+
curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
675+
)
676+
combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
677+
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
678+
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
679+
curr_attn_weights = curr_attn_weights[..., :-1]
680+
out_block = torch.matmul(curr_attn_weights, v_block)
681+
outs.append(out_block)
682+
output = torch.cat(outs, dim=2)
683+
684+
output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous()
685+
return output, output
686+
687+
624688
class QEffPrefillOnlyGptOssAttention(GptOssAttention):
625689
"""Multi-headed attention from 'Attention Is All You Need' paper"""
626690

@@ -667,7 +731,7 @@ def forward(
667731
read_idx = short_read_idx + torch.where(
668732
position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0
669733
)
670-
# This is a trick to export with NUM_BLOCKS<seq_len<sliding_window_len, disabling it by default.
734+
# This is a trick to export with seq_len<sliding_window_len
671735
read_idx = torch.where(read_idx > position_ids.max(), 0, read_idx)
672736
k_cache = key_states[:, :, read_idx, :]
673737
v_cache = value_states[:, :, read_idx, :]
@@ -680,7 +744,10 @@ def forward(
680744
else:
681745
attention_mask = attention_mask
682746

683-
attention_interface: Callable = eager_attention_forward_blocked
747+
if os.environ.get("ENABLE_OPT_SWA", "0") == "1":
748+
attention_interface: Callable = opt_eager_attention_forward_blocked
749+
else:
750+
attention_interface: Callable = eager_attention_forward_blocked
684751
attn_output, attn_weights = attention_interface(
685752
self,
686753
query_states,

QEfficient/transformers/models/modeling_auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2321,13 +2321,13 @@ def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Opti
23212321
if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128:
23222322
raise ValueError(
23232323
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. "
2324-
f"Or set `NUM_BLOCKS` ENV variable"
2324+
f"Or set `NUM_Q_BLOCKS` ENV variable"
23252325
f"Received: prefill_seq_len={prefill_seq_len}"
23262326
)
23272327

23282328
num_q_blocks = prefill_seq_len // block_size
23292329
logger.warning(
2330-
f"Setting NUM_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_BLOCKS` to override"
2330+
f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override"
23312331
)
23322332
os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks)
23332333
num_q_blocks = int(num_q_blocks)

0 commit comments

Comments
 (0)