Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def forward( # type: ignore
window_left, window_right = window_size
# Check if this is generation phase (sq = 1)
sq = q.shape[1]
if q.dim() == 4 and sq == 1:
if q.dim() == 4 and sq == 1 and not page_table:
# For gen case, we don't need to save tensors for backward
ctx.is_gen = True
out, _ = cutlass_blackwell_fmha_decode_forward(
Expand Down
28 changes: 27 additions & 1 deletion mslk/attention/fmha/cutlass_blackwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
LowerTriangularFromBottomRightLocalAttentionMask,
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
PagedBlockDiagonalCausalLocalPaddedKeysMask,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
PagedBlockDiagonalPaddedKeysMask,
)
from .common import AttentionBwOpBase, AttentionFwOpBase, Context, Gradients, Inputs
from .utils.op_common import register_operator
Expand Down Expand Up @@ -79,6 +82,9 @@ def _convert_input_format(
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalLocalAttentionPaddedKeysMask,
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
PagedBlockDiagonalCausalLocalPaddedKeysMask,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
PagedBlockDiagonalPaddedKeysMask,
),
):
assert attn_bias.k_seqinfo.seqstart.device == inp.query.device
Expand Down Expand Up @@ -132,6 +138,12 @@ def fold(x):
key = fold(key)
value = fold(value)

# Reshape KV to 4D for paged attention
if isinstance(attn_bias, PagedBlockDiagonalPaddedKeysMask):
num_pages = value.shape[0] // attn_bias.page_size
key = key.view(num_pages, attn_bias.page_size, *key.shape[1:])
value = value.view(num_pages, attn_bias.page_size, *value.shape[1:])

new_inp = Inputs(
query=query,
key=key,
Expand Down Expand Up @@ -172,6 +184,8 @@ def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
PagedBlockDiagonalCausalLocalPaddedKeysMask,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
),
)

Expand All @@ -188,6 +202,8 @@ def _is_bottom_right(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) ->
BlockDiagonalLocalAttentionPaddedKeysMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
PagedBlockDiagonalCausalLocalPaddedKeysMask,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
),
)

Expand All @@ -202,8 +218,9 @@ def _window_size(
(
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
LowerTriangularFromBottomRightLocalAttentionMask,
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
LowerTriangularFromBottomRightLocalAttentionMask,
PagedBlockDiagonalCausalLocalPaddedKeysMask,
),
):
win_left = attn_bias._window_size - 1
Expand Down Expand Up @@ -243,6 +260,9 @@ class FwOp(AttentionFwOpBase):
LowerTriangularFromBottomRightLocalAttentionMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
PagedBlockDiagonalPaddedKeysMask,
PagedBlockDiagonalCausalLocalPaddedKeysMask,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
)
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
Expand Down Expand Up @@ -307,6 +327,11 @@ def apply(
window_left, window_right = _window_size(inp.attn_bias)

if inp.query.numel() > 0 and inp.key.numel() > 0:
block_table = (
inp.attn_bias.block_tables
if isinstance(inp.attn_bias, PagedBlockDiagonalPaddedKeysMask)
else None
)
out, lse = cls.OPERATOR(
q=inp.query,
k=inp.key,
Expand All @@ -321,6 +346,7 @@ def apply(
window_left=window_left,
window_right=window_right,
bottom_right=_is_bottom_right(inp.attn_bias),
page_table=block_table,
)
else:
out = torch.zeros_like(inp.query)
Expand Down
26 changes: 25 additions & 1 deletion test/attention/fmha/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import logging
import math
import random
from contextlib import nullcontext
from typing import List, Optional, Tuple, Type

Expand Down Expand Up @@ -1679,6 +1678,31 @@ def test_paged_attention_flash3(
paged_attention_run_inner(B, MAX_T, num_quant_groups, page_size, op, bench=False)


@sm100_or_better_only
@disable_on_rocm
@pytest.mark.parametrize(
"op",
_filter_unsupported_ops(
[
fmha.cutlass_blackwell.FwOp,
]
),
)
@pytest.mark.parametrize("B", [1, 5, 128])
@pytest.mark.parametrize("MAX_T", [64, 128, 2048, 4096, 8192])
@pytest.mark.parametrize("page_size", [128])
def test_paged_attention_cutlass_blackwell(
op: Type[AttentionFwOpBase], B: int, MAX_T: int, page_size: int
):
if (
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask
not in get_supported_attn_bias_types(op)
):
pytest.skip("Not supported bias")
num_quant_groups = 0
paged_attention_run_inner(B, MAX_T, num_quant_groups, page_size, op, bench=False)


def paged_attention_run_inner(
B: int,
MAX_T: int,
Expand Down
Loading