Skip to content

perf(flash_cuda): head-parallel dK/dV + vectorized loads + cp.async#376

Open
WilliamYue37 wants to merge 1 commit into
mainfrom
perf/flash-cuda-fa2
Open

perf(flash_cuda): head-parallel dK/dV + vectorized loads + cp.async#376
WilliamYue37 wants to merge 1 commit into
mainfrom
perf/flash-cuda-fa2

Conversation

@WilliamYue37
Copy link
Copy Markdown
Member

What this does

(⚡️ Performance) FlashAttention-2 follow-up to #358. The custom block-causal
flash_cuda kernel landed in #358 with a real memory win but ran much slower
than the eager/sdpa backends at the pi07_paligemma shape (head_dim=256, MQA).
This PR applies FA2 techniques to close most of that gap, benchmarked on a
local RTX 3090
(B2 S1024 D256 bf16, MQA):

baseline flash_cuda (main) this PR eager / sdpa
per-op fwd 3.14 ms 2.19 ms (1.4×) 0.81 / 0.41
per-op fwd+bwd 32.4 ms 7.13 ms (4.5×) 2.18 / 2.72
stacked 18-layer +ckpt 842.7 ms 295.8 ms (2.85×) ~97 / ~93
+ckpt peak mem 0.21 GB 0.24 GB 0.50 / 0.28

Three changes, all in flash_blockmask.cu:

  1. dK/dV WMMA kernel parallelized over query heads (grid.y = H, was Hkv).
    It now writes per-head fp32 partials (H, B, Sk, D) that a new
    dkv_reduce_kernel sums over each GQA/MQA head group in a fixed order. This
    removes the serial for h in group loop that, under MQA (Hkv=1), launched a
    single key-head block looping all 8 query heads — starving the GPU. No
    atomics
    , so bit-identical determinism is preserved (CLAUDE.md rule Fixing reward normalizer #3).
  2. Vectorized 128-bit (uint4) tile loads (load_tile_vec) across all WMMA
    kernels, replacing the scalar per-element loads that round-tripped every
    element bf16→float→bf16. The kernels were memory-bound, so this was the single
    biggest win.
  3. cp.async double-buffered streaming of Q/dO in the dominant dK/dV kernel
    (cpasync_tile, __pipeline_*), iterating a contiguous query-tile suffix
    derived from block-id monotonicity (skip-free pipeline).

Honest verdict: this is a 2.85–4.5× speedup over the previous flash_cuda
and the memory win is preserved (~2× under eager), but it still does not beat
eager/sdpa
(~3× off) at head_dim=256. After vectorization the kernels are
compute/smem-bound, not load-bound, so cp.async only bought ~9% on the dK/dV
kernel; I deliberately did not force it into the dQ/forward kernels where it
would cost an occupancy drop against the 99 KB sm_86 shared-memory cap. Fully
closing the gap to cuDNN would require a register-resident mma.sync rewrite
(tracked as future work). The memory cost rises slightly (per-head dK/dV
partials are the gradient, transient during backward) but stays well below
eager.

How it was tested

  • All 134 GPU tests in tests/policies/test_flash_attn_cuda.py pass on a
    local RTX 3090 (fp32 vs dense ref ~1e-4; bf16/fp16 WMMA vs sdpa ~2e-2; WMMA
    backward vs fp32 autograd ~5e-2; non-square cross-attention; GQA/MQA
    Hkv ∈ {1,2}; padding; every mask pattern).
  • Determinism verified: same-input fwd+bwd is bit-identical across 8 repeats
    on 4 configs (output + dQ/dK/dV), confirming no cp.async race and no
    atomic nondeterminism.
  • Benchmarks above via python -m opentau.scripts.benchmark_flash_attn
    (extended to report checkpointed fwd+bwd latency for eager/sdpa/flash).
  • Pending: full nightly regression suite + A100 benchmarking (the A100 node
    is currently unavailable; will validate there once free, as with feat(pi07_paligemma): custom CUDA block-causal flash attention #358).

How to checkout & try? (for the reviewer)

pytest -m "gpu" -n 0 tests/policies/test_flash_attn_cuda.py
python -m opentau.scripts.benchmark_flash_attn

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.
      • Note: ran the flash_attn_cuda GPU tests (134 pass) + determinism on a local RTX 3090. Full nightly regression + A100 benchmarking are still pending (A100 node unavailable).

Note: Before submitting this PR, please read the contributor guideline.

FlashAttention-2 follow-up to #358. On head_dim=256 / MQA the custom
block-causal kernel was much slower than eager despite the memory win;
this closes most of that gap (3090, B2 S1024 D256 bf16):

  per-op fwd+bwd      32.4 -> 7.13 ms  (4.5x)
  stacked-18L +ckpt  842.7 -> 295.8 ms (2.85x)
  peak mem +ckpt     0.21  -> 0.24 GB  (still ~2x under eager)

Changes (all in flash_blockmask.cu):
- dK/dV WMMA kernel parallelized over query heads (grid.y=H, was Hkv):
  per-head fp32 partials (H,B,Sk,D) reduced by a new dkv_reduce_kernel.
  No atomics -> bit-identical determinism preserved. Removes the serial
  head-group loop that starved the GPU under MQA (Hkv=1).
- Vectorized 128-bit (uint4) tile loads across all WMMA kernels,
  replacing scalar per-element loads with a bf16->float->bf16 round-trip.
- cp.async double-buffered streaming of Q/dO in the dK/dV kernel over a
  contiguous query-tile suffix.

Still ~3x slower than eager/sdpa at head_dim=256 (closing that needs a
full mma.sync rewrite). benchmark_flash_attn.py now reports checkpointed
fwd+bwd latency for eager/sdpa/flash. Benched on local RTX 3090; A100
benchmarking + nightly regression still pending.
// the WMMA element type (bf16/fp16), so this is a raw byte copy — no per-element
// dtype round-trip. Threads [tid, tid+nthreads, ...) of the calling group
// participate. Alignment: D is a multiple of 32 (checked host-side) and every row
// base is a multiple of D, so all uint4 accesses are 16-byte aligned.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion — The vectorized uint4 loads add a 16-byte alignment requirement on the tensor base that the old scalar loads did not have. q.contiguous() (line 1150/1215) returns an already-contiguous input unchanged, preserving its storage_offset; a contiguous view whose storage_offset is not a multiple of 8 bf16 elements would yield a data_ptr() that is not 16-byte aligned and trigger a CUDA misaligned-address fault. In the normal flash call path q/k/v are fresh projections (offset 0), so this is low-risk, but the alignment assumption documented here only covers D/row-stride, not the base offset. Consider asserting reinterpret_cast<uintptr_t>(q.data_ptr()) % 16 == 0 host-side (or .contiguous() after a copy that guarantees offset 0) so a bad input fails with an actionable message instead of an opaque misaligned-address error.

// block computes ONE query head's full contribution to dK/dV for its key slab
// and writes it to per-head fp32 partial buffers dKp/dVp of shape (H, B, Sk, D).
// Because every (h, b, key, d) is written by exactly one block, there is no
// cross-block accumulation -> no atomics -> bit-identical determinism. A
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion — Determinism here is now load-bearing by construction (no atomics + fixed-order group reduce + correctly-drained cp.async double buffer), and the PR says it was verified manually across 8 repeats. But there's no automated determinism test in tests/policies/test_flash_attn_cuda.py — the existing tests only check numeric agreement vs sdpa/autograd. A future edit (e.g. extending cp.async to the dQ/forward kernels, or switching the reduce to atomics) could silently reintroduce nondeterminism that the current suite wouldn't catch. Per CLAUDE.md rule #3, a same-input fwd+bwd bit-equality test (assert torch.equal on output + dQ/dK/dV across repeats) would lock this in cheaply.

dim3 grid((Sk + nw * BR_W - 1) / (nw * BR_W), Hkv, B);
// Per-head fp32 partials (H, B, Sk, D); each (h,b,key,d) written by exactly
// one block (deterministic), then summed over the head group below.
auto dKp = at::empty({H, B, Sk, D}, q.options().dtype(at::kFloat));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit — The fp32 per-head partials are (H, B, Sk, D) in float = 4·H·B·Sk·D bytes, i.e. ~2H· the bf16 (B,Sk,Hkv,D) gradient (and they exist transiently alongside it). At the benchmarked shape this is tiny, but on a long-sequence config (large Sk) with H=8 this is a real backward-time spike (e.g. B=8,Sk=8192,H=8,D=256 → ~1 GB transient for dKp+dVp). The PR description acknowledges this; just flagging that the spike scales with Sk, so it's worth keeping in mind if a config OOMs in dK/dV backward where main did not.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Jun 2, 2026

[claude-review] summary for commit d2dfc39

Reviewed the FA2 follow-up (head-parallel dK/dV + uint4 loads + cp.async). The kernel logic checks out: load offsets match idx_qkv, bwd_dkv_smem sizing matches the kernel's pointer arithmetic (incl. the new DKV_NBUF=2 double buffer), the cp.async pipeline is drained correctly (wait_prior(has_next ? 1 : 0) keeps exactly one batch in flight and the end-of-loop __syncthreads prevents a slot-overwrite race), the contiguous-suffix skip is safe even if q_blk isn't monotonic (the per-element mask still zeroes non-attending pairs), dKp/dVp are fully written before the reduce reads them, and determinism is preserved by the no-atomics + fixed-order group sum. No blocking issues found.

  • suggestionflash_blockmask.cu:84 — new uint4 loads require a 16-byte-aligned base; .contiguous() keeps an already-contiguous view's storage_offset, so a non-8-element-aligned offset would fault. Low risk in practice; consider a host-side alignment assert.
  • suggestionflash_blockmask.cu:872 — determinism is now load-bearing by construction but has no automated regression test; a same-input fwd+bwd bit-equality test would lock in CLAUDE.md rule Fixing reward normalizer #3.
  • nitflash_blockmask.cu:1295 — fp32 per-head partials are ~2H· the bf16 gradient and scale with Sk; potential backward-time OOM on long-sequence configs (acknowledged in the PR).

Note: GPU tests are gpu-marked and I can't run them here, so the "134 pass + determinism verified" claims are taken on the author's word.

@shuheng-liu shuheng-liu self-requested a review June 2, 2026 19:08
Copy link
Copy Markdown
Member

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving. Traced the rewrite statically and it holds up; the speedup is real and the PR is honest that it's still ~3× off eager/sdpa at head_dim=256. GPU-suite/A100 caveat at the bottom.

What I verified

  • Head-parallel dK/dV is equivalent to the old in-block group accumulation — per-head fp32 partials (H,B,Sk,D) written once each (no atomics), summed in fixed order by dkv_reduce_kernel. Partial write/read indexing and the reduce's output decomposition line up with idx_qkv, and the *scale-on-dK (not dV) placement is preserved.
  • cp.async double-buffer is race-free — ≤2 batches outstanding (matches DKV_NBUF=2), the prefetch slot (t+1)&1 always differs from the consumed slot t&1, and __pipeline_wait_prior(has_next ? 1 : 0) + __syncthreads() gate each tile's landing and slot reuse correctly. n_proc ∈ {0,1} edge cases hold, and the contiguous-suffix qt_start processes exactly the old per-tile-skip set given q_blk/k_blk monotonicity.
  • Shared-memory accounting matches the pointer layout field-for-field; at D=256, nw=1 ≈ 83.5 KB and nw=2 ≈ 134.6 KB, so pick_nw lands on nw=1 (sm_86) / nw=2 (A100) and both fit.
  • Determinism preserved (no atomics, fixed-order reduce, same-stream ordering between the dkv kernel and the reduce); fp32 path untouched; vectorized-load 16-byte alignment is guaranteed by the host-side D % 32 check.

Non-blocking suggestions (for a follow-up, not gating this merge)

  1. (memory scaling) dKp/dVp are O(H·B·Sk·D) fp32 globals — 2·(H/Hkv)× the bytes of the final dK/dV per tensor (16× each under MQA Hkv=1, H=8). Tiny at the benchmarked shape (~33 MB, still the memory winner with +ckpt) but grows linearly with B·Sk; the prior kernel had no global temp here, so a one-line note in the comment about the scaling would set expectations.
  2. (robustness) The backward dkv/dq WMMA launches use pick_nw + cudaFuncSetAttribute but lack the actionable TORCH_CHECK(smem <= max_smem, ...) the forward path has. If even nw=1 exceeds the device opt-in cap, pick_nw still returns 1 and you get an opaque launch failure instead of a clear message. Pre-existing, and the targeted arches fit nw=1 at D=256 — but since this PR raises the dkv footprint, mirroring the forward's check is a cheap win.
  3. (test) Consider committing a bitwise-determinism regression test (run flash_bwd twice, assert dQ/dK/dV bit-identical). The correctness story for the rewrite rests on determinism; you verified it ad hoc but it isn't locked in. Existing tests already exercise the new paths well (GQA hkv=2 / MQA hkv=1 group reductions, multi-tile pipelines via sq=40/50, non-square cross-attention), so this is purely additive.
  4. (nit) Backward correctness tests use H=4; the production pi07_paligemma shape and the benchmark use H=8, Hkv=1 (group=8). The reduce is generic over group size so correctness is unaffected, but a group=8 case would be marginally more representative.
  5. (nit, style) DKV_PREFETCH is a large multi-statement macro — justified by the ~15 captured locals, but a __device__ __forceinline__ helper (or a small pointer-bundle struct) would be more type-safe and debugger-friendly. Optional.
  6. (numerics note) dK/dV will differ in the last bits from #358's kernel — the group sum is now per-head-partial-then-reduce rather than one running in-block accumulator (a reassociation), and scale is applied to fp32 partials before the sum. Both are fine and within your test tolerances; just worth knowing if anyone bisects numerics against #358.

Caveat

GPU tests / A100 numbers — I could not run the GPU suite locally (CPU-only box), so the 134-pass / determinism claims rest on your reported runs. A100 benchmarking + full nightly regression are still pending per the PR, matching the #358 process. CPU CI is unaffected (the .cu isn't built on CPU runners; the benchmark script is import-only).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants