From d2dfc396a2d3c21dfa78c0824b80a70d507650b7 Mon Sep 17 00:00:00 2001 From: William Yue Date: Tue, 2 Jun 2026 11:36:21 -0700 Subject: [PATCH] perf(flash_cuda): head-parallel dK/dV + vectorized loads + cp.async FlashAttention-2 follow-up to #358. On head_dim=256 / MQA the custom block-causal kernel was much slower than eager despite the memory win; this closes most of that gap (3090, B2 S1024 D256 bf16): per-op fwd+bwd 32.4 -> 7.13 ms (4.5x) stacked-18L +ckpt 842.7 -> 295.8 ms (2.85x) peak mem +ckpt 0.21 -> 0.24 GB (still ~2x under eager) Changes (all in flash_blockmask.cu): - dK/dV WMMA kernel parallelized over query heads (grid.y=H, was Hkv): per-head fp32 partials (H,B,Sk,D) reduced by a new dkv_reduce_kernel. No atomics -> bit-identical determinism preserved. Removes the serial head-group loop that starved the GPU under MQA (Hkv=1). - Vectorized 128-bit (uint4) tile loads across all WMMA kernels, replacing scalar per-element loads with a bf16->float->bf16 round-trip. - cp.async double-buffered streaming of Q/dO in the dK/dV kernel over a contiguous query-tile suffix. Still ~3x slower than eager/sdpa at head_dim=256 (closing that needs a full mma.sync rewrite). benchmark_flash_attn.py now reports checkpointed fwd+bwd latency for eager/sdpa/flash. Benched on local RTX 3090; A100 benchmarking + nightly regression still pending. --- .../flash_attn_cuda/_csrc/flash_blockmask.cu | 415 ++++++++++++------ src/opentau/scripts/benchmark_flash_attn.py | 61 ++- 2 files changed, 322 insertions(+), 154 deletions(-) diff --git a/src/opentau/policies/flash_attn_cuda/_csrc/flash_blockmask.cu b/src/opentau/policies/flash_attn_cuda/_csrc/flash_blockmask.cu index 66b710c..5dc8c27 100644 --- a/src/opentau/policies/flash_attn_cuda/_csrc/flash_blockmask.cu +++ b/src/opentau/policies/flash_attn_cuda/_csrc/flash_blockmask.cu @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -74,6 +75,57 @@ __device__ __forceinline__ long idx_lse(long b, long h, long s, long H, long S) return (b * H + h) * S + s; } +// Vectorized (128-bit / 8-element) cooperative tile load: copy `rows` rows of D +// contiguous elements each from a strided global source into a packed smem tile, +// 16 bytes per transaction. Rows >= max_row are zero-filled. Both src and dst are +// the WMMA element type (bf16/fp16), so this is a raw byte copy — no per-element +// dtype round-trip. Threads [tid, tid+nthreads, ...) of the calling group +// participate. Alignment: D is a multiple of 32 (checked host-side) and every row +// base is a multiple of D, so all uint4 accesses are 16-byte aligned. +// global row r lives at src_base + (first_row + r) * row_stride +// smem row r lives at dst + r * D +template +__device__ __forceinline__ void load_tile_vec( + wt* __restrict__ dst, const wt* __restrict__ src_base, long row_stride, + int rows, int D, long first_row, long max_row, int tid, int nthreads) { + constexpr int VEC = 16 / sizeof(wt); // 8 for bf16/fp16 + const int dvec = D / VEC; + const uint4 zero = make_uint4(0u, 0u, 0u, 0u); + for (int c = tid; c < rows * dvec; c += nthreads) { + const int r = c / dvec, dv = c % dvec; + const long gr = first_row + r; + uint4* d4 = reinterpret_cast(dst + (long)r * D) + dv; + if (gr < max_row) { + *d4 = *(reinterpret_cast(src_base + gr * row_stride) + dv); + } else { + *d4 = zero; + } + } +} + +// Asynchronous (cp.async) variant of load_tile_vec: issues 16-byte global->smem +// copies that the warp can overlap with compute. The caller must __pipeline_commit() +// after issuing a tile's copies and __pipeline_wait_prior(...) before reading them. +// Out-of-range rows are zero-filled synchronously (rare; only the last tile). +template +__device__ __forceinline__ void cpasync_tile( + wt* __restrict__ dst, const wt* __restrict__ src_base, long row_stride, + int rows, int D, long first_row, long max_row, int tid, int nthreads) { + constexpr int VEC = 16 / sizeof(wt); // 8 for bf16/fp16 + const int dvec = D / VEC; + for (int c = tid; c < rows * dvec; c += nthreads) { + const int r = c / dvec, dv = c % dvec; + const long gr = first_row + r; + uint4* d4 = reinterpret_cast(dst + (long)r * D) + dv; + if (gr < max_row) { + const uint4* s4 = reinterpret_cast(src_base + gr * row_stride) + dv; + __pipeline_memcpy_async(d4, s4, sizeof(uint4)); + } else { + *d4 = make_uint4(0u, 0u, 0u, 0u); + } + } +} + // ========================================================================= // Forward: O = softmax(scale * Q K^T + mask) V, plus L = logsumexp per row. // One warp per query row. Dynamic smem holds an fp32 K/V tile + block metadata. @@ -475,6 +527,10 @@ __global__ void flash_fwd_wmma_kernel( using namespace nvcuda; using wt = typename WmmaTraits::wt; auto f2w = WmmaTraits::from_float; + // Raw (bit-compatible) views for vectorized 16-byte tile loads. + const wt* Qw = reinterpret_cast(Q); + const wt* Kw = reinterpret_cast(K); + const wt* Vw = reinterpret_cast(V); // Shared layout: per-warp slabs for Q/O/P/S/stats, block-shared K/V tile. extern __shared__ char smem_raw[]; @@ -509,14 +565,10 @@ __global__ void flash_fwd_wmma_kernel( int* qblkW = sQblk + warp * BR_W; char* qvalidW = sQvalid + warp * BR_W; - // Init this warp's accumulator + load its query slab. + // Init this warp's accumulator + load its query slab (vectorized). for (int idx = lane; idx < BR_W * D; idx += WARP) oW[idx] = 0.f; for (int i = lane; i < BR_W; i += WARP) { mW[i] = -INFINITY; lW[i] = 0.f; } - for (int idx = lane; idx < BR_W * D; idx += WARP) { - const int i = idx / D, d = idx % D; - const long qi = q0 + i; - qW[idx] = (qi < Sq) ? f2w(to_f(Q[idx_qkv(b, qi, h, d, Sq, H, D)])) : f2w(0.f); - } + load_tile_vec(qW, Qw + ((long)b * Sq * H + h) * D, (long)H * D, BR_W, D, q0, Sq, lane, WARP); for (int i = lane; i < BR_W; i += WARP) { const long qi = q0 + i; qblkW[i] = (qi < Sq) ? q_blk[b * Sq + qi] : INT_MIN; @@ -532,15 +584,11 @@ __global__ void flash_fwd_wmma_kernel( const int n_tiles = (Sk + BC_W - 1) / BC_W; for (int kt = 0; kt < n_tiles; ++kt) { const long kc0 = (long)kt * BC_W; - // Cooperative K/V tile load by ALL warps (overlaps global-load latency). - for (int idx = threadIdx.x; idx < BC_W * D; idx += blockDim.x) { - const int j = idx / D, d = idx % D; - const long kj = kc0 + j; - if (kj < Sk) { - Ks[idx] = f2w(to_f(K[idx_qkv(b, kj, hk, d, Sk, Hkv, D)])); - Vs[idx] = f2w(to_f(V[idx_qkv(b, kj, hk, d, Sk, Hkv, D)])); - } else { Ks[idx] = f2w(0.f); Vs[idx] = f2w(0.f); } - } + // Cooperative, vectorized K/V tile load by ALL warps. + const wt* Kbase = Kw + ((long)b * Sk * Hkv + hk) * D; + const wt* Vbase = Vw + ((long)b * Sk * Hkv + hk) * D; + load_tile_vec(Ks, Kbase, (long)Hkv * D, BC_W, D, kc0, Sk, threadIdx.x, blockDim.x); + load_tile_vec(Vs, Vbase, (long)Hkv * D, BC_W, D, kc0, Sk, threadIdx.x, blockDim.x); for (int j = threadIdx.x; j < BC_W; j += blockDim.x) { const long kj = kc0 + j; sKblk[j] = (kj < Sk) ? k_blk[b * Sk + kj] : INT_MAX; @@ -658,11 +706,17 @@ __host__ inline size_t bwd_dq_smem(int D, int nw) { return bf16 * 2 + f32 * 4 + i32 * 4 + c8; } +// Streamed Q/dO (and their per-query metadata) are double-buffered so cp.async +// can prefetch the next query tile while the current one is consumed. +constexpr int DKV_NBUF = 2; + __host__ inline size_t bwd_dkv_smem(int D, int nw) { - size_t bf16 = (size_t)nw * (2 * BR_W * D + 2 * BR_W * BR_W) + (size_t)2 * BR_W * D; // Ks,Vs,Ps,dSs;Qs,dOs - size_t f32 = (size_t)nw * (2 * BR_W * D + 2 * BR_W * BR_W) + (size_t)2 * BR_W; // dKacc,dVacc,ss,dp;Li,delta - size_t i32 = (size_t)nw * BR_W + BR_W; - size_t c8 = (size_t)nw * BR_W + BR_W; + // owned (nw): Ks,Vs,Ps,dSs ; streamed (×DKV_NBUF): Qs,dOs + size_t bf16 = (size_t)nw * (2 * BR_W * D + 2 * BR_W * BR_W) + (size_t)DKV_NBUF * 2 * BR_W * D; + // owned (nw): dKacc,dVacc,ss,dp ; streamed (×DKV_NBUF): Li,delta + size_t f32 = (size_t)nw * (2 * BR_W * D + 2 * BR_W * BR_W) + (size_t)DKV_NBUF * 2 * BR_W; + size_t i32 = (size_t)nw * BR_W + (size_t)DKV_NBUF * BR_W; // Kblk ; Qblk×NBUF + size_t c8 = (size_t)nw * BR_W + (size_t)DKV_NBUF * BR_W; // Kvalid ; Qvalid×NBUF return bf16 * 2 + f32 * 4 + i32 * 4 + c8; } @@ -676,6 +730,10 @@ __global__ void flash_bwd_dq_wmma_kernel( using namespace nvcuda; using wt = typename WmmaTraits::wt; auto f2w = WmmaTraits::from_float; + const wt* Qw = reinterpret_cast(Q); + const wt* Kw = reinterpret_cast(K); + const wt* Vw = reinterpret_cast(V); + const wt* dOw = reinterpret_cast(dO); extern __shared__ char smem_raw[]; wt* Qs = reinterpret_cast(smem_raw); // nw*BR*D @@ -711,14 +769,8 @@ __global__ void flash_bwd_dq_wmma_kernel( char* qvalidW = sQvalid + (long)warp * BR_W; for (int idx = lane; idx < BR_W * D; idx += WARP) dqW[idx] = 0.f; - for (int idx = lane; idx < BR_W * D; idx += WARP) { - const int i = idx / D, d = idx % D; - const long qi = q0 + i; - if (qi < Sq) { - qW[idx] = f2w(to_f(Q[idx_qkv(b, qi, h, d, Sq, H, D)])); - doW[idx] = f2w(to_f(dO[idx_qkv(b, qi, h, d, Sq, H, D)])); - } else { qW[idx] = f2w(0.f); doW[idx] = f2w(0.f); } - } + load_tile_vec(qW, Qw + ((long)b * Sq * H + h) * D, (long)H * D, BR_W, D, q0, Sq, lane, WARP); + load_tile_vec(doW, dOw + ((long)b * Sq * H + h) * D, (long)H * D, BR_W, D, q0, Sq, lane, WARP); for (int i = lane; i < BR_W; i += WARP) { const long qi = q0 + i; if (qi < Sq) { @@ -735,14 +787,10 @@ __global__ void flash_bwd_dq_wmma_kernel( const int n_tiles = (Sk + BC_W - 1) / BC_W; for (int kt = 0; kt < n_tiles; ++kt) { const long kc0 = (long)kt * BC_W; - for (int idx = threadIdx.x; idx < BC_W * D; idx += blockDim.x) { - const int j = idx / D, d = idx % D; - const long kj = kc0 + j; - if (kj < Sk) { - Ks[idx] = f2w(to_f(K[idx_qkv(b, kj, hk, d, Sk, Hkv, D)])); - Vs[idx] = f2w(to_f(V[idx_qkv(b, kj, hk, d, Sk, Hkv, D)])); - } else { Ks[idx] = f2w(0.f); Vs[idx] = f2w(0.f); } - } + const wt* Kbase = Kw + ((long)b * Sk * Hkv + hk) * D; + const wt* Vbase = Vw + ((long)b * Sk * Hkv + hk) * D; + load_tile_vec(Ks, Kbase, (long)Hkv * D, BC_W, D, kc0, Sk, threadIdx.x, blockDim.x); + load_tile_vec(Vs, Vbase, (long)Hkv * D, BC_W, D, kc0, Sk, threadIdx.x, blockDim.x); for (int j = threadIdx.x; j < BC_W; j += blockDim.x) { const long kj = kc0 + j; sKblk[j] = (kj < Sk) ? k_blk[b * Sk + kj] : INT_MAX; @@ -817,40 +865,52 @@ __global__ void flash_bwd_dq_wmma_kernel( } } +// dK/dV WMMA kernel, parallelized over query heads (grid.y == H, NOT Hkv). Each +// block computes ONE query head's full contribution to dK/dV for its key slab +// and writes it to per-head fp32 partial buffers dKp/dVp of shape (H, B, Sk, D). +// Because every (h, b, key, d) is written by exactly one block, there is no +// cross-block accumulation -> no atomics -> bit-identical determinism. A +// separate reduction (dkv_reduce_kernel) sums the GQA/MQA head group into the +// final (B, Sk, Hkv, D) dK/dV. This removes the serial `for h in group` loop the +// old single-Hkv-block design ran, which starved the GPU under MQA (Hkv=1). template __global__ void flash_bwd_dkv_wmma_kernel( const scalar_t* __restrict__ Q, const scalar_t* __restrict__ K, const scalar_t* __restrict__ V, const scalar_t* __restrict__ dO, const float* __restrict__ L, const float* __restrict__ delta, const int* __restrict__ q_blk, const int* __restrict__ k_blk, const bool* __restrict__ q_valid, const bool* __restrict__ k_valid, - scalar_t* __restrict__ dK, scalar_t* __restrict__ dV, + float* __restrict__ dKp, float* __restrict__ dVp, int B, int Sq, int Sk, int H, int Hkv, int D, float scale, int nw) { using namespace nvcuda; using wt = typename WmmaTraits::wt; auto f2w = WmmaTraits::from_float; + const wt* Qw = reinterpret_cast(Q); + const wt* Kw = reinterpret_cast(K); + const wt* Vw = reinterpret_cast(V); + const wt* dOw = reinterpret_cast(dO); extern __shared__ char smem_raw[]; wt* Ks = reinterpret_cast(smem_raw); // nw*BR*D (owned key slab) wt* Vs = Ks + (long)nw * BR_W * D; // nw*BR*D wt* Ps = Vs + (long)nw * BR_W * D; // nw*BR*BR wt* dSs = Ps + (long)nw * BR_W * BR_W; // nw*BR*BR - wt* Qs = dSs + (long)nw * BR_W * BR_W; // BR*D (streamed) - wt* dOs = Qs + (long)BR_W * D; // BR*D - float* dKacc = reinterpret_cast(dOs + (long)BR_W * D); // nw*BR*D + wt* Qs = dSs + (long)nw * BR_W * BR_W; // DKV_NBUF*BR*D (streamed, double-buffered) + wt* dOs = Qs + (long)DKV_NBUF * BR_W * D; // DKV_NBUF*BR*D + float* dKacc = reinterpret_cast(dOs + (long)DKV_NBUF * BR_W * D); // nw*BR*D float* dVacc = dKacc + (long)nw * BR_W * D; // nw*BR*D float* Ss = dVacc + (long)nw * BR_W * D; // nw*BR*BR float* dPs = Ss + (long)nw * BR_W * BR_W; // nw*BR*BR - float* sLi = dPs + (long)nw * BR_W * BR_W; // BR - float* sdl = sLi + (long)BR_W; // BR - int* sKblk = reinterpret_cast(sdl + (long)BR_W); // nw*BR (owned) - int* sQblk = sKblk + (long)nw * BR_W; // BR (streamed) - char* sKvalid = reinterpret_cast(sQblk + (long)BR_W); // nw*BR - char* sQvalid = sKvalid + (long)nw * BR_W; // BR + float* sLi = dPs + (long)nw * BR_W * BR_W; // DKV_NBUF*BR + float* sdl = sLi + (long)DKV_NBUF * BR_W; // DKV_NBUF*BR + int* sKblk = reinterpret_cast(sdl + (long)DKV_NBUF * BR_W); // nw*BR (owned) + int* sQblk = sKblk + (long)nw * BR_W; // DKV_NBUF*BR (streamed) + char* sKvalid = reinterpret_cast(sQblk + (long)DKV_NBUF * BR_W); // nw*BR + char* sQvalid = sKvalid + (long)nw * BR_W; // DKV_NBUF*BR const int warp = threadIdx.x / WARP; const int lane = threadIdx.x % WARP; - const int b = blockIdx.z, hk = blockIdx.y; - const int group = H / Hkv; + const int b = blockIdx.z, h = blockIdx.y; // query head (grid.y == H) + const int hk = h / (H / Hkv); // mapped kv head const long k0 = (long)blockIdx.x * (nw * BR_W) + (long)warp * BR_W; // this warp's key slab wt* ksW = Ks + (long)warp * BR_W * D; @@ -865,14 +925,10 @@ __global__ void flash_bwd_dkv_wmma_kernel( char* kvalidW = sKvalid + (long)warp * BR_W; for (int idx = lane; idx < BR_W * D; idx += WARP) { dkW[idx] = 0.f; dvW[idx] = 0.f; } - for (int idx = lane; idx < BR_W * D; idx += WARP) { - const int j = idx / D, d = idx % D; - const long kj = k0 + j; - if (kj < Sk) { - ksW[idx] = f2w(to_f(K[idx_qkv(b, kj, hk, d, Sk, Hkv, D)])); - vsW[idx] = f2w(to_f(V[idx_qkv(b, kj, hk, d, Sk, Hkv, D)])); - } else { ksW[idx] = f2w(0.f); vsW[idx] = f2w(0.f); } - } + const wt* Kbase = Kw + ((long)b * Sk * Hkv + hk) * D; + const wt* Vbase = Vw + ((long)b * Sk * Hkv + hk) * D; + load_tile_vec(ksW, Kbase, (long)Hkv * D, BR_W, D, k0, Sk, lane, WARP); + load_tile_vec(vsW, Vbase, (long)Hkv * D, BR_W, D, k0, Sk, lane, WARP); for (int j = lane; j < BR_W; j += WARP) { const long kj = k0 + j; kblkW[j] = (kj < Sk) ? k_blk[b * Sk + kj] : INT_MAX; @@ -885,98 +941,173 @@ __global__ void flash_bwd_dkv_wmma_kernel( if (sKvalid[j]) block_min_kblk = min(block_min_kblk, sKblk[j]); const int n_qtiles = (Sq + BR_W - 1) / BR_W; - for (int h = hk * group; h < (hk + 1) * group; ++h) { - for (int qt = 0; qt < n_qtiles; ++qt) { - const long qc0 = (long)qt * BR_W; - const long last_q = min((long)qc0 + BR_W - 1, (long)Sq - 1); - if (q_blk[b * Sq + last_q] < block_min_kblk) continue; // uniform skip - - for (int idx = threadIdx.x; idx < BR_W * D; idx += blockDim.x) { - const int q = idx / D, d = idx % D; - const long qi = qc0 + q; - if (qi < Sq) { - Qs[idx] = f2w(to_f(Q[idx_qkv(b, qi, h, d, Sq, H, D)])); - dOs[idx] = f2w(to_f(dO[idx_qkv(b, qi, h, d, Sq, H, D)])); - } else { Qs[idx] = f2w(0.f); dOs[idx] = f2w(0.f); } - } - for (int q = threadIdx.x; q < BR_W; q += blockDim.x) { - const long qi = qc0 + q; - if (qi < Sq) { - sLi[q] = L[idx_lse(b, h, qi, H, Sq)]; sdl[q] = delta[idx_lse(b, h, qi, H, Sq)]; - sQblk[q] = q_blk[b * Sq + qi]; sQvalid[q] = q_valid[b * Sq + qi] ? 1 : 0; - } else { sLi[q] = 0.f; sdl[q] = 0.f; sQblk[q] = INT_MIN; sQvalid[q] = 0; } - } - __syncthreads(); - // S = Q K^T and dP = dO V^T (rows = streamed queries, cols = owned keys). - wmma::fragment accS, accP; - wmma::fill_fragment(accS, 0.f); - wmma::fill_fragment(accP, 0.f); - for (int kk = 0; kk < D / WK; ++kk) { - wmma::fragment aq, ado; - wmma::fragment bk, bv; - wmma::load_matrix_sync(aq, Qs + kk * WK, D); - wmma::load_matrix_sync(bk, ksW + kk * WK, D); - wmma::mma_sync(accS, aq, bk, accS); - wmma::load_matrix_sync(ado, dOs + kk * WK, D); - wmma::load_matrix_sync(bv, vsW + kk * WK, D); - wmma::mma_sync(accP, ado, bv, accP); - } - wmma::store_matrix_sync(ssW, accS, BR_W, wmma::mem_row_major); - wmma::store_matrix_sync(dpW, accP, BR_W, wmma::mem_row_major); - __syncwarp(); + // q_blk and k_blk are both non-decreasing, so the query tiles that can reach + // this key slab form a contiguous suffix [qt_start, n_qtiles). Find its start + // once, then stream the suffix with a skip-free cp.async double-buffered + // pipeline (prefetch the next query tile while consuming the current one). + __shared__ int s_qt_start; + if (threadIdx.x == 0) { + int qs = n_qtiles; + for (int qt = 0; qt < n_qtiles; ++qt) { + const long last_q = min((long)qt * BR_W + BR_W - 1, (long)Sq - 1); + if (q_blk[b * Sq + last_q] >= block_min_kblk) { qs = qt; break; } + } + s_qt_start = qs; + } + __syncthreads(); + const int qt_start = s_qt_start; + const int n_proc = n_qtiles - qt_start; + + const wt* Qbase = Qw + ((long)b * Sq * H + h) * D; + const wt* dObase = dOw + ((long)b * Sq * H + h) * D; + +// Issue cp.async prefetch + (synchronous) metadata load for suffix-tile T into +// double-buffer slot SLOT. +#define DKV_PREFETCH(T, SLOT) \ + do { \ + const long qc0 = (long)(qt_start + (T)) * BR_W; \ + cpasync_tile(Qs + (long)(SLOT) * BR_W * D, Qbase, (long)H * D, BR_W, D, qc0, Sq, \ + threadIdx.x, blockDim.x); \ + cpasync_tile(dOs + (long)(SLOT) * BR_W * D, dObase, (long)H * D, BR_W, D, qc0, Sq, \ + threadIdx.x, blockDim.x); \ + __pipeline_commit(); \ + float* sLiS = sLi + (long)(SLOT) * BR_W; \ + float* sdlS = sdl + (long)(SLOT) * BR_W; \ + int* sQblkS = sQblk + (long)(SLOT) * BR_W; \ + char* sQvalidS = sQvalid + (long)(SLOT) * BR_W; \ + for (int q = threadIdx.x; q < BR_W; q += blockDim.x) { \ + const long qi = qc0 + q; \ + if (qi < Sq) { \ + sLiS[q] = L[idx_lse(b, h, qi, H, Sq)]; \ + sdlS[q] = delta[idx_lse(b, h, qi, H, Sq)]; \ + sQblkS[q] = q_blk[b * Sq + qi]; \ + sQvalidS[q] = q_valid[b * Sq + qi] ? 1 : 0; \ + } else { \ + sLiS[q] = 0.f; sdlS[q] = 0.f; sQblkS[q] = INT_MIN; sQvalidS[q] = 0; \ + } \ + } \ + } while (0) + + if (n_proc > 0) DKV_PREFETCH(0, 0); + for (int t = 0; t < n_proc; ++t) { + const int cur = t & 1; + const bool has_next = (t + 1 < n_proc); + if (has_next) DKV_PREFETCH(t + 1, (t + 1) & 1); + __pipeline_wait_prior(has_next ? 1 : 0); // ensure the current tile has landed + __syncthreads(); - // P (=softmax prob) and dS, with masking (lane q owns streamed query q). - if (lane < BR_W) { - const int q = lane; - const int qb = sQblk[q]; - const float Li = sLi[q], di = sdl[q]; - const bool qv = sQvalid[q]; - for (int j = 0; j < BR_W; ++j) { - const long kj = k0 + j; - const bool att = qv && (kj < Sk) && kvalidW[j] && (kblkW[j] <= qb); - float p = 0.f, ds = 0.f; - if (att) { - p = __expf(ssW[q * BR_W + j] * scale - Li); - ds = p * (dpW[q * BR_W + j] - di); - } - psW[q * BR_W + j] = f2w(p); - dsW[q * BR_W + j] = f2w(ds); + wt* Qcur = Qs + (long)cur * BR_W * D; + wt* dObuf = dOs + (long)cur * BR_W * D; + float* sLicur = sLi + (long)cur * BR_W; + float* sdlcur = sdl + (long)cur * BR_W; + int* sQblkcur = sQblk + (long)cur * BR_W; + char* sQvalidcur = sQvalid + (long)cur * BR_W; + + // S = Q K^T and dP = dO V^T (rows = streamed queries, cols = owned keys). + wmma::fragment accS, accP; + wmma::fill_fragment(accS, 0.f); + wmma::fill_fragment(accP, 0.f); + for (int kk = 0; kk < D / WK; ++kk) { + wmma::fragment aq, ado; + wmma::fragment bk, bv; + wmma::load_matrix_sync(aq, Qcur + kk * WK, D); + wmma::load_matrix_sync(bk, ksW + kk * WK, D); + wmma::mma_sync(accS, aq, bk, accS); + wmma::load_matrix_sync(ado, dObuf + kk * WK, D); + wmma::load_matrix_sync(bv, vsW + kk * WK, D); + wmma::mma_sync(accP, ado, bv, accP); + } + wmma::store_matrix_sync(ssW, accS, BR_W, wmma::mem_row_major); + wmma::store_matrix_sync(dpW, accP, BR_W, wmma::mem_row_major); + __syncwarp(); + + // P (=softmax prob) and dS, with masking (lane q owns streamed query q). + if (lane < BR_W) { + const int q = lane; + const int qb = sQblkcur[q]; + const float Li = sLicur[q], di = sdlcur[q]; + const bool qv = sQvalidcur[q]; + for (int j = 0; j < BR_W; ++j) { + const long kj = k0 + j; + const bool att = qv && (kj < Sk) && kvalidW[j] && (kblkW[j] <= qb); + float p = 0.f, ds = 0.f; + if (att) { + p = __expf(ssW[q * BR_W + j] * scale - Li); + ds = p * (dpW[q * BR_W + j] - di); } + psW[q * BR_W + j] = f2w(p); + dsW[q * BR_W + j] = f2w(ds); } - __syncwarp(); - - // dV += P^T @ dO ; dK += dS^T @ Q (P,dS as col_major -> transpose). - for (int dn = 0; dn < D / WN; ++dn) { - wmma::fragment accv, acck; - wmma::load_matrix_sync(accv, dvW + dn * WN, D, wmma::mem_row_major); - wmma::load_matrix_sync(acck, dkW + dn * WN, D, wmma::mem_row_major); - wmma::fragment ap, ads; - wmma::fragment bdo, bq; - wmma::load_matrix_sync(ap, psW, BR_W); - wmma::load_matrix_sync(bdo, dOs + dn * WN, D); - wmma::mma_sync(accv, ap, bdo, accv); - wmma::load_matrix_sync(ads, dsW, BR_W); - wmma::load_matrix_sync(bq, Qs + dn * WN, D); - wmma::mma_sync(acck, ads, bq, acck); - wmma::store_matrix_sync(dvW + dn * WN, accv, D, wmma::mem_row_major); - wmma::store_matrix_sync(dkW + dn * WN, acck, D, wmma::mem_row_major); - } - __syncthreads(); // streamed Qs/dOs reused next tile } + __syncwarp(); + + // dV += P^T @ dO ; dK += dS^T @ Q (P,dS as col_major -> transpose). + for (int dn = 0; dn < D / WN; ++dn) { + wmma::fragment accv, acck; + wmma::load_matrix_sync(accv, dvW + dn * WN, D, wmma::mem_row_major); + wmma::load_matrix_sync(acck, dkW + dn * WN, D, wmma::mem_row_major); + wmma::fragment ap, ads; + wmma::fragment bdo, bq; + wmma::load_matrix_sync(ap, psW, BR_W); + wmma::load_matrix_sync(bdo, dObuf + dn * WN, D); + wmma::mma_sync(accv, ap, bdo, accv); + wmma::load_matrix_sync(ads, dsW, BR_W); + wmma::load_matrix_sync(bq, Qcur + dn * WN, D); + wmma::mma_sync(acck, ads, bq, acck); + wmma::store_matrix_sync(dvW + dn * WN, accv, D, wmma::mem_row_major); + wmma::store_matrix_sync(dkW + dn * WN, acck, D, wmma::mem_row_major); + } + __syncthreads(); // done with this slot's Qcur/dObuf before it is refilled } +#undef DKV_PREFETCH + // Write this query head's contribution to the per-head fp32 partials + // (H, B, Sk, D); the reduction below sums the head group. dK carries the + // softmax `scale`; dV does not (matches the chain rule for O = P V). if (lane < BR_W) { const int j = lane; const long kj = k0 + j; - if (kj < Sk) + if (kj < Sk) { + const long base = (((long)h * B + b) * Sk + kj) * D; for (int d = 0; d < D; ++d) { - dK[idx_qkv(b, kj, hk, d, Sk, Hkv, D)] = from_f(dkW[j * D + d] * scale); - dV[idx_qkv(b, kj, hk, d, Sk, Hkv, D)] = from_f(dvW[j * D + d]); + dKp[base + d] = dkW[j * D + d] * scale; + dVp[base + d] = dvW[j * D + d]; } + } } } +// Sum the per-head dK/dV partials (H, B, Sk, D) over each GQA/MQA head group +// into the final (B, Sk, Hkv, D) gradients, casting to the I/O dtype. One thread +// per output element; the group sum runs in a fixed order so it is deterministic. +template +__global__ void dkv_reduce_kernel( + const float* __restrict__ dKp, const float* __restrict__ dVp, + scalar_t* __restrict__ dK, scalar_t* __restrict__ dV, + int B, int Sk, int H, int Hkv, int D) { + const long total = (long)B * Sk * Hkv * D; + const long idx = (long)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + const int d = idx % D; + long t = idx / D; + const int hk = t % Hkv; + t /= Hkv; + const long kj = t % Sk; + const int b = t / Sk; + const int group = H / Hkv; + float ak = 0.f, av = 0.f; +#pragma unroll 1 + for (int g = 0; g < group; ++g) { + const int h = hk * group + g; + const long o = (((long)h * B + b) * Sk + kj) * D + d; + ak += dKp[o]; + av += dVp[o]; + } + dK[idx_qkv(b, kj, hk, d, Sk, Hkv, D)] = from_f(ak); + dV[idx_qkv(b, kj, hk, d, Sk, Hkv, D)] = from_f(av); +} + // ---- host launchers ------------------------------------------------------ // Largest nw in [1, nw_max] whose kernel shared memory fits the device opt-in cap. @@ -1159,7 +1290,11 @@ std::vector flash_bwd( } else { const int nw = pick_nw(bwd_dkv_smem, D, 2, max_smem); const size_t smem = bwd_dkv_smem(D, nw); - dim3 grid((Sk + nw * BR_W - 1) / (nw * BR_W), Hkv, B); + // Per-head fp32 partials (H, B, Sk, D); each (h,b,key,d) written by exactly + // one block (deterministic), then summed over the head group below. + auto dKp = at::empty({H, B, Sk, D}, q.options().dtype(at::kFloat)); + auto dVp = at::empty({H, B, Sk, D}, q.options().dtype(at::kFloat)); + dim3 grid((Sk + nw * BR_W - 1) / (nw * BR_W), H, B); // parallel over query heads dim3 block(nw * WARP); DISPATCH_FLOAT(q.scalar_type(), "flash_bwd_dkv_wmma", [&] { if constexpr (!std::is_same_v) { @@ -1169,10 +1304,22 @@ std::vector flash_bwd( q.data_ptr(), k.data_ptr(), v.data_ptr(), dO.data_ptr(), L.data_ptr(), delta.data_ptr(), q_blk.data_ptr(), k_blk.data_ptr(), q_valid.data_ptr(), - k_valid.data_ptr(), dK.data_ptr(), dV.data_ptr(), + k_valid.data_ptr(), dKp.data_ptr(), dVp.data_ptr(), B, Sq, Sk, H, Hkv, D, (float)scale, nw); } }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Reduce per-head partials -> (B, Sk, Hkv, D) dK/dV (deterministic group sum). + const long total = (long)B * Sk * Hkv * D; + const int threads = 256; + const unsigned blocks = (unsigned)((total + threads - 1) / threads); + DISPATCH_FLOAT(q.scalar_type(), "dkv_reduce", [&] { + if constexpr (!std::is_same_v) { + dkv_reduce_kernel<<>>( + dKp.data_ptr(), dVp.data_ptr(), dK.data_ptr(), + dV.data_ptr(), B, Sk, H, Hkv, D); + } + }); } C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/src/opentau/scripts/benchmark_flash_attn.py b/src/opentau/scripts/benchmark_flash_attn.py index e1e622e..806bcfc 100644 --- a/src/opentau/scripts/benchmark_flash_attn.py +++ b/src/opentau/scripts/benchmark_flash_attn.py @@ -195,14 +195,15 @@ def oom_demo(): def stacked_memory_comparison(b=2, s=1024, h=8, hkv=1, d=256, layers=18): - """Fair full-model-style memory comparison: a stack of attention layers run - fwd+bwd, with and without ``torch.utils.checkpoint``, for eager vs flash. - - This is the decision-relevant memory comparison (the per-op tables above - compare a single attention call with no checkpointing). Realistic - pi07_paligemma training uses ``gradient_checkpointing=True``, which already - frees eager's per-layer (B,H,S,S) scores by recomputing them in backward; - so the honest apples-to-apples is "eager+ckpt vs flash+ckpt". + """Fair full-model-style comparison: a stack of attention layers run fwd+bwd, + with and without ``torch.utils.checkpoint``, for eager / sdpa / flash. + + This is the decision-relevant comparison (the per-op tables above compare a + single attention call with no checkpointing). Realistic pi07_paligemma + training uses ``gradient_checkpointing=True``, which already frees eager's + per-layer (B,H,S,S) scores by recomputing them in backward; so the honest + apples-to-apples for training is the "+ckpt" rows. Reports both peak memory + and fwd+bwd wall-clock latency for each backend. """ import torch.utils.checkpoint as cp @@ -218,30 +219,50 @@ def eager_layer(x): aw = torch.where(dense[:, None], torch.matmul(q, k.transpose(-1, -2)) * scale, BIG_NEG) return torch.matmul(torch.softmax(aw, -1), v).permute(0, 2, 1, 3).to(x.dtype) + def sdpa_layer(x): + g = h // hkv + q = x.permute(0, 2, 1, 3) + k = x[:, :, :hkv].permute(0, 2, 1, 3).repeat_interleave(g, 1) + v = x[:, :, :hkv].permute(0, 2, 1, 3).repeat_interleave(g, 1) + o = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=dense[:, None], scale=scale) + return o.permute(0, 2, 1, 3) + def flash_layer(x): return flash_attn_blockmask( x, x[:, :, :hkv].contiguous(), x[:, :, :hkv].contiguous(), q_blk, k_blk, q_valid, k_valid, scale ) - def measure(layer, ckpt): + def fwd_bwd(layer, ckpt): + x = torch.randn(b, s, h, d, device=device, dtype=torch.bfloat16, requires_grad=True) + for _ in range(layers): + x = (cp.checkpoint(layer, x, use_reentrant=False) if ckpt else layer(x)) + 1.0 + x.sum().backward() + + def measure(layer, ckpt, iters=10): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() try: - x = torch.randn(b, s, h, d, device=device, dtype=torch.bfloat16, requires_grad=True) - for _ in range(layers): - x = (cp.checkpoint(layer, x, use_reentrant=False) if ckpt else layer(x)) + 1.0 - x.sum().backward() + for _ in range(2): # warmup + fwd_bwd(layer, ckpt) torch.cuda.synchronize() - return f"{torch.cuda.max_memory_allocated() / 1e9:.2f} GB" + mem = torch.cuda.max_memory_allocated() / 1e9 + s0, e0 = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + s0.record() + for _ in range(iters): + fwd_bwd(layer, ckpt) + e0.record() + torch.cuda.synchronize() + return f"{s0.elapsed_time(e0) / iters:7.2f} ms / {mem:5.2f} GB" except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() - return "OOM" + return " OOM " - print(f"{layers} stacked attention layers, B{b} S{s} H{h} Hkv{hkv} D{d} bf16, fwd+bwd peak:") - print(f" eager, no-ckpt : {measure(eager_layer, False)}") - print(f" eager, +ckpt : {measure(eager_layer, True)} <- realistic training baseline") - print(f" flash, no-ckpt : {measure(flash_layer, False)}") - print(f" flash, +ckpt : {measure(flash_layer, True)}") + print( + f"{layers} stacked attention layers, B{b} S{s} H{h} Hkv{hkv} D{d} bf16, fwd+bwd (latency / peak mem):" + ) + print(f" {'backend':6s} {'no-ckpt':>22s} {'+ckpt (realistic training)':>26s}") + for name, layer in [("eager", eager_layer), ("sdpa", sdpa_layer), ("flash", flash_layer)]: + print(f" {name:6s} {measure(layer, False):>22s} {measure(layer, True):>26s}") def sdpa_backend_probe(b=1, s=4096, h=8, hkv=1, d=256):