Skip to content

Commit 73eb7c6

Browse files
committed
Change to print softmax correction factor
1 parent 41fb9bc commit 73eb7c6

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,44 @@ def forward(
14701470
softmax_offset=softmax_offset,
14711471
fp8_output=fp8_output,
14721472
)
1473-
return self.fused_attention(
1473+
if os.getenv("SKIP_CORRECTION_STATS", "0") == "1" and not self.training:
1474+
print (f'TE fused attention query_layer.shape: {query_layer.shape} key_layer.shape: {key_layer.shape} value_layer.shape: {value_layer.shape}')
1475+
k_repeat = query_layer.size(2) // key_layer.size(2)
1476+
# repeat head dimension of key_layer
1477+
key_layer2 = key_layer.repeat_interleave(k_repeat, dim=2)
1478+
# Q [s, b, h, d] -> [s, b*h, d] -> [b*h, s, d]
1479+
query_layer_t = query_layer.view(query_layer.size(0), query_layer.size(1)*query_layer.size(2), query_layer.size(3)).transpose(0,1)
1480+
# K [s, b, h, d] -> [s, b*h, d] -> [b*h, s, d] -> [b*h, d, s]
1481+
key_layer_t = key_layer2.view(key_layer2.size(0), key_layer2.size(1)*key_layer2.size(2), key_layer2.size(3)).transpose(0,1).transpose(1,2)
1482+
# bmm1 [b*h, s, s]
1483+
bmm1 = query_layer_t.bmm(key_layer_t)
1484+
q_len = query_layer.size(0)
1485+
k_len = key_layer.size(0)
1486+
# causal mask [s, s]
1487+
causal_mask=torch.triu(torch.ones(q_len, k_len, device=bmm1.device), diagonal=1).bool()
1488+
assert attn_mask_type == 'causal'
1489+
masked_bmm1 = bmm1.masked_fill(causal_mask[None, :, :], float('-inf'))
1490+
br, bc = 128, 128
1491+
assert masked_bmm1.size(1) % br == 0
1492+
assert masked_bmm1.size(2) % bc == 0
1493+
num_block_rows = (q_len + br - 1) // br # 8192/128 = 64
1494+
num_block_cols = (k_len + bc - 1) // bc # 8192/128 = 64
1495+
# blocked_attn [b*h, 64, 128, 64, 128]
1496+
blocked_attn = masked_bmm1.view(masked_bmm1.size(0),num_block_rows,br,num_block_cols,bc)
1497+
# block_max [b*h, 64, 128, 64]
1498+
block_max = blocked_attn.max(dim=-1)[0]
1499+
block_max_cummax = block_max.cummax(dim=-1)[0]
1500+
block_max_larger = torch.ones_like(block_max)
1501+
# True indicates block max is larger than the previous blocks' max
1502+
# block_max_larger [b*h, 64, 128, 64]
1503+
block_max_larger[..., 1:] = (block_max[..., 1:] - block_max_cummax[..., :-1]) > 0
1504+
# block_max_any_larger [b*h, 64, 64]
1505+
block_max_any_larger = block_max_larger.any(dim=-2).float()
1506+
cf_change_ratio = float(torch.sum(block_max_larger.float()) / torch.numel(block_max_larger)) * 2 * num_block_rows / (num_block_rows + 1)
1507+
cf_block_change_ratio = float(torch.sum(block_max_any_larger) / torch.numel(block_max_any_larger)) * 2 * num_block_rows / (num_block_rows + 1)
1508+
print (f'query_layer {query_layer_t.shape}-{query_layer_t.dtype} key_layer {key_layer_t.shape}-{key_layer_t.dtype} cf_change_ratio {cf_change_ratio} cf_block_change_ratio {cf_block_change_ratio}')
1509+
1510+
attn_out = self.fused_attention(
14741511
query_layer,
14751512
key_layer,
14761513
value_layer,
@@ -1500,6 +1537,7 @@ def forward(
15001537
softmax_offset=softmax_offset,
15011538
fp8_output=fp8_output,
15021539
)
1540+
return attn_out
15031541

15041542
if use_unfused_attention:
15051543
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"

0 commit comments

Comments
 (0)