@@ -83,7 +83,7 @@ def forward(self, hidden: torch.Tensor):
8383 up = (hidden @ W_u ) + b_u # [T, I]
8484
8585 # Apply GptOss activation with clamping
86- gate = gate .clamp (min = None , max = self .experts .limit )
86+ gate = gate .clamp (min = torch . finfo ( torch . float16 ). min , max = self .experts .limit )
8787 up = up .clamp (min = - self .experts .limit , max = self .experts .limit )
8888
8989 # GLU activation
@@ -584,11 +584,12 @@ def eager_attention_forward_blocked(
584584 value_states = repeat_kv (value , module .num_key_value_groups )
585585
586586 BS , NH , CL , DH = query .shape
587- target_blocks = int (os .environ .get ("NUM_BLOCKS " , 1 ))
587+ target_blocks = int (os .environ .get ("NUM_Q_BLOCKS " , 1 ))
588588 block_positions = []
589589 for j in range (target_blocks ):
590590 block_positions .append (j * (CL // target_blocks ))
591591 block_count = 0
592+
592593 outs = []
593594 for block_idx in range (target_blocks ):
594595 block_count += 1
@@ -621,6 +622,69 @@ def eager_attention_forward_blocked(
621622 return output , output
622623
623624
625+ def opt_eager_attention_forward_blocked (
626+ module : nn .Module ,
627+ query : torch .Tensor ,
628+ key : torch .Tensor ,
629+ value : torch .Tensor ,
630+ attention_mask : Optional [torch .Tensor ],
631+ scaling : float ,
632+ ** kwargs ,
633+ ):
634+ key_states = repeat_kv (key , module .num_key_value_groups )
635+ value_states = repeat_kv (value , module .num_key_value_groups )
636+
637+ BS , NH , CL , DH = query .shape
638+ target_blocks = int (os .environ .get ("NUM_Q_BLOCKS" , 1 ))
639+ block_positions = []
640+ for j in range (target_blocks ):
641+ block_positions .append (j * (CL // target_blocks ))
642+ block_count = 0
643+ outs = []
644+ for block_idx in range (target_blocks ):
645+ block_count += 1
646+ qi = block_positions [block_idx ]
647+ # Calculate block size (last block should be handled with remainder)
648+
649+ if block_idx == target_blocks - 1 :
650+ real_q_len = CL - qi
651+ else :
652+ real_q_len = block_positions [block_idx + 1 ] - qi
653+
654+ if block_idx == 0 :
655+ kv_start_idx = 0
656+ else :
657+ kv_start_idx = qi - 128
658+
659+ q_block = query [:, :, qi : qi + real_q_len , :]
660+ if kwargs .get ("sliding_window" ):
661+ k_block = key_states [:, :, kv_start_idx : qi + real_q_len , :]
662+ v_block = value_states [:, :, kv_start_idx : qi + real_q_len , :]
663+ attn_mask_block = attention_mask [:, :, qi : qi + real_q_len , kv_start_idx : qi + real_q_len ]
664+ else :
665+ k_block = key_states
666+ v_block = value_states
667+ attn_mask_block = attention_mask [:, :, qi : qi + real_q_len , :]
668+
669+ scores = torch .matmul (q_block , k_block .transpose (2 , 3 )) * scaling
670+ curr_attn_weights = torch .where (
671+ attn_mask_block , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), scores
672+ )
673+ sinks = module .sinks .reshape (1 , - 1 , 1 , 1 ).expand (
674+ curr_attn_weights .shape [0 ], - 1 , curr_attn_weights .shape [- 2 ], - 1
675+ )
676+ combined_logits = torch .cat ([curr_attn_weights , sinks ], dim = - 1 )
677+ combined_logits = combined_logits - combined_logits .max (dim = - 1 , keepdim = True ).values
678+ curr_attn_weights = nn .functional .softmax (combined_logits , dim = - 1 , dtype = torch .float32 )
679+ curr_attn_weights = curr_attn_weights [..., :- 1 ]
680+ out_block = torch .matmul (curr_attn_weights , v_block )
681+ outs .append (out_block )
682+ output = torch .cat (outs , dim = 2 )
683+
684+ output = output .view (BS , NH , CL , DH ).transpose (1 , 2 ).contiguous ()
685+ return output , output
686+
687+
624688class QEffPrefillOnlyGptOssAttention (GptOssAttention ):
625689 """Multi-headed attention from 'Attention Is All You Need' paper"""
626690
@@ -667,7 +731,7 @@ def forward(
667731 read_idx = short_read_idx + torch .where (
668732 position_ids .max () > sliding_window_len - 1 , position_ids .max () - sliding_window_len + 1 , 0
669733 )
670- # This is a trick to export with NUM_BLOCKS< seq_len<sliding_window_len, disabling it by default.
734+ # This is a trick to export with seq_len<sliding_window_len
671735 read_idx = torch .where (read_idx > position_ids .max (), 0 , read_idx )
672736 k_cache = key_states [:, :, read_idx , :]
673737 v_cache = value_states [:, :, read_idx , :]
@@ -680,7 +744,10 @@ def forward(
680744 else :
681745 attention_mask = attention_mask
682746
683- attention_interface : Callable = eager_attention_forward_blocked
747+ if os .environ .get ("ENABLE_OPT_SWA" , "0" ) == "1" :
748+ attention_interface : Callable = opt_eager_attention_forward_blocked
749+ else :
750+ attention_interface : Callable = eager_attention_forward_blocked
684751 attn_output , attn_weights = attention_interface (
685752 self ,
686753 query_states ,
0 commit comments