Skip to content

Fix notation and clarification of Puzzle 9#16

Open
alexzhang13 wants to merge 1 commit intogpu-mode:mainfrom
alexzhang13:main
Open

Fix notation and clarification of Puzzle 9#16
alexzhang13 wants to merge 1 commit intogpu-mode:mainfrom
alexzhang13:main

Conversation

@alexzhang13
Copy link
Copy Markdown

@alexzhang13 alexzhang13 commented Jul 17, 2024

The notation for the softmax in Puzzle 9 is confusing. The current indexing is not representative of the outer product, and also the inclusion of an extra variable B1 is a bit ambiguous. I think the new description (minor change) is more clear.

The new notation also makes it clear the relationship between the k vector and v vector, which is important for understanding how the full flash attention is done.

The notation for the softmax in Puzzle 9 is both confusing and wrong. The indexing is not the outer product, and also the inclusion of an extra variable B1 is a bit ambiguous. I think the new description (minor change) is more clear.
@VachanVY
Copy link
Copy Markdown

Hi @alexzhang13 ,

function arguments are also not complete... B1 is missing
image

but here i found:
https://github.com/SiriusNEO/Triton-Puzzles-Lite/blob/main/puzzles_ans.py#L614-L617

@triton.jit
def flashatt_kernel(
    q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):

@skimberk
Copy link
Copy Markdown
Contributor

+1 that the current notation for Puzzle 9 is confusing/potentially incorrect, and also that B1 is missing

As @VachanVY mentioned, it looks like it's been fixed/improved in Triton-Puzzles-Lite: https://github.com/SiriusNEO/Triton-Puzzles-Lite/blob/main/puzzles.md#puzzle-9-simple-flashattention

$$z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)_j v_{j} \text{ for } i = 1\ldots N_0$$

The missing B1 has been fixed there too: https://github.com/SiriusNEO/Triton-Puzzles-Lite/blob/2990dc91ab0495c5d0306609806f0b455b0555f2/puzzles.py#L466

@triton.jit
def flashatt_kernel(
    q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):

and then calls it with:

test(
    flashatt_kernel,
    flashatt_spec,
    B={"B0": 64, "B1": 32},
    nelem={"N0": 200, "T": 200},
    # other lite specific params removed
)

@msaroufim
Copy link
Copy Markdown
Member

This was fixed

@msaroufim msaroufim closed this Mar 18, 2026
@msaroufim msaroufim reopened this Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants