Skip to content

Commit 4d205a2

Browse files
authored
Merge pull request #17 from Synthyra/fixing_pattention
Fixing pattention
2 parents 84c95b1 + b8bfa0c commit 4d205a2

5 files changed

Lines changed: 163 additions & 43 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ __pycache__/
44
*.bin
55
/logs
66
/experiments/*.yaml
7+
/.cache

Dockerfile

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# sudo docker build -t speedrun_plm .
22
# sudo docker run --gpus all --shm-size=128g -v ${PWD}:/workspace speedrun_plm torchrun --standalone --nproc_per_node=4 train.py
3-
3+
# docker run --gpus all -v ${PWD}:/workspace speedrun_plm python train.py --bugfix
44
# 1️⃣ CUDA / cuDNN base with no Python
55
FROM nvidia/cuda:12.6.2-cudnn-devel-ubuntu24.04
66

@@ -34,8 +34,7 @@ WORKDIR /app
3434
COPY requirements.txt .
3535

3636
RUN pip install --upgrade pip setuptools && \
37-
# force-install torch built for CUDA 12.6
38-
pip install --force-reinstall torch torchvision --index-url https://download.pytorch.org/whl/cu128 -U && \
37+
pip install --force-reinstall torch torchvision --index-url https://download.pytorch.org/whl/cu128 && \
3938
pip install -r requirements.txt -U
4039

4140
# 5️⃣ Copy the rest of the source

model/attention.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import math
35
from typing import Optional
4-
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
6+
from torch.nn.attention.flex_attention import flex_attention
57

68
from model.flex_mods import generate_tanh_softcap
79
from model.utils import norm, Linear
@@ -61,6 +63,7 @@ def forward(
6163
x: torch.Tensor,
6264
attention_mask: Optional[torch.Tensor] = None,
6365
vi: Optional[torch.Tensor] = None,
66+
**kwargs,
6467
) -> torch.Tensor:
6568
l, d = x.size() # batch size must be 1 for FlexAttention
6669
q, k, v = self.Wq(x), self.Wk(x), self.Wv(x)
@@ -98,32 +101,32 @@ def __init__(self, config):
98101
self.config = config
99102
self.n_tokens = config.num_att_tokens
100103
self.Wq = Linear(config.hidden_size, config.hidden_size)
101-
self.Pk = nn.Parameter(torch.randn(1, self.n_tokens, config.hidden_size))
102-
self.Pv = nn.Parameter(torch.randn(1, self.n_tokens, config.hidden_size))
103-
self.sliding_window_size = config.sliding_window_size
104+
self.Pk = nn.Parameter(torch.randn(self.n_tokens, config.hidden_size))
105+
self.Pv = nn.Parameter(torch.randn(self.n_tokens, config.hidden_size))
104106

105-
def forward(
106-
self,
107-
x: torch.Tensor,
108-
sliding_window_size: Optional[int] = None,
109-
) -> torch.Tensor:
107+
def act(self, x: torch.Tensor) -> torch.Tensor:
108+
o = x / (torch.norm(x, p=2, dim=-1, keepdim=True) + 1e-3) * math.sqrt(x.shape[-1])
109+
o = F.gelu(o)
110+
return o
111+
112+
def forward(self, x: torch.Tensor, last_eos: Optional[int] = None) -> torch.Tensor:
113+
if last_eos is None:
114+
last_eos = x.shape[1] - 1
110115
Q_len, d = x.size() # batch size must be 1 for FlexAttention
111116

112-
if sliding_window_size is None:
113-
sliding_window_size = self.sliding_window_size
117+
attention_mask = torch.ones(Q_len, self.n_tokens, device=x.device)
118+
attention_mask[last_eos:, :] = 0
114119

115-
def doc_mask_mod(b, h, q_idx, kv_idx):
116-
bidirectional_sliding_window_mask = torch.abs(q_idx - kv_idx) < sliding_window_size
117-
return bidirectional_sliding_window_mask
120+
q = self.Wq(x) # (Q_len, d)
121+
k = self.Pk # (n_tokens, d)
122+
v = self.Pv # (n_tokens, d)
118123

119-
KV_len = self.n_tokens
120-
attention_mask = create_block_mask(doc_mask_mod, 1, 1, Q_len, KV_len)
124+
attn_weight = q @ k.transpose(0, 1) # (Q_len, n_tokens)
125+
attn_weight *= attention_mask
126+
attn_weight = self.act(attn_weight)
121127

122-
q = self.Wq(x).unsqueeze(0).unsqueeze(1) # (1, 1, Q_len, d)
123-
k = self.Pk.unsqueeze(1) # (1, 1, n_tokens, d)
124-
v = self.Pv.unsqueeze(1) # (1, 1, n_tokens, d)
125-
y = flex_attention(q, k, v, block_mask=attention_mask) # (1, 1, Q_len, d)
126-
return y.squeeze(1) # (1, Q_len, d)
128+
y = attn_weight @ v # (Q_len, d)
129+
return y.unsqueeze(0) # (1, Q_len, d)
127130

128131

129132
class MultiHeadPAttention(nn.Module):
@@ -148,12 +151,14 @@ def forward(
148151
x: torch.Tensor,
149152
attention_mask: Optional[torch.Tensor] = None,
150153
vi: Optional[torch.Tensor] = None,
154+
last_eos: Optional[int] = None,
155+
**kwargs,
151156
) -> torch.Tensor:
152157
# attention mask already prepped for sdpa shape (bs, 1, seq_len, seq_len)
153158
l, d = x.size()
154-
q = self.Wq(x) # (1, l, d)
155-
k = self.Wk(x) # (1, l, d)
156-
v = self.Wv(x) # (1, l, d)
159+
q = self.Wq(x, last_eos) # (1, l, d)
160+
k = self.Wk(x, last_eos) # (1, l, d)
161+
v = self.Wv(x, last_eos) # (1, l, d)
157162

158163
if self.unet and vi is not None:
159164
# Reshape vi from (l, d) to (1, l, d) to match v's shape before applying it
@@ -175,3 +180,62 @@ def forward(
175180

176181
y = y.contiguous().view(1, l, self.n_heads * self.d_head) # (1, l, n_heads * d_head)
177182
return self.Wo(y).squeeze(0) # (l, hidden_size)
183+
184+
185+
if __name__ == '__main__':
186+
# test pattention
187+
# py -m model.attention
188+
189+
# Simple config for testing
190+
class TestConfig:
191+
def __init__(self):
192+
self.hidden_size = 64
193+
self.num_att_tokens = 8
194+
195+
config = TestConfig()
196+
pattention = PAttention(config)
197+
198+
# Test input: sequence length 10, hidden size 64
199+
seq_len = 10
200+
x = torch.randn(seq_len, config.hidden_size)
201+
202+
# Test mask logic with different last_eos values
203+
print("Testing PAttention mask logic...")
204+
205+
# Case 1: last_eos = 5 (mask positions 0-4, unmask positions 5-9)
206+
last_eos = 5
207+
output = pattention(x, last_eos=last_eos)
208+
209+
# Manually check the mask logic
210+
q = pattention.Wq(x)
211+
k = pattention.Pk
212+
attn_weight = q @ k.transpose(0, 1)
213+
214+
# Create expected mask
215+
expected_mask = torch.ones(seq_len, config.num_att_tokens)
216+
expected_mask[:last_eos, :] = 0
217+
218+
# Apply mask
219+
masked_attn = attn_weight * expected_mask
220+
221+
# Check that positions before last_eos are zero
222+
assert torch.allclose(masked_attn[:last_eos, :], torch.zeros(last_eos, config.num_att_tokens)), \
223+
"Attention weights before last_eos should be zero"
224+
225+
# Check that positions from last_eos onwards are non-zero (assuming non-zero input)
226+
assert not torch.allclose(masked_attn[last_eos:, :], torch.zeros(seq_len - last_eos, config.num_att_tokens)), \
227+
"Attention weights from last_eos onwards should be non-zero"
228+
229+
print(f"Test passed for last_eos={last_eos}")
230+
231+
# Case 2: last_eos = 0 (no masking)
232+
last_eos = 0
233+
output = pattention(x, last_eos=last_eos)
234+
print(f"Test passed for last_eos={last_eos}")
235+
236+
# Case 3: last_eos = seq_len - 1 (mask all but last position)
237+
last_eos = seq_len - 1
238+
output = pattention(x, last_eos=last_eos)
239+
print(f"Test passed for last_eos={last_eos}")
240+
241+
print("All PAttention mask tests passed!")

model/model.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,25 @@ def forward(
105105
attention_mask: Optional[torch.Tensor] = None,
106106
vi: Optional[torch.Tensor] = None,
107107
x0: Optional[torch.Tensor] = None,
108+
last_eos: Optional[int] = None,
109+
**kwargs,
108110
) -> torch.Tensor:
109111
if self.unet:
110112
x = self.lambdas[0] * x + self.lambdas[1] * x0
111-
x = x + self.attn(norm(x), attention_mask, vi)
113+
x = x + self.attn(
114+
x=norm(x),
115+
attention_mask=attention_mask,
116+
vi=vi,
117+
last_eos=last_eos,
118+
**kwargs,
119+
)
112120
else:
113-
x = x + self.attn(norm(x), attention_mask)
121+
x = x + self.attn(
122+
x=norm(x),
123+
attention_mask=attention_mask,
124+
last_eos=last_eos,
125+
**kwargs,
126+
)
114127
x = x + self.mlp(norm(x))
115128
return x
116129

@@ -120,9 +133,18 @@ def __init__(self, config: PLMConfig):
120133
super().__init__()
121134
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
122135

123-
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
136+
def forward(
137+
self,
138+
x: torch.Tensor,
139+
attention_mask: Optional[torch.Tensor] = None,
140+
**kwargs,
141+
) -> torch.Tensor:
124142
for layer in self.layers:
125-
x = layer(x, attention_mask)
143+
x = layer(
144+
x=x,
145+
attention_mask=attention_mask,
146+
**kwargs,
147+
)
126148
return x
127149

128150

@@ -137,17 +159,35 @@ def __init__(self, config: PLMConfig):
137159

138160
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
139161

140-
def forward(self, x: torch.Tensor, ve: List[torch.Tensor], attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
162+
def forward(
163+
self,
164+
x: torch.Tensor,
165+
ve: List[torch.Tensor],
166+
attention_mask: Optional[torch.Tensor] = None,
167+
**kwargs,
168+
) -> torch.Tensor:
141169
x0 = x
142170
ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:]
143171
skip_connections = []
144172
for i in range(self.num_encoder_layers):
145-
x = self.layers[i](x, attention_mask, ve_enc[i], x0)
173+
x = self.layers[i](
174+
x=x,
175+
attention_mask=attention_mask,
176+
vi=ve_enc[i],
177+
x0=x0,
178+
**kwargs,
179+
)
146180
skip_connections.append(x)
147181

148182
for i in range(self.num_decoder_layers):
149183
x = x + self.skip_weights[i] * skip_connections.pop()
150-
x = self.layers[self.num_encoder_layers + i](x, attention_mask, ve_dec[i], x0)
184+
x = self.layers[self.num_encoder_layers + i](
185+
x=x,
186+
attention_mask=attention_mask,
187+
vi=ve_dec[i],
188+
x0=x0,
189+
**kwargs,
190+
)
151191
return x
152192

153193

@@ -219,9 +259,18 @@ def doc_mask_mod(b, h, q_idx, kv_idx):
219259
x = norm(x)
220260
if self.unet:
221261
ve = self.value_embeds(input_ids)
222-
x = self.transformer(x, ve, attention_mask)
262+
x = self.transformer(
263+
x=x,
264+
ve=ve,
265+
attention_mask=attention_mask,
266+
last_eos=last_eos,
267+
)
223268
else:
224-
x = self.transformer(x, attention_mask)
269+
x = self.transformer(
270+
x=x,
271+
attention_mask=attention_mask,
272+
last_eos=last_eos,
273+
)
225274
return x
226275

227276
def get_vector_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:

train.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def arg_parser():
8888
parser.add_argument("--bfloat16", action="store_true", help="Use bfloat16")
8989

9090
# Data hyperparams
91-
parser.add_argument("--input_bin", type=str, default='data/omgprot50/omgprot50_train_*.bin', help="Input training bin files pattern")
92-
parser.add_argument("--input_valid_bin", type=str, default='data/omgprot50/omgprot50_valid_*.bin', help="Input validation bin files pattern")
93-
parser.add_argument("--input_test_bin", type=str, default='data/omgprot50/omgprot50_test_*.bin', help="Input test bin files pattern")
91+
parser.add_argument("--input_bin", type=str, default='data/omg_prot50/omg_prot50_train_*.bin', help="Input training bin files pattern")
92+
parser.add_argument("--input_valid_bin", type=str, default='data/omg_prot50/omg_prot50_valid_*.bin', help="Input validation bin files pattern")
93+
parser.add_argument("--input_test_bin", type=str, default='data/omg_prot50/omg_prot50_test_*.bin', help="Input test bin files pattern")
9494
parser.add_argument("--mlm", type=bool, default=False, help="Use masked language modeling")
9595
parser.add_argument("--mask_rate", type=float, default=0.2, help="Mask rate for masked language modeling")
9696
parser.add_argument("--starting_mask_rate", type=float, default=0.1, help="Starting mask rate for masked language modeling")
@@ -401,7 +401,7 @@ def init_schedulers(self):
401401
if self.args.mask_rate_schedule:
402402
mask_rate_scheduler = LerpFloat(
403403
start_val=self.args.starting_mask_rate,
404-
end_val=self.args.mask_rate + 0.01,
404+
end_val=self.args.mask_rate,
405405
precision=0.01
406406
)
407407
else:
@@ -561,10 +561,17 @@ def train(self):
561561
timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
562562

563563
frac_done = step / self.args.num_steps # training progress
564-
self.sliding_window_size = self.sliding_window_size_scheduler(frac_done)
564+
if frac_done > 1:
565+
self.sliding_window_size = self.args.max_length
566+
else:
567+
self.sliding_window_size = self.sliding_window_size_scheduler(frac_done)
568+
565569
if self.mask_rate_scheduler:
566570
frac_done_mask = step / self.args.mask_rate_steps
567-
mask_rate = self.mask_rate_scheduler(frac_done_mask)
571+
if frac_done_mask > 1:
572+
mask_rate = self.args.mask_rate
573+
else:
574+
mask_rate = self.mask_rate_scheduler(frac_done_mask)
568575
self.current_mask_rate = mask_rate
569576
self.train_loader.set_mask_rate(mask_rate)
570577

@@ -685,7 +692,7 @@ def train(self):
685692
args.num_att_tokens = 128
686693
args.expansion_ratio = 2.0
687694
args.soft_logit_cap = 16.0
688-
args.p_attention = False
695+
args.p_attention = True
689696
args.tie_embeddings = False
690697
args.unet = True
691698
args.batch_size = 2048

0 commit comments

Comments
 (0)