11import torch
22import torch .nn as nn
3+ import torch .nn .functional as F
4+ import math
35from typing import Optional
4- from torch .nn .attention .flex_attention import flex_attention , create_block_mask
6+ from torch .nn .attention .flex_attention import flex_attention
57
68from model .flex_mods import generate_tanh_softcap
79from model .utils import norm , Linear
@@ -61,6 +63,7 @@ def forward(
6163 x : torch .Tensor ,
6264 attention_mask : Optional [torch .Tensor ] = None ,
6365 vi : Optional [torch .Tensor ] = None ,
66+ ** kwargs ,
6467 ) -> torch .Tensor :
6568 l , d = x .size () # batch size must be 1 for FlexAttention
6669 q , k , v = self .Wq (x ), self .Wk (x ), self .Wv (x )
@@ -98,32 +101,32 @@ def __init__(self, config):
98101 self .config = config
99102 self .n_tokens = config .num_att_tokens
100103 self .Wq = Linear (config .hidden_size , config .hidden_size )
101- self .Pk = nn .Parameter (torch .randn (1 , self .n_tokens , config .hidden_size ))
102- self .Pv = nn .Parameter (torch .randn (1 , self .n_tokens , config .hidden_size ))
103- self .sliding_window_size = config .sliding_window_size
104+ self .Pk = nn .Parameter (torch .randn (self .n_tokens , config .hidden_size ))
105+ self .Pv = nn .Parameter (torch .randn (self .n_tokens , config .hidden_size ))
104106
105- def forward (
106- self ,
107- x : torch .Tensor ,
108- sliding_window_size : Optional [int ] = None ,
109- ) -> torch .Tensor :
107+ def act (self , x : torch .Tensor ) -> torch .Tensor :
108+ o = x / (torch .norm (x , p = 2 , dim = - 1 , keepdim = True ) + 1e-3 ) * math .sqrt (x .shape [- 1 ])
109+ o = F .gelu (o )
110+ return o
111+
112+ def forward (self , x : torch .Tensor , last_eos : Optional [int ] = None ) -> torch .Tensor :
113+ if last_eos is None :
114+ last_eos = x .shape [1 ] - 1
110115 Q_len , d = x .size () # batch size must be 1 for FlexAttention
111116
112- if sliding_window_size is None :
113- sliding_window_size = self . sliding_window_size
117+ attention_mask = torch . ones ( Q_len , self . n_tokens , device = x . device )
118+ attention_mask [ last_eos :, :] = 0
114119
115- def doc_mask_mod ( b , h , q_idx , kv_idx ):
116- bidirectional_sliding_window_mask = torch . abs ( q_idx - kv_idx ) < sliding_window_size
117- return bidirectional_sliding_window_mask
120+ q = self . Wq ( x ) # (Q_len, d)
121+ k = self . Pk # (n_tokens, d)
122+ v = self . Pv # (n_tokens, d)
118123
119- KV_len = self .n_tokens
120- attention_mask = create_block_mask (doc_mask_mod , 1 , 1 , Q_len , KV_len )
124+ attn_weight = q @ k .transpose (0 , 1 ) # (Q_len, n_tokens)
125+ attn_weight *= attention_mask
126+ attn_weight = self .act (attn_weight )
121127
122- q = self .Wq (x ).unsqueeze (0 ).unsqueeze (1 ) # (1, 1, Q_len, d)
123- k = self .Pk .unsqueeze (1 ) # (1, 1, n_tokens, d)
124- v = self .Pv .unsqueeze (1 ) # (1, 1, n_tokens, d)
125- y = flex_attention (q , k , v , block_mask = attention_mask ) # (1, 1, Q_len, d)
126- return y .squeeze (1 ) # (1, Q_len, d)
128+ y = attn_weight @ v # (Q_len, d)
129+ return y .unsqueeze (0 ) # (1, Q_len, d)
127130
128131
129132class MultiHeadPAttention (nn .Module ):
@@ -148,12 +151,14 @@ def forward(
148151 x : torch .Tensor ,
149152 attention_mask : Optional [torch .Tensor ] = None ,
150153 vi : Optional [torch .Tensor ] = None ,
154+ last_eos : Optional [int ] = None ,
155+ ** kwargs ,
151156 ) -> torch .Tensor :
152157 # attention mask already prepped for sdpa shape (bs, 1, seq_len, seq_len)
153158 l , d = x .size ()
154- q = self .Wq (x ) # (1, l, d)
155- k = self .Wk (x ) # (1, l, d)
156- v = self .Wv (x ) # (1, l, d)
159+ q = self .Wq (x , last_eos ) # (1, l, d)
160+ k = self .Wk (x , last_eos ) # (1, l, d)
161+ v = self .Wv (x , last_eos ) # (1, l, d)
157162
158163 if self .unet and vi is not None :
159164 # Reshape vi from (l, d) to (1, l, d) to match v's shape before applying it
@@ -175,3 +180,62 @@ def forward(
175180
176181 y = y .contiguous ().view (1 , l , self .n_heads * self .d_head ) # (1, l, n_heads * d_head)
177182 return self .Wo (y ).squeeze (0 ) # (l, hidden_size)
183+
184+
185+ if __name__ == '__main__' :
186+ # test pattention
187+ # py -m model.attention
188+
189+ # Simple config for testing
190+ class TestConfig :
191+ def __init__ (self ):
192+ self .hidden_size = 64
193+ self .num_att_tokens = 8
194+
195+ config = TestConfig ()
196+ pattention = PAttention (config )
197+
198+ # Test input: sequence length 10, hidden size 64
199+ seq_len = 10
200+ x = torch .randn (seq_len , config .hidden_size )
201+
202+ # Test mask logic with different last_eos values
203+ print ("Testing PAttention mask logic..." )
204+
205+ # Case 1: last_eos = 5 (mask positions 0-4, unmask positions 5-9)
206+ last_eos = 5
207+ output = pattention (x , last_eos = last_eos )
208+
209+ # Manually check the mask logic
210+ q = pattention .Wq (x )
211+ k = pattention .Pk
212+ attn_weight = q @ k .transpose (0 , 1 )
213+
214+ # Create expected mask
215+ expected_mask = torch .ones (seq_len , config .num_att_tokens )
216+ expected_mask [:last_eos , :] = 0
217+
218+ # Apply mask
219+ masked_attn = attn_weight * expected_mask
220+
221+ # Check that positions before last_eos are zero
222+ assert torch .allclose (masked_attn [:last_eos , :], torch .zeros (last_eos , config .num_att_tokens )), \
223+ "Attention weights before last_eos should be zero"
224+
225+ # Check that positions from last_eos onwards are non-zero (assuming non-zero input)
226+ assert not torch .allclose (masked_attn [last_eos :, :], torch .zeros (seq_len - last_eos , config .num_att_tokens )), \
227+ "Attention weights from last_eos onwards should be non-zero"
228+
229+ print (f"Test passed for last_eos={ last_eos } " )
230+
231+ # Case 2: last_eos = 0 (no masking)
232+ last_eos = 0
233+ output = pattention (x , last_eos = last_eos )
234+ print (f"Test passed for last_eos={ last_eos } " )
235+
236+ # Case 3: last_eos = seq_len - 1 (mask all but last position)
237+ last_eos = seq_len - 1
238+ output = pattention (x , last_eos = last_eos )
239+ print (f"Test passed for last_eos={ last_eos } " )
240+
241+ print ("All PAttention mask tests passed!" )
0 commit comments