perf(flash_cuda): head-parallel dK/dV + vectorized loads + cp.async#376
perf(flash_cuda): head-parallel dK/dV + vectorized loads + cp.async#376WilliamYue37 wants to merge 1 commit into
Conversation
FlashAttention-2 follow-up to #358. On head_dim=256 / MQA the custom block-causal kernel was much slower than eager despite the memory win; this closes most of that gap (3090, B2 S1024 D256 bf16): per-op fwd+bwd 32.4 -> 7.13 ms (4.5x) stacked-18L +ckpt 842.7 -> 295.8 ms (2.85x) peak mem +ckpt 0.21 -> 0.24 GB (still ~2x under eager) Changes (all in flash_blockmask.cu): - dK/dV WMMA kernel parallelized over query heads (grid.y=H, was Hkv): per-head fp32 partials (H,B,Sk,D) reduced by a new dkv_reduce_kernel. No atomics -> bit-identical determinism preserved. Removes the serial head-group loop that starved the GPU under MQA (Hkv=1). - Vectorized 128-bit (uint4) tile loads across all WMMA kernels, replacing scalar per-element loads with a bf16->float->bf16 round-trip. - cp.async double-buffered streaming of Q/dO in the dK/dV kernel over a contiguous query-tile suffix. Still ~3x slower than eager/sdpa at head_dim=256 (closing that needs a full mma.sync rewrite). benchmark_flash_attn.py now reports checkpointed fwd+bwd latency for eager/sdpa/flash. Benched on local RTX 3090; A100 benchmarking + nightly regression still pending.
| // the WMMA element type (bf16/fp16), so this is a raw byte copy — no per-element | ||
| // dtype round-trip. Threads [tid, tid+nthreads, ...) of the calling group | ||
| // participate. Alignment: D is a multiple of 32 (checked host-side) and every row | ||
| // base is a multiple of D, so all uint4 accesses are 16-byte aligned. |
There was a problem hiding this comment.
suggestion — The vectorized uint4 loads add a 16-byte alignment requirement on the tensor base that the old scalar loads did not have. q.contiguous() (line 1150/1215) returns an already-contiguous input unchanged, preserving its storage_offset; a contiguous view whose storage_offset is not a multiple of 8 bf16 elements would yield a data_ptr() that is not 16-byte aligned and trigger a CUDA misaligned-address fault. In the normal flash call path q/k/v are fresh projections (offset 0), so this is low-risk, but the alignment assumption documented here only covers D/row-stride, not the base offset. Consider asserting reinterpret_cast<uintptr_t>(q.data_ptr()) % 16 == 0 host-side (or .contiguous() after a copy that guarantees offset 0) so a bad input fails with an actionable message instead of an opaque misaligned-address error.
| // block computes ONE query head's full contribution to dK/dV for its key slab | ||
| // and writes it to per-head fp32 partial buffers dKp/dVp of shape (H, B, Sk, D). | ||
| // Because every (h, b, key, d) is written by exactly one block, there is no | ||
| // cross-block accumulation -> no atomics -> bit-identical determinism. A |
There was a problem hiding this comment.
suggestion — Determinism here is now load-bearing by construction (no atomics + fixed-order group reduce + correctly-drained cp.async double buffer), and the PR says it was verified manually across 8 repeats. But there's no automated determinism test in tests/policies/test_flash_attn_cuda.py — the existing tests only check numeric agreement vs sdpa/autograd. A future edit (e.g. extending cp.async to the dQ/forward kernels, or switching the reduce to atomics) could silently reintroduce nondeterminism that the current suite wouldn't catch. Per CLAUDE.md rule #3, a same-input fwd+bwd bit-equality test (assert torch.equal on output + dQ/dK/dV across repeats) would lock this in cheaply.
| dim3 grid((Sk + nw * BR_W - 1) / (nw * BR_W), Hkv, B); | ||
| // Per-head fp32 partials (H, B, Sk, D); each (h,b,key,d) written by exactly | ||
| // one block (deterministic), then summed over the head group below. | ||
| auto dKp = at::empty({H, B, Sk, D}, q.options().dtype(at::kFloat)); |
There was a problem hiding this comment.
nit — The fp32 per-head partials are (H, B, Sk, D) in float = 4·H·B·Sk·D bytes, i.e. ~2H· the bf16 (B,Sk,Hkv,D) gradient (and they exist transiently alongside it). At the benchmarked shape this is tiny, but on a long-sequence config (large Sk) with H=8 this is a real backward-time spike (e.g. B=8,Sk=8192,H=8,D=256 → ~1 GB transient for dKp+dVp). The PR description acknowledges this; just flagging that the spike scales with Sk, so it's worth keeping in mind if a config OOMs in dK/dV backward where main did not.
|
[claude-review] summary for commit d2dfc39 Reviewed the FA2 follow-up (head-parallel dK/dV +
Note: GPU tests are |
shuheng-liu
left a comment
There was a problem hiding this comment.
Approving. Traced the rewrite statically and it holds up; the speedup is real and the PR is honest that it's still ~3× off eager/sdpa at head_dim=256. GPU-suite/A100 caveat at the bottom.
What I verified
- Head-parallel dK/dV is equivalent to the old in-block group accumulation — per-head fp32 partials
(H,B,Sk,D)written once each (no atomics), summed in fixed order bydkv_reduce_kernel. Partial write/read indexing and the reduce's output decomposition line up withidx_qkv, and the*scale-on-dK (not dV) placement is preserved. - cp.async double-buffer is race-free — ≤2 batches outstanding (matches
DKV_NBUF=2), the prefetch slot(t+1)&1always differs from the consumed slott&1, and__pipeline_wait_prior(has_next ? 1 : 0)+__syncthreads()gate each tile's landing and slot reuse correctly.n_proc ∈ {0,1}edge cases hold, and the contiguous-suffixqt_startprocesses exactly the old per-tile-skip set givenq_blk/k_blkmonotonicity. - Shared-memory accounting matches the pointer layout field-for-field; at
D=256,nw=1≈ 83.5 KB andnw=2≈ 134.6 KB, sopick_nwlands on nw=1 (sm_86) / nw=2 (A100) and both fit. - Determinism preserved (no atomics, fixed-order reduce, same-stream ordering between the dkv kernel and the reduce); fp32 path untouched; vectorized-load 16-byte alignment is guaranteed by the host-side
D % 32check.
Non-blocking suggestions (for a follow-up, not gating this merge)
- (memory scaling)
dKp/dVpareO(H·B·Sk·D)fp32 globals —2·(H/Hkv)×the bytes of the final dK/dV per tensor (16× each under MQAHkv=1, H=8). Tiny at the benchmarked shape (~33 MB, still the memory winner with +ckpt) but grows linearly withB·Sk; the prior kernel had no global temp here, so a one-line note in the comment about the scaling would set expectations. - (robustness) The backward dkv/dq WMMA launches use
pick_nw+cudaFuncSetAttributebut lack the actionableTORCH_CHECK(smem <= max_smem, ...)the forward path has. If evennw=1exceeds the device opt-in cap,pick_nwstill returns 1 and you get an opaque launch failure instead of a clear message. Pre-existing, and the targeted arches fitnw=1atD=256— but since this PR raises the dkv footprint, mirroring the forward's check is a cheap win. - (test) Consider committing a bitwise-determinism regression test (run
flash_bwdtwice, assert dQ/dK/dV bit-identical). The correctness story for the rewrite rests on determinism; you verified it ad hoc but it isn't locked in. Existing tests already exercise the new paths well (GQAhkv=2/ MQAhkv=1group reductions, multi-tile pipelines viasq=40/50, non-square cross-attention), so this is purely additive. - (nit) Backward correctness tests use
H=4; the production pi07_paligemma shape and the benchmark useH=8, Hkv=1(group=8). The reduce is generic over group size so correctness is unaffected, but a group=8 case would be marginally more representative. - (nit, style)
DKV_PREFETCHis a large multi-statement macro — justified by the ~15 captured locals, but a__device__ __forceinline__helper (or a small pointer-bundle struct) would be more type-safe and debugger-friendly. Optional. - (numerics note) dK/dV will differ in the last bits from #358's kernel — the group sum is now per-head-partial-then-reduce rather than one running in-block accumulator (a reassociation), and
scaleis applied to fp32 partials before the sum. Both are fine and within your test tolerances; just worth knowing if anyone bisects numerics against #358.
Caveat
GPU tests / A100 numbers — I could not run the GPU suite locally (CPU-only box), so the 134-pass / determinism claims rest on your reported runs. A100 benchmarking + full nightly regression are still pending per the PR, matching the #358 process. CPU CI is unaffected (the .cu isn't built on CPU runners; the benchmark script is import-only).
What this does
(⚡️ Performance) FlashAttention-2 follow-up to #358. The custom block-causal
flash_cudakernel landed in #358 with a real memory win but ran much slowerthan the
eager/sdpabackends at the pi07_paligemma shape (head_dim=256, MQA).This PR applies FA2 techniques to close most of that gap, benchmarked on a
local RTX 3090 (B2 S1024 D256 bf16, MQA):
flash_cuda(main)Three changes, all in
flash_blockmask.cu:grid.y = H, wasHkv).It now writes per-head fp32 partials
(H, B, Sk, D)that a newdkv_reduce_kernelsums over each GQA/MQA head group in a fixed order. Thisremoves the serial
for h in grouploop that, under MQA (Hkv=1), launched asingle key-head block looping all 8 query heads — starving the GPU. No
atomics, so bit-identical determinism is preserved (CLAUDE.md rule Fixing reward normalizer #3).
uint4) tile loads (load_tile_vec) across all WMMAkernels, replacing the scalar per-element loads that round-tripped every
element bf16→float→bf16. The kernels were memory-bound, so this was the single
biggest win.
cp.asyncdouble-buffered streaming of Q/dO in the dominant dK/dV kernel(
cpasync_tile,__pipeline_*), iterating a contiguous query-tile suffixderived from block-id monotonicity (skip-free pipeline).
Honest verdict: this is a 2.85–4.5× speedup over the previous
flash_cudaand the memory win is preserved (~2× under eager), but it still does not beat
eager/sdpa(~3× off) at head_dim=256. After vectorization the kernels arecompute/smem-bound, not load-bound, so
cp.asynconly bought ~9% on the dK/dVkernel; I deliberately did not force it into the dQ/forward kernels where it
would cost an occupancy drop against the 99 KB sm_86 shared-memory cap. Fully
closing the gap to cuDNN would require a register-resident
mma.syncrewrite(tracked as future work). The memory cost rises slightly (per-head dK/dV
partials are
H×the gradient, transient during backward) but stays well beloweager.
How it was tested
tests/policies/test_flash_attn_cuda.pypass on alocal RTX 3090 (fp32 vs dense ref ~1e-4; bf16/fp16 WMMA vs
sdpa~2e-2; WMMAbackward vs fp32 autograd ~5e-2; non-square cross-attention; GQA/MQA
Hkv ∈ {1,2}; padding; every mask pattern).on 4 configs (output + dQ/dK/dV), confirming no
cp.asyncrace and noatomic nondeterminism.
python -m opentau.scripts.benchmark_flash_attn(extended to report checkpointed fwd+bwd latency for eager/sdpa/flash).
is currently unavailable; will validate there once free, as with feat(pi07_paligemma): custom CUDA block-causal flash attention #358).
How to checkout & try? (for the reviewer)
pytest -m "gpu" -n 0 tests/policies/test_flash_attn_cuda.pyChecklist
flash_attn_cudaGPU tests (134 pass) + determinism on a local RTX 3090. Full nightly regression + A100 benchmarking are still pending (A100 node unavailable).Note: Before submitting this PR, please read the contributor guideline.