@@ -1470,7 +1470,44 @@ def forward(
14701470 softmax_offset = softmax_offset ,
14711471 fp8_output = fp8_output ,
14721472 )
1473- return self .fused_attention (
1473+ if os .getenv ("SKIP_CORRECTION_STATS" , "0" ) == "1" and not self .training :
1474+ print (f'TE fused attention query_layer.shape: { query_layer .shape } key_layer.shape: { key_layer .shape } value_layer.shape: { value_layer .shape } ' )
1475+ k_repeat = query_layer .size (2 ) // key_layer .size (2 )
1476+ # repeat head dimension of key_layer
1477+ key_layer2 = key_layer .repeat_interleave (k_repeat , dim = 2 )
1478+ # Q [s, b, h, d] -> [s, b*h, d] -> [b*h, s, d]
1479+ query_layer_t = query_layer .view (query_layer .size (0 ), query_layer .size (1 )* query_layer .size (2 ), query_layer .size (3 )).transpose (0 ,1 )
1480+ # K [s, b, h, d] -> [s, b*h, d] -> [b*h, s, d] -> [b*h, d, s]
1481+ key_layer_t = key_layer2 .view (key_layer2 .size (0 ), key_layer2 .size (1 )* key_layer2 .size (2 ), key_layer2 .size (3 )).transpose (0 ,1 ).transpose (1 ,2 )
1482+ # bmm1 [b*h, s, s]
1483+ bmm1 = query_layer_t .bmm (key_layer_t )
1484+ q_len = query_layer .size (0 )
1485+ k_len = key_layer .size (0 )
1486+ # causal mask [s, s]
1487+ causal_mask = torch .triu (torch .ones (q_len , k_len , device = bmm1 .device ), diagonal = 1 ).bool ()
1488+ assert attn_mask_type == 'causal'
1489+ masked_bmm1 = bmm1 .masked_fill (causal_mask [None , :, :], float ('-inf' ))
1490+ br , bc = 128 , 128
1491+ assert masked_bmm1 .size (1 ) % br == 0
1492+ assert masked_bmm1 .size (2 ) % bc == 0
1493+ num_block_rows = (q_len + br - 1 ) // br # 8192/128 = 64
1494+ num_block_cols = (k_len + bc - 1 ) // bc # 8192/128 = 64
1495+ # blocked_attn [b*h, 64, 128, 64, 128]
1496+ blocked_attn = masked_bmm1 .view (masked_bmm1 .size (0 ),num_block_rows ,br ,num_block_cols ,bc )
1497+ # block_max [b*h, 64, 128, 64]
1498+ block_max = blocked_attn .max (dim = - 1 )[0 ]
1499+ block_max_cummax = block_max .cummax (dim = - 1 )[0 ]
1500+ block_max_larger = torch .ones_like (block_max )
1501+ # True indicates block max is larger than the previous blocks' max
1502+ # block_max_larger [b*h, 64, 128, 64]
1503+ block_max_larger [..., 1 :] = (block_max [..., 1 :] - block_max_cummax [..., :- 1 ]) > 0
1504+ # block_max_any_larger [b*h, 64, 64]
1505+ block_max_any_larger = block_max_larger .any (dim = - 2 ).float ()
1506+ cf_change_ratio = float (torch .sum (block_max_larger .float ()) / torch .numel (block_max_larger )) * 2 * num_block_rows / (num_block_rows + 1 )
1507+ cf_block_change_ratio = float (torch .sum (block_max_any_larger ) / torch .numel (block_max_any_larger )) * 2 * num_block_rows / (num_block_rows + 1 )
1508+ print (f'query_layer { query_layer_t .shape } -{ query_layer_t .dtype } key_layer { key_layer_t .shape } -{ key_layer_t .dtype } cf_change_ratio { cf_change_ratio } cf_block_change_ratio { cf_block_change_ratio } ' )
1509+
1510+ attn_out = self .fused_attention (
14741511 query_layer ,
14751512 key_layer ,
14761513 value_layer ,
@@ -1500,6 +1537,7 @@ def forward(
15001537 softmax_offset = softmax_offset ,
15011538 fp8_output = fp8_output ,
15021539 )
1540+ return attn_out
15031541
15041542 if use_unfused_attention :
15051543 allow_emulation = os .getenv ("NVTE_UnfusedDPA_Emulate_FP8" , "0" ) == "1"
0 commit comments