From 58535e26fe1953fd3bb343fa5b38bcd39f515d82 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 7 May 2026 10:51:37 +0800 Subject: [PATCH 01/14] feat(kda/sm100): support GVA (HV > HQK) in fwd intra/recomp kernels Follow the GVA pattern used in the SM90 KDA (and in gated_delta_rule GVA) so that the SM100 KDA forward pass can handle num_v_heads > num_qk_heads. C++ changes: - tile_scheduler: Params now carries heads_per_group; decode_tile_coord enumerates tiles in v-head space and returns both v_head_idx and qk_head_idx (= v_head_idx / heads_per_group). When HV == HQK this degenerates to the previous behaviour. - kda_config: KDA_fwd_intra_params / KDA_fwd_recomp_w_u_params split h into h_qk and h_v and cache heads_per_group; Akk and w/u/kg/qg layouts now live in v-head space. - intra kernel/mainloop: Q/K TMA descriptors use shape_QK (total, d, h_qk); g TMA uses shape_VG (total, d, h_v). Load warp slices Q/K with qk_head_idx and g with v_head_idx; Aqk row stride and beta stride now use params.h_v. - recomp_w_u kernel/mainloop: K/Q TMA descriptors use shape_QK; V/g TMA use shape_VG; Akk TMA uses shape_Akk (total, BT, h_v). Load warp slices K/Q with qk_head_idx and V/g/Akk with v_head_idx; w/u/kg/qg write stride and beta stride now use params.h_v. API / Python: - kda_sm100.cu: derive h_qk from Q/K and h_v from V/g; validate HV % HQK == 0 and beta/qg_out shapes. - cula/kda/chunk_intra.py: infer HQK from k.shape[2] and HV from v.shape[2]; allocate Aqk, Akk, w, kg, qg in v-head space; add shape assertions. Backward compatible: when HV == HQK, heads_per_group == 1 and qk_head_idx == v_head_idx, and all shapes/strides reduce to the pre-GVA layout. --- csrc/api/kda_sm100.cu | 88 +++++++++++++++++-- csrc/kda/sm100/kda_config.hpp | 50 ++++++----- csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp | 29 ++++-- .../sm100/kda_fwd_intra_mainloop_sm100.hpp | 51 +++++++---- .../sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp | 33 ++++--- .../kda_fwd_recomp_w_u_mainloop_sm100.hpp | 57 ++++++++---- csrc/kda/sm100/tile_scheduler.hpp | 29 ++++-- cula/kda/chunk_intra.py | 29 ++++-- 8 files changed, 268 insertions(+), 98 deletions(-) diff --git a/csrc/api/kda_sm100.cu b/csrc/api/kda_sm100.cu index ac32411..ff89887 100644 --- a/csrc/api/kda_sm100.cu +++ b/csrc/api/kda_sm100.cu @@ -37,7 +37,33 @@ ChunkKDAFwdIntra( KDA_fwd_intra_params params; params.total_q_len = q.size(0) * q.size(1); params.b = cu_seqlens.size(0) - 1; - params.h = q.size(2); + // GVA: Q/K are in h_qk head space (from q.size(2)); g/beta/Aqk/Akk are in h_v head + // space (from g.size(2)). When HV == HQK, heads_per_group == 1 and behaviour matches + // the pre-GVA path. + params.h_qk = q.size(2); + params.h_v = g.size(2); + TORCH_CHECK( + k.size(2) == params.h_qk, + "ChunkKDAFwdIntra: k.size(2) (", + k.size(2), + ") must match q.size(2) (", + params.h_qk, + ") under GVA (Q/K share h_qk)."); + TORCH_CHECK( + beta.size(-1) == params.h_v, + "ChunkKDAFwdIntra: beta.size(-1) (", + beta.size(-1), + ") must equal h_v (", + params.h_v, + ")."); + TORCH_CHECK( + params.h_qk > 0 && params.h_v > 0 && params.h_v % params.h_qk == 0, + "ChunkKDAFwdIntra: h_v (", + params.h_v, + ") must be a positive multiple of h_qk (", + params.h_qk, + ")."); + params.heads_per_group = params.h_v / params.h_qk; params.d = q.size(3); params.chunk_size = chunk_size; params.scale = scale; @@ -56,13 +82,15 @@ ChunkKDAFwdIntra( params.chunk_indices_ptr = chunk_indices.data_ptr(); params.Aqk_out_ptr = Aqk_out.data_ptr(); params.Akk_out_ptr = Akk_out.data_ptr(); - params.shape_Akk = cute::make_shape(params.total_q_len, params.chunk_size, params.h); - params.stride_Akk = cute::make_stride(params.chunk_size * params.h, cute::_1{}, params.chunk_size); + // Akk is laid out per v-head: (total_len, chunk_size, h_v). + params.shape_Akk = cute::make_shape(params.total_q_len, params.chunk_size, params.h_v); + params.stride_Akk = cute::make_stride(params.chunk_size * params.h_v, cute::_1{}, params.chunk_size); int tile_num = chunk_indices.size(0); auto device_prop = at::cuda::getCurrentDeviceProperties(); params.num_sm = device_prop->multiProcessorCount; - params.tile_scheduler_params = - StaticPersistentTileScheduler::Params{tile_num, params.h, params.num_sm, (int*)tile_counter.data_ptr()}; + // Tiles are enumerated in v-head space. + params.tile_scheduler_params = StaticPersistentTileScheduler::Params{ + tile_num, params.h_v, params.heads_per_group, params.num_sm, (int*)tile_counter.data_ptr()}; kda::sm100::run_kda_fwd_intra_sm100(params, at::cuda::getCurrentCUDAStream()); } @@ -85,7 +113,31 @@ ChunkKDAFwdRecompWU( KDA_fwd_recomp_w_u_params params; params.total_len = k.size(0) * k.size(1); params.b = cu_seqlens.size(0) - 1; - params.h = k.size(2); + // GVA: K (and optional Q) live in h_qk space; V/G/beta/A/w/u/kg/qg live in h_v space. + params.h_qk = k.size(2); + params.h_v = v.size(2); + TORCH_CHECK( + g.size(2) == params.h_v, + "ChunkKDAFwdRecompWU: g.size(2) (", + g.size(2), + ") must equal v.size(2) (", + params.h_v, + ")."); + TORCH_CHECK( + beta.size(-1) == params.h_v, + "ChunkKDAFwdRecompWU: beta.size(-1) (", + beta.size(-1), + ") must equal h_v (", + params.h_v, + ")."); + TORCH_CHECK( + params.h_qk > 0 && params.h_v > 0 && params.h_v % params.h_qk == 0, + "ChunkKDAFwdRecompWU: h_v (", + params.h_v, + ") must be a positive multiple of h_qk (", + params.h_qk, + ")."); + params.heads_per_group = params.h_v / params.h_qk; params.d = k.size(3); params.chunk_size = chunk_size; TORCH_CHECK( @@ -108,14 +160,32 @@ ChunkKDAFwdRecompWU( TORCH_CHECK( has_q == has_qg_out, "ChunkKDAFwdRecompWU: q and qg_out must either both be provided or both be omitted."); params.store_qg = has_q && has_qg_out; + if (params.store_qg) { + TORCH_CHECK( + q->size(2) == params.h_qk, + "ChunkKDAFwdRecompWU: q.size(2) (", + q->size(2), + ") must equal h_qk (", + params.h_qk, + ")."); + TORCH_CHECK( + qg_out->size(2) == params.h_v, + "ChunkKDAFwdRecompWU: qg_out.size(2) (", + qg_out->size(2), + ") must equal h_v (", + params.h_v, + ")."); + } params.q_ptr = params.store_qg ? q->data_ptr() : nullptr; params.qg_out_ptr = params.store_qg ? qg_out->data_ptr() : nullptr; - params.shape_wukg = cute::make_shape(params.total_len, params.d, params.h); - params.stride_wukg = cute::make_stride(params.d * params.h, cute::_1{}, params.d); + // w/u/kg/qg are per v-head: (total_len, d, h_v). + params.shape_wukg = cute::make_shape(params.total_len, params.d, params.h_v); + params.stride_wukg = cute::make_stride(params.d * params.h_v, cute::_1{}, params.d); int tile_num = chunk_indices.size(0); auto device_prop = at::cuda::getCurrentDeviceProperties(); params.num_sm = device_prop->multiProcessorCount; - params.tile_scheduler_params = StaticPersistentTileScheduler::Params{tile_num, params.h, params.num_sm, nullptr}; + params.tile_scheduler_params = StaticPersistentTileScheduler::Params{ + tile_num, params.h_v, params.heads_per_group, params.num_sm, nullptr}; kda::sm100::run_kda_fwd_recomp_w_u_sm100(params, at::cuda::getCurrentCUDAStream()); } \ No newline at end of file diff --git a/csrc/kda/sm100/kda_config.hpp b/csrc/kda/sm100/kda_config.hpp index 6f96529..67b496a 100644 --- a/csrc/kda/sm100/kda_config.hpp +++ b/csrc/kda/sm100/kda_config.hpp @@ -17,12 +17,18 @@ #include "kda/sm100/tile_scheduler.hpp" struct KDA_fwd_intra_params { - using GmemShapeAkk = cute::Shape; // (seqlen_kv, seqlen_kv, h) + // Akk shape is (total_seqlen, chunk_size, num_v_heads). Under GVA (num_v_heads > num_qk_heads), + // Aqk and Akk are produced per v-head because g/beta/Akk scaling all live in v-head space. + using GmemShapeAkk = cute::Shape; // (seqlen_kv, chunk_size, h_v) using GmemStrideAkk = cute::Stride; int total_q_len; int b; - int h; + // GVA: Q/K are sized by num_qk_heads; V, g, beta are sized by num_v_heads; Aqk/Akk are per v-head. + // When num_v_heads == num_qk_heads, heads_per_group == 1 and behaviour matches the pre-GVA path. + int h_qk; + int h_v; + int heads_per_group; // = h_v / h_qk, precomputed on host int d; int chunk_size; float scale; @@ -30,12 +36,12 @@ struct KDA_fwd_intra_params { bool unified_gref; bool is_beta_bf16; - void* __restrict__ q_ptr; //[b, t, h, d] - void* __restrict__ k_ptr; //[b, t, h, d] - void* __restrict__ g_ptr; //[b, t, h, d] - void* __restrict__ beta_ptr; //[b, t, h] - void* __restrict__ Aqk_out_ptr; //[b, t, h, BT] - void* __restrict__ Akk_out_ptr; //[b, t, h, BT] + void* __restrict__ q_ptr; //[b, t, h_qk, d] + void* __restrict__ k_ptr; //[b, t, h_qk, d] + void* __restrict__ g_ptr; //[b, t, h_v, d] + void* __restrict__ beta_ptr; //[b, t, h_v] + void* __restrict__ Aqk_out_ptr; //[b, t, h_v, BT] + void* __restrict__ Akk_out_ptr; //[b, t, h_v, BT] void* __restrict__ cu_seqlens_ptr; //[b + 1] void* __restrict__ chunk_indices_ptr; //[(b * t) / chunk_size, 2] @@ -48,28 +54,32 @@ struct KDA_fwd_intra_params { }; struct KDA_fwd_recomp_w_u_params { - using GmemShapeWUKg = cute::Shape; // (seqlen_kv, seqlen_kv, h) + // w/u/kg/qg all have shape (total_seqlen, d, num_v_heads) under GVA. + using GmemShapeWUKg = cute::Shape; // (seqlen_kv, d, h_v) using GmemStrideWUKg = cute::Stride; int total_len; int b; - int h; + // GVA: K and (optional) Q are sized by num_qk_heads; V/G/beta/Akk/w/u/kg/qg are per v-head. + int h_qk; + int h_v; + int heads_per_group; // = h_v / h_qk, precomputed on host int d; int chunk_size; bool is_beta_bf16; - void* __restrict__ k_ptr; //[b, t, h, d] - void* __restrict__ v_ptr; //[b, t, h, d] - void* __restrict__ q_ptr; //[b, t, h, d] (optional, for StoreQG) - void* __restrict__ beta_ptr; //[b, t, h] - void* __restrict__ A_ptr; //[b. t, h, BT] - void* __restrict__ g_ptr; //[b, t, h, d] + void* __restrict__ k_ptr; //[b, t, h_qk, d] + void* __restrict__ v_ptr; //[b, t, h_v, d] + void* __restrict__ q_ptr; //[b, t, h_qk, d] (optional, for StoreQG) + void* __restrict__ beta_ptr; //[b, t, h_v] + void* __restrict__ A_ptr; //[b, t, h_v, BT] + void* __restrict__ g_ptr; //[b, t, h_v, d] void* __restrict__ cu_seqlens_ptr; //[b + 1] void* __restrict__ chunk_indices_ptr; //[(b * t) / chunk_size, 2] - void* __restrict__ w_out_ptr; //[b, t, h, d] - void* __restrict__ u_out_ptr; //[b, t, h, d] - void* __restrict__ kg_out_ptr; //[b, t, h, d] - void* __restrict__ qg_out_ptr; //[b, t, h, d] (optional, for StoreQG) + void* __restrict__ w_out_ptr; //[b, t, h_v, d] + void* __restrict__ u_out_ptr; //[b, t, h_v, d] + void* __restrict__ kg_out_ptr; //[b, t, h_v, d] + void* __restrict__ qg_out_ptr; //[b, t, h_v, d] (optional, for StoreQG) bool store_qg; diff --git a/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp b/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp index 60dc4b3..021bec8 100644 --- a/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp @@ -53,8 +53,8 @@ struct KdaChunkFwdIntraKernelSm100 { using SmemLayoutInputFP32 = typename Mainloop::SmemLayoutInputFP32; // TMA params (for host launcher) - template - using TmaParams = typename Mainloop::template TmaParams; + template + using TmaParams = typename Mainloop::template TmaParams; // Pipeline types (for construction in operator()) using PipelineQKG = typename Mainloop::PipelineQKG; @@ -321,29 +321,40 @@ __launch_bounds__(512, 1, 1) kda_fwd_intra_sm100_kernel_entry( template inline void run_kda_fwd_intra_sm100_impl_dispatch(KDA_fwd_intra_params& params, cudaStream_t stream) { - auto shape_QKG = make_shape(params.total_q_len, params.d, params.h); - auto stride_QKG = make_stride(params.h * params.d, _1{}, params.d); + // GVA: Q/K are sized by `h_qk`; G is sized by `h_v`. When HV == HQK + // (heads_per_group == 1), shape_QK and shape_VG coincide with the + // pre-GVA shape_QKG and behaviour is unchanged. + auto shape_QK = make_shape(params.total_q_len, params.d, params.h_qk); + auto stride_QK = make_stride(params.h_qk * params.d, _1{}, params.d); + auto shape_VG = make_shape(params.total_q_len, params.d, params.h_v); + auto stride_VG = make_stride(params.h_v * params.d, _1{}, params.d); // --- Build TMA descriptors --- auto tma_Q = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((ku::bf16*)params.q_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((ku::bf16*)params.q_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_K = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((ku::bf16*)params.k_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((ku::bf16*)params.k_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_G = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputFP32{}); // --- Pack TMA params --- - typename Kernel::template TmaParams + typename Kernel::template TmaParams< + decltype(shape_QK), + decltype(shape_VG), + decltype(tma_Q), + decltype(tma_K), + decltype(tma_G)> tma_params = { - shape_QKG, + shape_QK, + shape_VG, tma_Q, tma_K, tma_G, diff --git a/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp b/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp index 3aa2746..68f6baa 100644 --- a/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp @@ -227,9 +227,13 @@ struct KdaChunkFwdIntraMainloopSm100 { }; // ===================== TMA Params ===================== - template + // GVA: Q/K live in h_qk head space (shape_qk), while G lives in h_v + // head space (shape_vg). When h_v == h_qk both shapes coincide and the + // TMA descriptors degrade to the pre-GVA behaviour. + template struct TmaParams { - ShapeQKG shape_qkg; + ShapeQK shape_qk; + ShapeVG shape_vg; TMA_Q tma_q; TMA_K tma_k; TMA_G tma_g; @@ -318,7 +322,10 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // head_idx here is the v-head index (Aqk/Akk/beta/g live in v-head space). + // qk_head_idx is only consumed by the TMA load warp for Q/K slicing. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -502,7 +509,8 @@ struct KdaChunkFwdIntraMainloopSm100 { int token_offset = cu_seqlens_ptr[batch_idx]; int row = idx_in_warpgroup % 64; int BT = TileT; - int H = params.h; + // Aqk is laid out per v-head: row-stride is h_v * BT, head slot offset is head_idx * BT. + int H = params.h_v; __nv_bfloat16* Aqk_base = reinterpret_cast<__nv_bfloat16*>(params.Aqk_out_ptr); __nv_bfloat16* qk_out_row = Aqk_base + static_cast(token_offset + tile_idx * TileT + row) * H * BT + head_idx * BT; @@ -568,7 +576,10 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // MMA loop does not actually consume head_idx, but we decode to advance the + // same tile space as the other warps (num_blocks * num_v_heads). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -703,21 +714,24 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - // Decode tile coordinates - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Decode tile coordinates. head_idx is the v-head index (used for G), + // and qk_head_idx is the companion Q/K head (computed from heads_per_group). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); - int head_idx = get<1>(blk_coord); + int head_idx = get<1>(blk_coord); // v-head index int tile_idx = get<2>(blk_coord); + int qk_head_idx = get<3>(blk_coord); // == head_idx / heads_per_group int token_offset = cu_seqlens_ptr[batch_idx]; int seq_len = cu_seqlens_ptr[batch_idx + 1] - cu_seqlens_ptr[batch_idx]; int sub_seq_len = min(TileT, seq_len - tile_idx * TileT); Tensor mQ = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_q.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_q.get_tma_tensor(tma_params.shape_qk)); Tensor mK = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qk)); Tensor mG = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_vg)); // TMA load body (Q, K, G — unified pipeline, single barrier per stage) CUTE_NO_UNROLL @@ -727,12 +741,13 @@ struct KdaChunkFwdIntraMainloopSm100 { Tensor sK = make_tensor(make_smem_ptr(shared_plan->k[buf_idx].data()), SmemLayoutInputBF16{}); Tensor sG = make_tensor(make_smem_ptr(shared_plan->g[buf_idx].data()), SmemLayoutInputFP32{}); + // GVA: K and Q are sliced by qk_head_idx; G is sliced by head_idx (v-head). Tensor gK = local_tile( - mK(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); + mK(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); Tensor gG = local_tile( mG(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); Tensor gQ = local_tile( - mQ(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); + mQ(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); // Single acquire for all three TMA copies qkg_load_pipeline.producer_acquire(qkg_load_pipe_state_write); @@ -769,7 +784,9 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Akk is laid out per v-head (params.shape_Akk uses h_v), so we index by head_idx. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -882,7 +899,9 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // beta is per v-head: layout (total_seqlen, h_v), row stride = h_v. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -896,7 +915,7 @@ struct KdaChunkFwdIntraMainloopSm100 { shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = (thread_idx < sub_seq_len) ? float(reinterpret_cast( - params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) + params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h_v + head_idx]) : float(0); } fence_view_async_shared(); diff --git a/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp b/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp index 6bfa78c..a0a0318 100644 --- a/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp @@ -41,14 +41,16 @@ struct KdaChunkFwdRecompWUKernelSm100 { // TMA params (for host launcher) template < - typename ShapeKVG, + typename ShapeQK, + typename ShapeVG, typename ShapeAkk, typename TMA_V, typename TMA_K, typename TMA_G, typename TMA_Akk, typename TMA_Q = int> - using TmaParams = typename Mainloop::template TmaParams; + using TmaParams = + typename Mainloop::template TmaParams; // Pipeline types (for construction in operator()) using PipelineA = typename Mainloop::PipelineA; @@ -429,25 +431,29 @@ __launch_bounds__(384, 1, 1) kda_fwd_recomp_w_u_sm100_kernel_entry( template inline void run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cudaStream_t stream) { - auto shape_KVG = make_shape(params.total_len, params.d, params.h); - auto stride_KVG = make_stride(params.h * params.d, _1{}, params.d); - auto shape_Akk = make_shape(params.total_len, params.chunk_size, params.h); - auto stride_Akk = make_stride(params.h * params.chunk_size, _1{}, params.chunk_size); + // GVA: K and (optional) Q are sized by h_qk; V and G are sized by h_v. + // Akk lives in v-head space (BT x BT per v-head). + auto shape_QK = make_shape(params.total_len, params.d, params.h_qk); + auto stride_QK = make_stride(params.h_qk * params.d, _1{}, params.d); + auto shape_VG = make_shape(params.total_len, params.d, params.h_v); + auto stride_VG = make_stride(params.h_v * params.d, _1{}, params.d); + auto shape_Akk = make_shape(params.total_len, params.chunk_size, params.h_v); + auto stride_Akk = make_stride(params.h_v * params.chunk_size, _1{}, params.chunk_size); // --- Build TMA descriptors --- auto tma_V = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.v_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.v_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputBF16{}); auto tma_K = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.k_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.k_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_G = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputFP32{}); auto tma_Akk = cute::make_tma_copy( @@ -455,12 +461,12 @@ run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cu make_tensor(make_gmem_ptr((bf16*)params.A_ptr), make_layout(shape_Akk, stride_Akk)), typename Kernel::SmemLayoutInputAkkBF16{}); - // Q TMA descriptor (only meaningful when StoreQG=true) + // Q TMA descriptor (only meaningful when StoreQG=true). Q lives in h_qk head space. auto tma_Q = [&]() { if constexpr (Kernel::StoreQG) { return cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.q_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.q_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); } else { return 0; // placeholder, not used @@ -469,14 +475,15 @@ run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cu // --- Pack TMA params --- typename Kernel::template TmaParams< - decltype(shape_KVG), + decltype(shape_QK), + decltype(shape_VG), decltype(shape_Akk), decltype(tma_V), decltype(tma_K), decltype(tma_G), decltype(tma_Akk), decltype(tma_Q)> - tma_params = {shape_KVG, shape_Akk, tma_V, tma_K, tma_G, tma_Akk, tma_Q}; + tma_params = {shape_QK, shape_VG, shape_Akk, tma_V, tma_K, tma_G, tma_Akk, tma_Q}; // --- Launch config --- auto kernel_fn = &kda_fwd_recomp_w_u_sm100_kernel_entry; diff --git a/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp b/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp index 718e075..07bcf26 100644 --- a/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp @@ -187,8 +187,11 @@ struct KdaChunkFwdRecompWUMainloopSm100 { }; // ===================== TMA Params ===================== + // GVA: K and (optional) Q live in h_qk head space (shape_qk), while V + // and G live in h_v head space (shape_vg). Akk is per v-head. template < - typename ShapeKVG, + typename ShapeQK, + typename ShapeVG, typename ShapeAkk, typename TMA_V, typename TMA_K, @@ -196,7 +199,8 @@ struct KdaChunkFwdRecompWUMainloopSm100 { typename TMA_Akk, typename TMA_Q = int> struct TmaParams { - ShapeKVG shape_kvg; + ShapeQK shape_qk; + ShapeVG shape_vg; ShapeAkk shape_Akk; TMA_V tma_v; TMA_K tma_k; @@ -252,7 +256,10 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Prologue touches K (h_qk) and G (h_v) + beta (h_v) + optional Q (h_qk). + // head_idx is the v-head index; qk_head_idx is derived via heads_per_group. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -632,7 +639,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Epilogue consumes V/beta (both h_v) and writes w/u/kg/qg (all h_v). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -732,9 +741,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { // each thread processes one row of W/U (TileK columns) int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16); - // GMEM output address: layout [total_len, d, h], stride [d*h, 1, d] + // GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d] __nv_bfloat16* out_row_base = - out_ptr_base + (token_offset_cur + row) * params.d * params.h + head_idx * params.d; + out_ptr_base + (token_offset_cur + row) * params.d * params.h_v + head_idx * params.d; constexpr int QuarK = TileK / 4; @@ -796,7 +805,8 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { // int tid = tile_scheduler.get_current_tile_id(); - // auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h_v, params.heads_per_group, + // chunk_indices_ptr, cu_seqlens_ptr); // ============================================================ // Once per WU: Wait for Akk in SMEM (from Load warp) @@ -876,31 +886,36 @@ struct KdaChunkFwdRecompWUMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - // Decode tile coordinates - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Decode tile coordinates. head_idx is the v-head (used for V/G/Akk + // TMA loads); qk_head_idx (= head_idx / heads_per_group) is used for + // K/Q TMA loads under GVA. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); - int head_idx = get<1>(blk_coord); + int head_idx = get<1>(blk_coord); // v-head int tile_idx = get<2>(blk_coord); + int qk_head_idx = get<3>(blk_coord); // qk-head int token_offset = cu_seqlens_ptr[batch_idx]; int seq_len = cu_seqlens_ptr[batch_idx + 1] - cu_seqlens_ptr[batch_idx]; int sub_seq_len = min(TileT, seq_len - tile_idx * TileT); // Build GMEM tensor views (with domain offset for batch) + // K and Q live in h_qk head space (shape_qk); V, G and Akk live in h_v space. Tensor mK = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qk)); Tensor mV = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_v.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_v.get_tma_tensor(tma_params.shape_vg)); Tensor mG = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_vg)); Tensor mA = domain_offset( make_coord(token_offset, _0{}, _0{}), tma_params.tma_akk.get_tma_tensor(tma_params.shape_Akk)); - // Q GMEM tensor (only used when StoreQG=true) + // Q GMEM tensor (only used when StoreQG=true). Q lives in h_qk space. [[maybe_unused]] auto mQ = [&]() { if constexpr (StoreQG) { return domain_offset( make_coord(token_offset, _0{}, _0{}), - tma_params.tma_q.get_tma_tensor(tma_params.shape_kvg)); + tma_params.tma_q.get_tma_tensor(tma_params.shape_qk)); } else { return 0; // unused placeholder } @@ -933,8 +948,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { Tensor sG = make_tensor( make_smem_ptr(shared_plan->g[g_pipe_state_write.index()].data()), SmemLayoutInputFP32{}); + // GVA slicing: K uses qk_head_idx; V and G use the v-head index. Tensor gK = local_tile( - mK(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); + mK(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); Tensor gV = local_tile( mV(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); Tensor gG = local_tile( @@ -960,8 +976,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { Tensor sQ = make_tensor( make_smem_ptr(shared_plan->q_buf.q[q_pipe_state_write.index()].data()), SmemLayoutInputBF16{}); + // Q (StoreQG) lives in h_qk space → slice by qk_head_idx. Tensor gQ = local_tile( - mQ(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); + mQ(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); q_pipeline.producer_acquire(q_pipe_state_write); ku::launch_tma_copy( tma_params.tma_q, gQ, sQ, *q_pipeline.producer_get_barrier(q_pipe_state_write)); @@ -994,7 +1011,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // LoadAux: beta is per v-head (row stride = h_v). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -1010,7 +1029,7 @@ struct KdaChunkFwdRecompWUMainloopSm100 { float beta_val = (thread_idx < sub_seq_len) ? float(reinterpret_cast( - params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) + params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h_v + head_idx]) : float(0); shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = beta_val; } diff --git a/csrc/kda/sm100/tile_scheduler.hpp b/csrc/kda/sm100/tile_scheduler.hpp index 47044aa..695bb26 100644 --- a/csrc/kda/sm100/tile_scheduler.hpp +++ b/csrc/kda/sm100/tile_scheduler.hpp @@ -26,11 +26,20 @@ // No smem synchronization needed — every CTA processes tiles starting // at blockIdx.x and striding by gridDim.x. All warps within a CTA // independently maintain the same tile_id, so no tile pipeline is needed. +// +// GVA (Grouped V-head Attention) support: +// Q/K are sized by `num_qk_heads`; V, g, beta, O and state tensors are +// sized by `num_v_heads`. We enumerate tiles by `num_v_heads` so that +// each v-head is scheduled independently, and derive the companion +// `qk_head_idx = v_head_idx / heads_per_group` on the device side. +// `heads_per_group = num_v_heads / num_qk_heads` is precomputed on the +// host to avoid a per-tile integer division. // =================================================================== struct StaticPersistentTileScheduler { struct Params { - int num_blocks; // number of sequence chunks (from chunk_indices) - int num_heads; + int num_blocks; // number of sequence chunks (from chunk_indices) + int num_heads; // == num_v_heads; tiles are enumerated by v-head + int heads_per_group; // == num_v_heads / num_qk_heads, precomputed on host int num_sm; int* tile_counter; // unused @@ -77,14 +86,22 @@ struct StaticPersistentTileScheduler { return current_tile_id < total_tiles(); } + // Decode tile_id -> (batch_idx, v_head_idx, seq_idx, qk_head_idx). + // `num_v_heads` is the number of V/O/g/beta heads; tile enumeration is + // done in v-head space. `heads_per_group` (= num_v_heads/num_qk_heads) + // is used to derive the companion Q/K head index for GVA. + // For backward compatibility, when HV == HQK, `heads_per_group == 1` + // and `qk_head_idx == v_head_idx`. CUTLASS_DEVICE static auto - decode_tile_coord(int tile_id, int num_heads, int* chunk_indices_ptr, int* cu_seqlens_ptr) { + decode_tile_coord( + int tile_id, int num_v_heads, int heads_per_group, int* chunk_indices_ptr, int* /*cu_seqlens_ptr*/) { using namespace cute; - int tile_idx_raw = tile_id / num_heads; - int head_idx = tile_id % num_heads; + int tile_idx_raw = tile_id / num_v_heads; + int v_head_idx = tile_id % num_v_heads; + int qk_head_idx = v_head_idx / heads_per_group; int batch_idx = chunk_indices_ptr[tile_idx_raw * 2]; int seq_idx = chunk_indices_ptr[tile_idx_raw * 2 + 1]; - return make_coord(batch_idx, head_idx, seq_idx, 0); + return make_coord(batch_idx, v_head_idx, seq_idx, qk_head_idx); } }; \ No newline at end of file diff --git a/cula/kda/chunk_intra.py b/cula/kda/chunk_intra.py index 0703638..9fcc93a 100644 --- a/cula/kda/chunk_intra.py +++ b/cula/kda/chunk_intra.py @@ -759,7 +759,22 @@ def chunk_kda_fwd_intra( unified_gref: bool = False, # Set True for ~5% extra perf (slightly lower precision) ): assert safe_gate, "Only safe_gate=True is supported in chunk_kda_fwd_intra for now" - B, T, H, K = k.shape + # GVA support: Q/K have head-dim HQK; V/g/beta/Aqk/Akk/w/u/kg/qg have head-dim HV. + # Pre-GVA behaviour is preserved when HV == HQK. + B, T, HQK, K = k.shape + HV = v.shape[2] + assert v.shape[0] == B and v.shape[1] == T, ( + f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}" + ) + assert HV > 0 and HQK > 0 and HV % HQK == 0, ( + f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" + ) + if gk is not None: + assert gk.shape[0] == B and gk.shape[1] == T and gk.shape[2] == HV, ( + f"gk shape must be (B, T, HV={HV}, K); got {tuple(gk.shape)}" + ) + if beta is not None: + assert beta.shape[-1] == HV, f"beta last dim must equal HV={HV}; got {tuple(beta.shape)}" BT = chunk_size if cu_seqlens is None: @@ -773,18 +788,20 @@ def chunk_kda_fwd_intra( "cu_seqlens and chunk_indices must be int32 for cuda impl" ) - Aqk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) - Akk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) + # Aqk/Akk are produced per v-head (they live in v-head space because g/beta are per v-head). + Aqk = torch.empty(B, T, HV, BT, device=k.device, dtype=k.dtype) + Akk = torch.empty(B, T, HV, BT, device=k.device, dtype=k.dtype) tile_counter = torch.zeros(1, dtype=torch.int32, device=q.device) cula_cuda.chunk_kda_fwd_intra_cuda( q, k, gk, beta, cu_seqlens, chunk_indices, Aqk, Akk, tile_counter, scale, chunk_size, use_tf32_inverse, unified_gref ) - w = torch.empty_like(k) + # w/u/kg/qg are all per-v-head outputs. + w = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) u = torch.empty_like(v) - qg = torch.empty_like(q) if disable_recompute else None - kg = torch.empty_like(k) if gk is not None else None + qg = torch.empty(B, T, HV, K, device=q.device, dtype=q.dtype) if disable_recompute else None + kg = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) if gk is not None else None cula_cuda.recompute_w_u_cuda( k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg From 1f2c75b41148ed43c81105e46dd6be2dfb8e9e10 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 7 May 2026 11:37:17 +0800 Subject: [PATCH 02/14] add sm100 test --- tests/test_kda_gva_intra_sm100.py | 375 ++++++++++++++++++++++++++++++ 1 file changed, 375 insertions(+) create mode 100644 tests/test_kda_gva_intra_sm100.py diff --git a/tests/test_kda_gva_intra_sm100.py b/tests/test_kda_gva_intra_sm100.py new file mode 100644 index 0000000..a86e56f --- /dev/null +++ b/tests/test_kda_gva_intra_sm100.py @@ -0,0 +1,375 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SM100 KDA GVA (HV > HQK) support in chunk_kda_fwd_intra. + +The SM100 kernels (kda_fwd_intra / kda_fwd_recomp_w_u) now accept: + * q, k with head-dim ``HQK`` + * v, g, beta with head-dim ``HV`` where ``HV = group_size * HQK`` (group_size >= 1) + +This file verifies that the cuLA GVA path produces numerically matching results +compared to the FLA Triton reference, where the FLA reference does not natively +support GVA and therefore receives ``k`` replicated along the head axis to +``HV`` heads. Both uniform-length and varlen layouts are covered, and an +additional degeneracy test asserts that ``HV == HQK`` (group_size == 1) keeps +the non-GVA behaviour untouched. +""" + +from __future__ import annotations + +import pytest +import torch +from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra as fla_chunk_kda_fwd_intra +from fla.ops.kda.gate import kda_gate_chunk_cumsum +from fla.ops.utils.constant import RCP_LN2 +from fla.ops.utils.index import prepare_chunk_indices +from fla.utils import assert_close, device + +from cula.kda.chunk_intra import chunk_kda_fwd_intra as cula_chunk_kda_fwd_intra +from cula.utils import prepare_uniform_cu_seqlens + +pytestmark = pytest.mark.sm100_only + + +# ========================================================================= +# Helpers +# ========================================================================= + +def _l2norm_last(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.normalize(x.float(), p=2, dim=-1).to(x.dtype) + + +def _repeat_head(x: torch.Tensor, group_size: int, head_dim: int = 2) -> torch.Tensor: + """Replicate ``x`` along the head axis by ``group_size``. + + Mirrors GVA's broadcasting semantics: each QK head is paired with + ``group_size`` consecutive V heads, so ``k[..., h_qk, :]`` is used by + ``v[..., h_qk * group_size : (h_qk + 1) * group_size, :]``. + """ + return x.repeat_interleave(group_size, dim=head_dim).contiguous() + + +def _make_gva_inputs( + B: int, + T: int, + HQK: int, + HV: int, + D: int, + chunk_size: int, + cu_seqlens: torch.Tensor | None = None, + dtype: torch.dtype = torch.bfloat16, + seed: int = 42, +): + """Construct inputs for chunk_kda_fwd_intra in GVA layout. + + Returns: + q, k : (B, T, HQK, D) dtype + v : (B, T, HV, D) dtype + g : (B, T, HV, D) float32, after kda_gate_chunk_cumsum + beta : (B, T, HV) float32 in (0, 1) + scale : float + cu_seqlens : (N+1,) int32 or None + chunk_indices: (NT, 2) int32 or None + """ + assert HV % HQK == 0 and HV >= HQK, f"invalid HV/HQK: {HV}/{HQK}" + + torch.manual_seed(seed) + scale = D ** (-0.5) + + # QK are in HQK head space; V / gates / beta live in HV space. + q = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + k = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + v = torch.randn(B, T, HV, D, dtype=dtype, device=device) + g_raw = torch.randn(B, T, HV, D, dtype=dtype, device=device) + beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() + + # l2-normalise q/k so that scale/gate ranges match production use. + q = _l2norm_last(q) + k = _l2norm_last(k) + + # Per-HV gate preprocessing (cumsum inside chunks). + A_log = torch.randn(HV, dtype=torch.float, device=device) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device) + + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + ) + g = kda_gate_chunk_cumsum( + g=g_raw, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=-5.0, + ) + return q, k, v, g, beta, scale, cu_seqlens, chunk_indices + + +def _run_fla_ref(q, k_hqk, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, group_size, disable_recompute): + """Reference: replicate k along head axis to HV, then call FLA intra. + + FLA's chunk_kda_fwd_intra assumes H == HQK == HV (no GVA), so we construct + the HV-head view of k and q before invoking it. + """ + k_hv = _repeat_head(k_hqk, group_size) + q_hv = _repeat_head(q, group_size) + return fla_chunk_kda_fwd_intra( + q=q_hv, + k=k_hv, + v=v, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=True, + disable_recompute=disable_recompute, + ) + + +def _run_cula_gva(q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute): + return cula_chunk_kda_fwd_intra( + q=q, + k=k, + v=v, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=True, + disable_recompute=disable_recompute, + ) + + +# ========================================================================= +# Uniform-length tests +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("B", "T", "HQK", "group_size", "D"), + [ + pytest.param(*cfg, id="B{}-T{}-HQK{}-gs{}-D{}".format(*cfg)) + for cfg in [ + # group_size == 2: classic GVA 2:1 + (1, 256, 2, 2, 128), + (2, 512, 4, 2, 128), + # group_size == 4: wider grouping + (1, 1024, 2, 4, 128), + (2, 1024, 4, 4, 128), + # Non-multiple-of-BT sequence length to stress boundary handling. + (1, 500, 2, 2, 128), + (1, 1000, 4, 2, 128), + ] + ], +) +def test_gva_intra_uniform(B, T, HQK, group_size, D, disable_recompute): + """cuLA GVA path must match FLA(k-replicated-to-HV) for uniform seqlens.""" + HV = HQK * group_size + chunk_size = 64 + + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + + # cuLA GVA path (k in HQK head space). + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute, + ) + + # FLA reference (k replicated to HV). + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = _run_fla_ref( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, group_size, disable_recompute, + ) + + # All outputs live in HV head space → shapes must match directly. + assert Aqk_c.shape == Aqk_r.shape, (Aqk_c.shape, Aqk_r.shape) + assert Akk_c.shape == Akk_r.shape, (Akk_c.shape, Akk_r.shape) + assert w_c.shape == w_r.shape, (w_c.shape, w_r.shape) + assert u_c.shape == u_r.shape, (u_c.shape, u_r.shape) + assert kg_c.shape == kg_r.shape, (kg_c.shape, kg_r.shape) + + # Aqk / Akk are the core A-matrices; they drive w/u, so keep tolerances tight. + assert_close("Aqk", Aqk_r, Aqk_c, 0.005) + assert_close("Akk", Akk_r, Akk_c, 0.008) + + # recompute_w_u outputs + assert_close("w", w_r, w_c, 0.008) + assert_close("u", u_r, u_c, 0.008) + assert_close("kg", kg_r, kg_c, 0.005) + + if disable_recompute: + assert qg_c is not None and qg_r is not None + assert qg_c.shape == qg_r.shape, (qg_c.shape, qg_r.shape) + assert_close("qg", qg_r, qg_c, 0.005) + else: + assert qg_c is None, "cuLA must not materialise qg when disable_recompute=False" + + +# ========================================================================= +# Varlen tests +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("HQK", "group_size", "D", "cu_seqlens"), + [ + pytest.param(*cfg, id="HQK{}-gs{}-D{}-ns{}".format(cfg[0], cfg[1], cfg[2], len(cfg[3]) - 1)) + for cfg in [ + (2, 2, 128, [0, 256, 500, 1000]), + (4, 2, 128, [0, 100, 300, 1200, 2000]), + (2, 4, 128, [0, 15, 100, 300, 1200, 2048]), + # Simulated realistic trace. + ( + 4, 2, 128, + [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096], + ), + ] + ], +) +def test_gva_intra_varlen(HQK, group_size, D, cu_seqlens, disable_recompute): + """GVA correctness under variable-length (packed) inputs.""" + HV = HQK * group_size + chunk_size = 64 + + cu_seqlens_t = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + T = int(cu_seqlens_t[-1].item()) + # Packed layout uses B=1 and a flat time axis. + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices = _make_gva_inputs( + B=1, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens_t, + ) + + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices, chunk_size, disable_recompute, + ) + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = _run_fla_ref( + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices, chunk_size, group_size, disable_recompute, + ) + + assert_close("Aqk", Aqk_r, Aqk_c, 0.005) + assert_close("Akk", Akk_r, Akk_c, 0.008) + assert_close("w", w_r, w_c, 0.008) + assert_close("u", u_r, u_c, 0.008) + assert_close("kg", kg_r, kg_c, 0.005) + + if disable_recompute: + assert_close("qg", qg_r, qg_c, 0.005) + else: + assert qg_c is None + + +# ========================================================================= +# Degeneracy: HV == HQK must match the non-GVA (same-shape) reference +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("B", "T", "H", "D"), + [ + pytest.param(*cfg, id="B{}-T{}-H{}-D{}".format(*cfg)) + for cfg in [ + (1, 512, 4, 128), + (2, 1024, 4, 128), + ] + ], +) +def test_gva_intra_degenerate_equals_non_gva(B, T, H, D, disable_recompute): + """When HV == HQK, the GVA code path must be byte-for-byte equivalent + to the non-GVA path that existed before this change. + + We do not have a separate "non-GVA" entrypoint, but we can assert the + cuLA path matches FLA with *no* head replication (group_size=1), which + exercises the ``HV == HQK`` fast-path inside the new kernels. + """ + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=H, HV=H, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute, + ) + # group_size=1 → no replication; identical input shape to cuLA. + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = fla_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=disable_recompute, + ) + + assert_close("Aqk", Aqk_r, Aqk_c, 0.005) + assert_close("Akk", Akk_r, Akk_c, 0.008) + assert_close("w", w_r, w_c, 0.008) + assert_close("u", u_r, u_c, 0.008) + assert_close("kg", kg_r, kg_c, 0.005) + if disable_recompute: + assert_close("qg", qg_r, qg_c, 0.005) + + +# ========================================================================= +# Shape / contract sanity checks (run even without a reference) +# ========================================================================= + +@pytest.mark.parametrize("group_size", [1, 2, 4]) +def test_gva_intra_output_shapes(group_size): + """All outputs of chunk_kda_fwd_intra must live in HV-head space.""" + B, T, HQK, D = 1, 256, 2, 128 + HV = HQK * group_size + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + w, u, qg, kg, Aqk, Akk = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute=True, + ) + + assert Aqk.shape == (B, T, HV, chunk_size), Aqk.shape + assert Akk.shape == (B, T, HV, chunk_size), Akk.shape + assert w.shape == (B, T, HV, D), w.shape + assert u.shape == (B, T, HV, D), u.shape + assert kg.shape == (B, T, HV, D), kg.shape + assert qg is not None and qg.shape == (B, T, HV, D), (None if qg is None else qg.shape) + + +# ========================================================================= +# Negative / assertion tests +# ========================================================================= + +def test_gva_intra_rejects_non_multiple_ratio(): + """HV must be a positive integer multiple of HQK.""" + B, T, HQK, HV, D = 1, 128, 3, 5, 128 # 5 % 3 != 0 + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + # We intentionally do not use _make_gva_inputs because the assert fires + # before kernel launch on the python side. + dtype = torch.bfloat16 + q = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + k = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + v = torch.randn(B, T, HV, D, dtype=dtype, device=device) + g = torch.randn(B, T, HV, D, dtype=torch.float, device=device) + beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() + + with pytest.raises(AssertionError): + cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=D ** -0.5, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, + safe_gate=True, disable_recompute=False, + ) From 3273659ab708fcd610badc289d94d60e560e4f41 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 13:18:17 +0800 Subject: [PATCH 03/14] benchmark --- benchmarks/bench_kda_chunk_intra.py | 173 +++++++++++++++++++++++++++- benchmarks/utils.py | 51 ++++++++ cula/utils.py | 2 +- 3 files changed, 222 insertions(+), 4 deletions(-) diff --git a/benchmarks/bench_kda_chunk_intra.py b/benchmarks/bench_kda_chunk_intra.py index 29aa5ab..466dd70 100644 --- a/benchmarks/bench_kda_chunk_intra.py +++ b/benchmarks/bench_kda_chunk_intra.py @@ -25,7 +25,7 @@ from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra as fla_chunk_kda_fwd_intra -from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs +from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs, prepare_intra_inputs_gva from cula.kda.chunk_intra import chunk_kda_fwd_intra as cula_chunk_kda_fwd_intra # Constant params @@ -39,6 +39,7 @@ VARIANCE = 1.0 DISABLE_RECOMPUTE = False # Whether to disable recompute (compute QG in forward) +GROUP_SIZE = 1 # GVA group size: HV = GROUP_SIZE * H. 1 means no GVA. def accuracy_stats(a, b): @@ -246,6 +247,158 @@ def benchmark_chunk_intra_varlen(): print("─" * 100) +# ============================================================================== +# GVA uniform seqlen benchmark +# ============================================================================== +def benchmark_chunk_intra_gva_uniform(group_size: int): + """Benchmark GVA (HV > HQK) intra chunk: cuLA vs FLA Triton (k replicated to HV). + + FLA does not natively support GVA, so the reference replicates k along the + head axis to HV before calling the kernel (same strategy as in the unit tests). + """ + device = torch.device("cuda") + chunk_size = BT + HQK = H + HV = HQK * group_size + T_vals = [512, 1024, 4096, 8192, 16384, 32768] + + print("=" * 100) + print( + f" GVA Uniform ChunkIntra Benchmark: cuLA vs FLA Triton " + f"B={B} HQK={HQK} HV={HV} (group_size={group_size}) D={D} disable_recompute={DISABLE_RECOMPUTE}" + ) + print("=" * 100) + print( + f"{'B':>4} {'T':>7} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" + ) + print("─" * 100) + + for T in T_vals: + seq_lens = [T] * B + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( + B, T, HQK, HV, D, device, cu_seqlens=cu_seqlens + ) + + # FLA reference: replicate k/q to HV heads + k_hv = k.repeat_interleave(group_size, dim=2).contiguous() + q_hv = q.repeat_interleave(group_size, dim=2).contiguous() + + # Accuracy: run once and compare + out_fla = fla_chunk_kda_fwd_intra( + q=q_hv, k=k_hv, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ) + out_cula = cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ) + # Compare first output (w) + o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla + o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) + + # Performance + ms_fla = triton.testing.do_bench( + lambda: fla_chunk_kda_fwd_intra( + q=q_hv, k=k_hv, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ), + ) + ms_cula = triton.testing.do_bench( + lambda: cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ), + ) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + print( + f"{B:>4} {T:>7} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" + ) + + print("─" * 100) + + +# ============================================================================== +# GVA varlen benchmark +# ============================================================================== +def benchmark_chunk_intra_gva_varlen(group_size: int): + """Varlen GVA benchmark: cuLA vs FLA Triton (k replicated to HV).""" + device = torch.device("cuda") + chunk_size = BT + HQK = H + HV = HQK * group_size + total_len_vals = [8192, 16384, 32768, 65536] + + print() + print("=" * 110) + print( + f" GVA Varlen ChunkIntra Benchmark: cuLA vs FLA Triton " + f"NUM_SEQS={NUM_SEQS} HQK={HQK} HV={HV} (group_size={group_size}) D={D} disable_recompute={DISABLE_RECOMPUTE}" + ) + print("=" * 110) + print( + f"{'total_len':>10} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" + ) + print("─" * 110) + + for total_len in total_len_vals: + seq_lens = generate_random_seq_lens(NUM_SEQS, total_len, MIN_SEQ_LEN, VARIANCE, SEED) + T = total_len + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( + 1, T, HQK, HV, D, device, cu_seqlens=cu_seqlens + ) + + k_hv = k.repeat_interleave(group_size, dim=2).contiguous() + q_hv = q.repeat_interleave(group_size, dim=2).contiguous() + + # Accuracy + out_fla = fla_chunk_kda_fwd_intra( + q=q_hv, k=k_hv, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ) + out_cula = cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ) + o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla + o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) + + # Performance + ms_fla = triton.testing.do_bench( + lambda: fla_chunk_kda_fwd_intra( + q=q_hv, k=k_hv, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ), + ) + ms_cula = triton.testing.do_bench( + lambda: cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ), + ) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + print( + f"{total_len:>10} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" + ) + + print("─" * 110) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="bench_kda_chunk_intra: cuLA vs FLA Triton for chunk_kda_fwd_intra") parser.add_argument( @@ -253,11 +406,25 @@ def benchmark_chunk_intra_varlen(): action="store_true", help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) + parser.add_argument( + "--group_size", + type=int, + default=1, + help="GVA group size: HV = group_size * H. 1 (default) runs the non-GVA benchmark. " + "Values > 1 run GVA benchmarks comparing cuLA (k in HQK space) vs FLA (k replicated to HV).", + ) args = parser.parse_args() if args.disable_recompute: DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") - benchmark_chunk_intra_uniform() - benchmark_chunk_intra_varlen() + GROUP_SIZE = args.group_size + + if GROUP_SIZE == 1: + benchmark_chunk_intra_uniform() + benchmark_chunk_intra_varlen() + else: + assert H % 1 == 0, "H must be divisible by group_size" + benchmark_chunk_intra_gva_uniform(GROUP_SIZE) + benchmark_chunk_intra_gva_varlen(GROUP_SIZE) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index bfd0761..05198ab 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -366,3 +366,54 @@ def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_siz ) return q, k, v, g, beta, scale, cu_seqlens, chunk_indices + + +def prepare_intra_inputs_gva( + batch_size, T, HQK, HV, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED +): + """Prepare preprocessed inputs for chunk_kda_fwd_intra with GVA (HV >= HQK). + + GVA layout: + q, k : (batch_size_flat, T, HQK, D) — Q/K head space + v : (batch_size_flat, T, HV, D) — V head space + g : (batch_size_flat, T, HV, D) — gate in V head space (after cumsum) + beta : (batch_size_flat, T, HV) — beta in V head space + + When HV == HQK (group_size == 1) this is identical to prepare_intra_inputs. + All tensors are flattened to batch_size=1 for cu_seqlens compatibility. + """ + assert HV > 0 and HQK > 0 and HV % HQK == 0, f"HV ({HV}) must be a positive multiple of HQK ({HQK})" + dtype = torch.bfloat16 + scale = D ** (-0.5) + + set_seed(seed) + + q = torch.randn(batch_size, T, HQK, D, dtype=dtype, device=device) + k = torch.randn(batch_size, T, HQK, D, dtype=dtype, device=device) + v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device) + g_raw = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device) + beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid() + + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + + if batch_size != 1: + q, k, v, g_raw, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g_raw, beta)) + + A_log = torch.randn(HV, dtype=torch.float, device=device) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + + g = kda_gate_chunk_cumsum( + g=g_raw, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=-5.0, + ) + + return q, k, v, g, beta, scale, cu_seqlens, chunk_indices diff --git a/cula/utils.py b/cula/utils.py index bd70730..3b43fe0 100644 --- a/cula/utils.py +++ b/cula/utils.py @@ -83,7 +83,7 @@ def assert_hopper(device: torch.device | str | int | None = None) -> None: def get_kda_fused_fwd(device: torch.device | str | int | None = None) -> Callable: """Return the appropriate ``kda_prefill`` implementation for *device*. - - sm100/sm103 (Blackwell) → cula.kda.kda_prefill_blackwell (not yet available) + - sm100/sm103 (Blackwell) → NotImplementedError - sm90 (Hopper) → cula.kda.kda_prefill_hopper Args: From 270ab5d865aa9dc35c753ba2e3c90265dcc3741e Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 15:47:12 +0800 Subject: [PATCH 04/14] benchmark and test --- benchmarks/bench_kda.py | 227 ++++++++--------- benchmarks/bench_kda_fwd_bwd_e2e.py | 363 +++++++++++++--------------- benchmarks/utils.py | 87 +++++++ cula/kda/chunk.py | 13 +- cula/kda/chunk_bwd.py | 67 ++--- cula/kda/chunk_fwd.py | 9 +- cula/kda/chunk_intra.py | 121 +++++----- tests/test_kda.py | 187 ++++++++++++++ 8 files changed, 683 insertions(+), 391 deletions(-) diff --git a/benchmarks/bench_kda.py b/benchmarks/bench_kda.py index dc31d11..50b7399 100644 --- a/benchmarks/bench_kda.py +++ b/benchmarks/bench_kda.py @@ -15,7 +15,7 @@ """ bench_kda.py — Benchmark: cuLA CuTe DSL vs FLA Triton baseline - for chunk_kda (KDA forward) + for chunk_kda (KDA training, fwd+bwd) Compares: - Accuracy: RMSE, relative max diff between cuLA and FLA outputs @@ -25,8 +25,13 @@ - Fixed-length: B=1, B=2 with various T - Varlen: ~20 seqs with 2-3x length variation +H (number of Q/K heads) is a module-level constant; HV (number of V heads) +defaults to H and can be overridden globally via --hv to run every config in +GVA mode. In GVA mode cuLA receives native HQK q/k; FLA receives q/k +expanded to HV heads. HV must be a positive multiple of H. + Usage: - python bench_kda.py [--mode fixed|varlen|both] [--ncu] + python bench_kda.py [--mode fixed|varlen|both] [--hv HV] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda.py --mode varlen --ncu @@ -48,6 +53,7 @@ build_varlen_configs, exclusive_cumsum, prepare_safe_gate_inputs, + prepare_safe_gate_inputs_gva, set_seed, ) from cula.kda import chunk_kda as cula_chunk_kda @@ -55,7 +61,10 @@ # ============================================================ # Constants # ============================================================ +# H = QK head count; HV = V head count. HV defaults to H (non-GVA / MHA). +# Override via --hv to run every config in GVA mode (HV must be a multiple of H). H, D = 64, 128 +HV = H WARMUP = 10 N_ITERS = 30 NCU_MODE = False @@ -162,74 +171,70 @@ def check_determinism(H=4, total_T=8192, num_seqs=10, iters=10000): assert torch.equal(state, ref_state), f"State mismatch at iter {i}" +def _prepare_inputs(B, T, cu_seqlens): + """Return (inputs, q_fla, k_fla, q_cula, k_cula). + + Non-GVA (HV == H): all four q/k are the same tensor. + GVA (HV > H) : cuLA gets native HQK q/k; FLA gets q/k expanded to HV. + """ + device = torch.device("cuda") + if HV > H: + inputs = prepare_safe_gate_inputs_gva(B, T, H, HV, D, device, cu_seqlens=cu_seqlens) + q_cula, k_cula = inputs["q"], inputs["k"] # [B_flat, T, H, D] + q_fla = q_cula.repeat_interleave(HV // H, dim=2).contiguous() # [B_flat, T, HV, D] + k_fla = k_cula.repeat_interleave(HV // H, dim=2).contiguous() + else: + inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens) + q_cula = q_fla = inputs["q"] + k_cula = k_fla = inputs["k"] + return inputs, q_fla, k_fla, q_cula, k_cula + + # ============================================================ # Fixed-length benchmark # ============================================================ def bench_fixed(configs): - print("\n" + "=" * 100) - print(f" Fixed-Length Benchmark: cuLA CuTe DSL vs FLA Triton disable_recompute={DISABLE_RECOMPUTE}") - print("=" * 100) + gva_note = f"GVA HV={HV} ({HV // H}x)" if HV > H else f"MHA HV=H={H}" + print("\n" + "=" * 110) + print(f" Fixed-Length Benchmark: cuLA CuTe DSL vs FLA Triton {gva_note} disable_recompute={DISABLE_RECOMPUTE}") + print("=" * 110) results = [] for B, T in configs: set_seed(SEED) - device = torch.device("cuda") torch.cuda.empty_cache() seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=torch.device("cuda")) - inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens) - q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + inputs, q_fla, k_fla, q_cula, k_cula = _prepare_inputs(B, T, cu_seqlens) + v, g, beta = inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - ) + _shared = dict(v=v, g=g, beta=beta, scale=scale, A_log=A_log, dt_bias=dt_bias, + init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound) + common_fla = dict(q=q_fla, k=k_fla, **_shared) + common_cula = dict(q=q_cula, k=k_cula, **_shared) - # Accuracy: compare outputs - o_fla, _ = run_kda(**common, fn=fla_chunk_kda) - o_cula, _ = run_kda(**common, fn=cula_chunk_kda) + # Accuracy + o_fla, _ = run_kda(**common_fla, fn=fla_chunk_kda) + o_cula, _ = run_kda(**common_cula, fn=cula_chunk_kda) torch.cuda.synchronize() - rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) # Performance - def fn_fla(**common_kw): - return lambda: run_kda(**common_kw, fn=fla_chunk_kda) - - def fn_cula(**common_kw): - return lambda: run_kda(**common_kw, fn=cula_chunk_kda) - - ms_fla = time_kernel(fn_fla(**common)) - ms_cula = time_kernel(fn_cula(**common)) + ms_fla = time_kernel(lambda: run_kda(**common_fla, fn=fla_chunk_kda)) + ms_cula = time_kernel(lambda: run_kda(**common_cula, fn=cula_chunk_kda)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - r = { - "B": B, - "T": T, - "rmse": rmse, - "rel_max": rel_max, - "mean_diff": mean_diff, - "ms_fla": ms_fla, - "ms_cula": ms_cula, - "speedup": speedup, - } - results.append(r) - # print(f" B={B:2d} T={T:5d} done ({speedup:.2f}x)") - - del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + results.append({ + "B": B, "T": T, "H": H, "HV": HV, + "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, + "ms_fla": ms_fla, "ms_cula": ms_cula, "speedup": speedup, + }) + + del o_fla, o_cula, inputs torch.cuda.empty_cache() return results @@ -239,77 +244,51 @@ def fn_cula(**common_kw): # Varlen benchmark # ============================================================ def bench_varlen(configs): - print("\n" + "=" * 100) - print(f" Varlen Benchmark: cuLA CuTe DSL vs FLA Triton disable_recompute={DISABLE_RECOMPUTE}") - print("=" * 100) + gva_note = f"GVA HV={HV} ({HV // H}x)" if HV > H else f"MHA HV=H={H}" + print("\n" + "=" * 110) + print(f" Varlen Benchmark: cuLA CuTe DSL vs FLA Triton {gva_note} disable_recompute={DISABLE_RECOMPUTE}") + print("=" * 110) results = [] for seq_lens, total_len, dist in configs: set_seed(SEED) - device = torch.device("cuda") torch.cuda.empty_cache() T = total_len - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=torch.device("cuda")) - inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens) - q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + inputs, q_fla, k_fla, q_cula, k_cula = _prepare_inputs(1, T, cu_seqlens) + v, g, beta = inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - ) + _shared = dict(v=v, g=g, beta=beta, scale=scale, A_log=A_log, dt_bias=dt_bias, + init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound) + common_fla = dict(q=q_fla, k=k_fla, **_shared) + common_cula = dict(q=q_cula, k=k_cula, **_shared) # Accuracy - o_fla, _ = run_kda(**common, fn=fla_chunk_kda) - o_cula, _ = run_kda(**common, fn=cula_chunk_kda) + o_fla, _ = run_kda(**common_fla, fn=fla_chunk_kda) + o_cula, _ = run_kda(**common_cula, fn=cula_chunk_kda) torch.cuda.synchronize() - rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) # Performance - def fn_fla(**common_kw): - return lambda: run_kda(**common_kw, fn=fla_chunk_kda) - - def fn_cula(**common_kw): - return lambda: run_kda(**common_kw, fn=cula_chunk_kda) - - ms_fla = time_kernel(fn_fla(**common)) - ms_cula = time_kernel(fn_cula(**common)) + ms_fla = time_kernel(lambda: run_kda(**common_fla, fn=fla_chunk_kda)) + ms_cula = time_kernel(lambda: run_kda(**common_cula, fn=cula_chunk_kda)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") n_seqs = len(seq_lens) - min_l, max_l = min(seq_lens), max(seq_lens) - avg_l = T // n_seqs - tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" - - r = { - "tag": tag, - "dist": dist, - "T_total": T, - "n_seqs": n_seqs, - "rmse": rmse, - "rel_max": rel_max, - "mean_diff": mean_diff, - "ms_fla": ms_fla, - "ms_cula": ms_cula, - "speedup": speedup, - } - results.append(r) - # print(f" {tag:45s} done ({speedup:.2f}x)") - - del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min(seq_lens)}..{max(seq_lens)}] avg={T // n_seqs}" + + results.append({ + "tag": tag, "dist": dist, "T_total": T, "n_seqs": n_seqs, + "H": H, "HV": HV, + "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, + "ms_fla": ms_fla, "ms_cula": ms_cula, "speedup": speedup, + }) + + del o_fla, o_cula, inputs torch.cuda.empty_cache() return results @@ -319,11 +298,13 @@ def fn_cula(**common_kw): # Report # ============================================================ def print_report(fixed_results, varlen_results): - sep = "=" * 110 + sep = "=" * 120 print(f"\n\n{sep}") print(" BENCHMARK REPORT: chunk_kda") print(" cuLA CuTe DSL vs FLA Triton") - print(f" H={H} D={D} dtype=bf16 safe_gate=True disable_recompute={DISABLE_RECOMPUTE}") + print(f" D={D} dtype=bf16 safe_gate=True disable_recompute={DISABLE_RECOMPUTE}") + gva_note = f"GVA enabled (HV={HV} > H={H}, ratio={HV // H}x)" if HV > H else f"MHA (HV=H={H})" + print(f" {gva_note}") wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") @@ -332,32 +313,39 @@ def print_report(fixed_results, varlen_results): if fixed_results: print("\n [Fixed-Length]") - print(f" {'─' * 85}") + print(f" {'─' * 110}") print( - f" {'B':>3s} {'T':>5s} │ {'RMSE':>10s} {'rel_max':>10s}" - f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s}" + f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " + f"{'RMSE':>10s} {'rel_max':>10s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s}" ) - print(f" {'─' * 85}") + print(f" {'─' * 110}") for r in fixed_results: + gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" print( - f" {r['B']:3d} {r['T']:5d} │ " + f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} │ " f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x" ) - print(f" {'─' * 85}") + print(f" {'─' * 110}") if varlen_results: print("\n [Varlen]") - print(f" {'─' * 100}") - print(f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} │ {'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s}") - print(f" {'─' * 100}") + print(f" {'─' * 120}") + print( + f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " + f"{'RMSE':>10s} {'rel_max':>10s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s}" + ) + print(f" {'─' * 120}") for r in varlen_results: + gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" print( - f" {r['tag']:>45s} │ " + f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} │ " f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x" ) - print(f" {'─' * 100}") + print(f" {'─' * 120}") print(f"\n{sep}\n") @@ -389,9 +377,15 @@ def main(): action="store_true", help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) + parser.add_argument( + "--hv", + type=int, + default=None, + help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H for GVA mode.", + ) args = parser.parse_args() - global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE + global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE, HV if args.ncu: NCU_MODE = True print("[NCU mode] warmup=1, iters=1") @@ -401,6 +395,12 @@ def main(): if args.disable_recompute: DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") + if args.hv is not None: + if args.hv < H or args.hv % H != 0: + raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") + HV = args.hv + if HV > H: + print(f"[GVA] HV={HV} (H={H}, ratio={HV // H}x)") fixed_configs = [ # (B, T) @@ -434,6 +434,7 @@ def main(): varlen_res = bench_varlen(varlen_configs) print_report(fixed_res, varlen_res) + return fixed_res, varlen_res diff --git a/benchmarks/bench_kda_fwd_bwd_e2e.py b/benchmarks/bench_kda_fwd_bwd_e2e.py index c6b4117..a63ce6e 100644 --- a/benchmarks/bench_kda_fwd_bwd_e2e.py +++ b/benchmarks/bench_kda_fwd_bwd_e2e.py @@ -29,8 +29,13 @@ - forward: forward pass only - e2e: forward + backward (end-to-end) +H (number of Q/K heads) is a module-level constant; HV (number of V heads) +defaults to H and can be overridden globally via --hv to run every config in +GVA mode. In GVA mode cuLA receives native HQK q/k; FLA receives q/k +expanded to HV heads. HV must be a positive multiple of H. + Usage: - python bench_kda_fwd_bwd_e2e.py [--mode fixed|varlen|both] [--phase forward|e2e] [--ncu] + python bench_kda_fwd_bwd_e2e.py [--mode fixed|varlen|both] [--phase forward|e2e] [--hv HV] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda_fwd_bwd_e2e.py --mode varlen --ncu @@ -53,6 +58,7 @@ exclusive_cumsum, generate_random_seq_lens, prepare_safe_gate_inputs, + prepare_safe_gate_inputs_gva, set_seed, ) from cula.kda import chunk_kda as cula_chunk_kda @@ -60,7 +66,10 @@ # ============================================================ # Constants # ============================================================ +# H = QK head count; HV = V head count. HV defaults to H (non-GVA / MHA). +# Override via --hv to run every config in GVA mode (HV must be a multiple of H). H, D = 64, 128 +HV = H WARMUP = 25 N_ITERS = 100 NCU_MODE = False @@ -226,113 +235,116 @@ def check_determinism(num_seqs=5, T=512, iters=20): return True +def _prepare_inputs_e2e(B, T, cu_seqlens): + """Return (inputs, q_fla, k_fla, q_cula, k_cula). + + Non-GVA (HV == H): all four q/k are the same tensor. + GVA (HV > H) : cuLA gets native HQK q/k; FLA gets q/k expanded to HV. + """ + device = torch.device("cuda") + if HV > H: + inputs = prepare_safe_gate_inputs_gva(B, T, H, HV, D, device, cu_seqlens=cu_seqlens, has_init_state=True) + q_cula, k_cula = inputs["q"], inputs["k"] # [B_flat, T, H, D] + q_fla = q_cula.repeat_interleave(HV // H, dim=2).contiguous() # [B_flat, T, HV, D] + k_fla = k_cula.repeat_interleave(HV // H, dim=2).contiguous() + else: + inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=True) + q_cula = q_fla = inputs["q"] + k_cula = k_fla = inputs["k"] + return inputs, q_fla, k_fla, q_cula, k_cula + + +def _compare_accuracy(fla_results, cula_results): + """Compare accuracy between FLA and cuLA results, handling GVA dq/dk shape mismatch.""" + acc = {} + for name in ("o", "ht", "dv", "dg", "dbeta", "dh0"): + if name in fla_results and name in cula_results: + err_ratio, rel_max, mean_diff = accuracy_stats(fla_results[name], cula_results[name]) + acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} + if "dq" in fla_results and "dq" in cula_results: + dq_fla = fla_results["dq"] + dk_fla = fla_results["dk"] + if HV > H: + # Aggregate FLA HV-space grads back to HQK space for comparison + *head_prefix, hv_size, d_size = dq_fla.shape + dq_fla = dq_fla.reshape(*head_prefix, H, HV // H, d_size).sum(dim=-2) + dk_fla = dk_fla.reshape(*head_prefix, H, HV // H, d_size).sum(dim=-2) + for name, ref, out in (("dq", dq_fla, cula_results["dq"]), ("dk", dk_fla, cula_results["dk"])): + err_ratio, rel_max, mean_diff = accuracy_stats(ref, out) + acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} + return acc + + # ============================================================ # Fixed-length benchmark # ============================================================ def bench_fixed(configs): + gva_note = f"GVA HV={HV} ({HV // H}x)" if HV > H else f"MHA HV=H={H}" print("\n" + "=" * 120) - print(f" Fixed-Length E2E Benchmark: cuLA vs FLA phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}") + print(f" Fixed-Length E2E Benchmark: cuLA vs FLA {gva_note} phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}") print("=" * 120) results = [] for B, T in configs: set_seed(SEED) - device = torch.device("cuda") torch.cuda.empty_cache() seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=torch.device("cuda")) - inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=True) - q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + inputs, q_fla, k_fla, q_cula, k_cula = _prepare_inputs_e2e(B, T, cu_seqlens) + v, g, beta = inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - # Generate do, dht for backward set_seed(SEED + 1) do = torch.randn_like(v) dht = torch.randn_like(init_state) - common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - do=do, - dht=dht, - ) - - # Accuracy: compare outputs and gradients + _shared = dict(v=v, g=g, beta=beta, scale=scale, A_log=A_log, dt_bias=dt_bias, + init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound, + do=do, dht=dht) + common_fla = dict(q=q_fla, k=k_fla, **_shared) + common_cula = dict(q=q_cula, k=k_cula, **_shared) + + # Accuracy acc = {} if PHASE == "e2e": - fla_results = run_kda_e2e_with_grads(**common, fn=fla_chunk_kda) - cula_results = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) + fla_results = run_kda_e2e_with_grads(**common_fla, fn=fla_chunk_kda) + cula_results = run_kda_e2e_with_grads(**common_cula, fn=cula_chunk_kda) torch.cuda.synchronize() - - for name in ("o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"): - err_ratio, rel_max, mean_diff = accuracy_stats(fla_results[name], cula_results[name]) - acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} + acc = _compare_accuracy(fla_results, cula_results) else: - # forward-only accuracy - o_fla, ht_fla = run_kda_e2e(**common, fn=fla_chunk_kda) - o_cula, ht_cula = run_kda_e2e(**common, fn=cula_chunk_kda) + o_fla, ht_fla = run_kda_e2e(**common_fla, fn=fla_chunk_kda) + o_cula, ht_cula = run_kda_e2e(**common_cula, fn=cula_chunk_kda) torch.cuda.synchronize() for name, ref, out in [("o", o_fla, o_cula), ("ht", ht_fla, ht_cula)]: err_ratio, rel_max, mean_diff = accuracy_stats(ref, out) acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} - # For timing, use leaf tensors with requires_grad - q_t = q.detach().clone().requires_grad_(True) - k_t = k.detach().clone().requires_grad_(True) - v_t = v.detach().clone().requires_grad_(True) - g_t = g.detach().clone().requires_grad_(True) - beta_t = beta.detach().clone().requires_grad_(True) - h0_t = init_state.detach().clone().requires_grad_(True) - - timing_common = dict( - q=q_t, - k=k_t, - v=v_t, - g=g_t, - beta=beta_t, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=h0_t, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - do=do, - dht=dht, - ) - - def fn_fla(**kw): - return lambda: run_kda_e2e(**kw, fn=fla_chunk_kda) - - def fn_cula(**kw): - return lambda: run_kda_e2e(**kw, fn=cula_chunk_kda) - - ms_fla = time_kernel(fn_fla(**timing_common)) - ms_cula = time_kernel(fn_cula(**timing_common)) + # Timing: fresh leaf tensors with requires_grad + def _make_timing(q_, k_): + return dict( + q=q_.detach().clone().requires_grad_(True), + k=k_.detach().clone().requires_grad_(True), + v=v.detach().clone().requires_grad_(True), + g=g.detach().clone().requires_grad_(True), + beta=beta.detach().clone().requires_grad_(True), + scale=scale, A_log=A_log, dt_bias=dt_bias, + init_state=init_state.detach().clone().requires_grad_(True), + cu_seqlens=cu_seqlens, lower_bound=lower_bound, do=do, dht=dht, + ) + + ms_fla = time_kernel(lambda: run_kda_e2e(**_make_timing(q_fla, k_fla), fn=fla_chunk_kda)) + ms_cula = time_kernel(lambda: run_kda_e2e(**_make_timing(q_cula, k_cula), fn=cula_chunk_kda)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - r = { - "B": B, - "T": T, - "accuracy": acc, - "ms_fla": ms_fla, - "ms_cula": ms_cula, - "speedup": speedup, - } - results.append(r) - - del q, k, v, g, beta, A_log, dt_bias, inputs, do, dht + results.append({ + "B": B, "T": T, "H": H, "HV": HV, + "accuracy": acc, "ms_fla": ms_fla, "ms_cula": ms_cula, "speedup": speedup, + }) + + del inputs, do, dht torch.cuda.empty_cache() return results @@ -342,115 +354,76 @@ def fn_cula(**kw): # Varlen benchmark # ============================================================ def bench_varlen(configs): + gva_note = f"GVA HV={HV} ({HV // H}x)" if HV > H else f"MHA HV=H={H}" print("\n" + "=" * 120) - print(f" Varlen E2E Benchmark: cuLA vs FLA phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}") + print(f" Varlen E2E Benchmark: cuLA vs FLA {gva_note} phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}") print("=" * 120) results = [] for seq_lens, total_len, dist in configs: set_seed(SEED) - device = torch.device("cuda") torch.cuda.empty_cache() T = total_len - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=torch.device("cuda")) - inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=True) - q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + inputs, q_fla, k_fla, q_cula, k_cula = _prepare_inputs_e2e(1, T, cu_seqlens) + v, g, beta = inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - # Generate do, dht for backward set_seed(SEED + 1) do = torch.randn_like(v) dht = torch.randn_like(init_state) - common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - do=do, - dht=dht, - ) - - # Accuracy: compare outputs and gradients + _shared = dict(v=v, g=g, beta=beta, scale=scale, A_log=A_log, dt_bias=dt_bias, + init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound, + do=do, dht=dht) + common_fla = dict(q=q_fla, k=k_fla, **_shared) + common_cula = dict(q=q_cula, k=k_cula, **_shared) + + # Accuracy acc = {} if PHASE == "e2e": - fla_results = run_kda_e2e_with_grads(**common, fn=fla_chunk_kda) - cula_results = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) + fla_results = run_kda_e2e_with_grads(**common_fla, fn=fla_chunk_kda) + cula_results = run_kda_e2e_with_grads(**common_cula, fn=cula_chunk_kda) torch.cuda.synchronize() - - for name in ("o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"): - err_ratio, rel_max, mean_diff = accuracy_stats(fla_results[name], cula_results[name]) - acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} + acc = _compare_accuracy(fla_results, cula_results) else: - o_fla, ht_fla = run_kda_e2e(**common, fn=fla_chunk_kda) - o_cula, ht_cula = run_kda_e2e(**common, fn=cula_chunk_kda) + o_fla, ht_fla = run_kda_e2e(**common_fla, fn=fla_chunk_kda) + o_cula, ht_cula = run_kda_e2e(**common_cula, fn=cula_chunk_kda) torch.cuda.synchronize() for name, ref, out in [("o", o_fla, o_cula), ("ht", ht_fla, ht_cula)]: err_ratio, rel_max, mean_diff = accuracy_stats(ref, out) acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} - # For timing, use leaf tensors with requires_grad - q_t = q.detach().clone().requires_grad_(True) - k_t = k.detach().clone().requires_grad_(True) - v_t = v.detach().clone().requires_grad_(True) - g_t = g.detach().clone().requires_grad_(True) - beta_t = beta.detach().clone().requires_grad_(True) - h0_t = init_state.detach().clone().requires_grad_(True) - - timing_common = dict( - q=q_t, - k=k_t, - v=v_t, - g=g_t, - beta=beta_t, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=h0_t, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - do=do, - dht=dht, - ) - - def fn_fla(**kw): - return lambda: run_kda_e2e(**kw, fn=fla_chunk_kda) - - def fn_cula(**kw): - return lambda: run_kda_e2e(**kw, fn=cula_chunk_kda) - - ms_fla = time_kernel(fn_fla(**timing_common)) - ms_cula = time_kernel(fn_cula(**timing_common)) + # Timing: fresh leaf tensors with requires_grad + def _make_timing(q_, k_): + return dict( + q=q_.detach().clone().requires_grad_(True), + k=k_.detach().clone().requires_grad_(True), + v=v.detach().clone().requires_grad_(True), + g=g.detach().clone().requires_grad_(True), + beta=beta.detach().clone().requires_grad_(True), + scale=scale, A_log=A_log, dt_bias=dt_bias, + init_state=init_state.detach().clone().requires_grad_(True), + cu_seqlens=cu_seqlens, lower_bound=lower_bound, do=do, dht=dht, + ) + + ms_fla = time_kernel(lambda: run_kda_e2e(**_make_timing(q_fla, k_fla), fn=fla_chunk_kda)) + ms_cula = time_kernel(lambda: run_kda_e2e(**_make_timing(q_cula, k_cula), fn=cula_chunk_kda)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") n_seqs = len(seq_lens) - min_l, max_l = min(seq_lens), max(seq_lens) - avg_l = T // n_seqs - tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" - - r = { - "tag": tag, - "dist": dist, - "T_total": T, - "n_seqs": n_seqs, - "accuracy": acc, - "ms_fla": ms_fla, - "ms_cula": ms_cula, - "speedup": speedup, - } - results.append(r) - - del q, k, v, g, beta, A_log, dt_bias, inputs, do, dht + tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min(seq_lens)}..{max(seq_lens)}] avg={T // n_seqs}" + + results.append({ + "tag": tag, "dist": dist, "T_total": T, "n_seqs": n_seqs, + "H": H, "HV": HV, + "accuracy": acc, "ms_fla": ms_fla, "ms_cula": ms_cula, "speedup": speedup, + }) + + del inputs, do, dht torch.cuda.empty_cache() return results @@ -465,15 +438,17 @@ def print_report(fixed_results, varlen_results): print(" BENCHMARK REPORT: chunk_kda forward+backward (E2E)") print(" cuLA CuTe DSL vs FLA Triton") print( - f" H={H} D={D} dtype=bf16 safe_gate=True phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}" + f" D={D} dtype=bf16 safe_gate=True phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}" ) + gva_note = f"GVA enabled (HV={HV} > H={H}, ratio={HV // H}x)" if HV > H else f"MHA (HV=H={H})" + print(f" {gva_note}") wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") print(f" Warmup={wu} Iters={ni}{mode_tag}") print(sep) - # Determine which accuracy keys to show + # Determine which accuracy keys to show (dq/dk present in e2e mode) if PHASE == "e2e": acc_keys = ["o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"] else: @@ -483,44 +458,39 @@ def print_report(fixed_results, varlen_results): if fixed_results: print("\n [Fixed-Length]") - print(f" {'─' * 125}") - - # Header - print(f" {'B':>3s} {'T':>5s} │ {'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s} │ {'':>10s}{acc_header}") - print(f" {'─' * 125}") - + print(f" {'─' * 130}") + print(f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s} │ {'':>10s}{acc_header}") + print(f" {'─' * 130}") for r in fixed_results: - rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) + gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" + rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) err_ratio_vals = " ".join(f"{r['accuracy'].get(k, {}).get('err_ratio', 0.0):10.6f}" for k in acc_keys) - # Line 1: timing + rel_max - print( - f" {r['B']:3d} {r['T']:5d} │ " - f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x │ " - f"{'rel_max:':>10s}{rel_max_vals}" - ) - # Line 2: err_ratio (no timing columns) - print(f" {'':3s} {'':5s} │ {'':9s} {'':11s} {'':8s} │ {'err_ratio:':>10s}{err_ratio_vals}") - print(f" {'─' * 125}") + prefix = f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " + blank = f" {'':3s} {'':6s} {'':3s} {'':3s} {'':4s} │ " + timing = f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x │ " + blank_t = f"{'':9s} {'':11s} {'':8s} │ " + print(f"{prefix}{timing}{'rel_max:':>10s}{rel_max_vals}") + print(f"{blank}{blank_t}{'err_ratio:':>10s}{err_ratio_vals}") + print(f" {'─' * 130}") if varlen_results: print("\n [Varlen]") - print(f" {'─' * 140}") - - print(f" {'Config':>45s} │ {'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s} │ {'':>10s}{acc_header}") - print(f" {'─' * 140}") - + print(f" {'─' * 145}") + print(f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s} │ {'':>10s}{acc_header}") + print(f" {'─' * 145}") for r in varlen_results: - rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) + gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" + rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) err_ratio_vals = " ".join(f"{r['accuracy'].get(k, {}).get('err_ratio', 0.0):10.6f}" for k in acc_keys) - # Line 1: timing + rel_max - print( - f" {r['tag']:>45s} │ " - f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x │ " - f"{'rel_max:':>10s}{rel_max_vals}" - ) - # Line 2: err_ratio (no config/timing columns) - print(f" {'':>45s} │ {'':9s} {'':11s} {'':8s} │ {'err_ratio:':>10s}{err_ratio_vals}") - print(f" {'─' * 140}") + prefix = f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " + blank = f" {'':>45s} {'':3s} {'':3s} {'':4s} │ " + timing = f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x │ " + blank_t = f"{'':9s} {'':11s} {'':8s} │ " + print(f"{prefix}{timing}{'rel_max:':>10s}{rel_max_vals}") + print(f"{blank}{blank_t}{'err_ratio:':>10s}{err_ratio_vals}") + print(f" {'─' * 145}") print(f"\n{sep}\n") @@ -564,9 +534,15 @@ def main(): action="store_true", help="Run determinism check: verify cuLA produces identical outputs across repeated runs", ) + parser.add_argument( + "--hv", + type=int, + default=None, + help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H for GVA mode.", + ) args = parser.parse_args() - global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE, PHASE + global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE, PHASE, HV if args.ncu: NCU_MODE = True print("[NCU mode] warmup=1, iters=1") @@ -577,6 +553,12 @@ def main(): DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") PHASE = args.phase + if args.hv is not None: + if args.hv < H or args.hv % H != 0: + raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") + HV = args.hv + if HV > H: + print(f"[GVA] HV={HV} (H={H}, ratio={HV // H}x)") if args.check_determinism: det_configs = [(5, 1024), (10, 4096), (10, 8192), (10, 16384)] @@ -616,6 +598,7 @@ def main(): varlen_res = bench_varlen(varlen_configs) print_report(fixed_res, varlen_res) + return fixed_res, varlen_res diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 05198ab..fcd9f17 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -324,6 +324,93 @@ def prepare_safe_gate_inputs( ) +def prepare_safe_gate_inputs_gva( + batch_size, + T, + HQK, + HV, + D, + device, + cu_seqlens=None, + chunk_size=CHUNK_SIZE, + seed=SEED, + has_init_state=False, +): + """Prepare native GVA inputs for cuLA chunk_kda (q/k stay in HQK head space). + + Unlike ``prepare_safe_gate_inputs`` which expands q/k to HV heads, + this function returns q/k in the smaller HQK space so that cuLA can + exercise its native GVA path (HQK < HV). + + GVA layout: + q, k : (batch_flat, T, HQK, D) — l2-normalised, no grad + v : (batch_flat, T, HV, D) + g, beta : (batch_flat, T, HV, D / ∅) — after kda_gate_chunk_cumsum + A_log : (HV,) + dt_bias : (HV * D,) + init_state : (num_seqs, HV, D, D) or None + + The ``g`` tensor is the *pre-processed* chunk-cumsum gate, ready to be fed + directly into chunk_kda with ``use_gate_in_kernel=True``. + """ + assert HV > 0 and HQK > 0 and HV % HQK == 0, f"HV ({HV}) must be a positive multiple of HQK ({HQK})" + + dtype = torch.bfloat16 + scale = D ** (-0.5) + + set_seed(seed) + + q = torch.randn(batch_size, T, HQK, D, dtype=dtype, device=device).requires_grad_(False) + k = torch.randn(batch_size, T, HQK, D, dtype=dtype, device=device).requires_grad_(False) + v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) + g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) + beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) + + # l2-normalise q/k in the HQK head space (native GVA: one QK head per group). + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + + A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device).requires_grad_(False) + + # Flatten to batch_size=1 for cu_seqlens compatibility. + if batch_size != 1: + q, k, v, g, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + + g = kda_gate_chunk_cumsum( + g=g, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=-5.0, + ) + + init_state = None + if has_init_state: + num_seqs = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch_size + init_state = torch.randn(num_seqs, HV, D, D, dtype=torch.float, device=device).requires_grad_(False) + + return dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A_log=A_log, + dt_bias=dt_bias, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + init_state=init_state, + lower_bound=-5.0, + ) + + def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED): """Prepare preprocessed inputs ready for chunk_kda_fwd_intra. diff --git a/cula/kda/chunk.py b/cula/kda/chunk.py index b89c780..e0a9320 100644 --- a/cula/kda/chunk.py +++ b/cula/kda/chunk.py @@ -379,10 +379,17 @@ def chunk_kda( if not (-5 <= lower_bound < 0): raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") - assert q.shape == k.shape == g.shape, "q, k, g must have the same shape." + B, T, HQK, K_dim = q.shape + HV = g.shape[2] + assert q.shape == k.shape, "q and k must have the same shape." + assert q.shape[:2] == g.shape[:2], "q/k and g must share batch and sequence dimensions." + assert HV % HQK == 0 and HV >= HQK, ( + f"g.shape[2] (HV={HV}) must be a positive multiple of q.shape[2] (HQK={HQK})." + ) + assert g.shape == (B, T, HV, K_dim), f"g must be [B,T,HV,K]=({B},{T},{HV},{K_dim}), got {tuple(g.shape)}." + assert beta.shape[:3] == (B, T, HV), f"beta must have shape [B,T,HV,...], got {tuple(beta.shape)}." + assert v.shape[:3] == (B, T, HV), f"v must have shape [B,T,HV,...], got {tuple(v.shape)}." assert k.shape[-1] <= 256, "Currently we only support key headdim <=256 for KDA :-(" - assert beta.shape == q.shape[:3], "beta must be of shape (batch size, seq len, num of head)." - assert v.shape == (*q.shape[:3], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)." assert q.dtype == k.dtype == v.dtype == torch.bfloat16, "q, k, v must be in bfloat16." assert beta.dtype == torch.bfloat16 or beta.dtype == torch.float32, "beta must be in bfloat16 or float32." assert q.shape[-1] == k.shape[-1] == v.shape[-1] == 128, "Currently we only support head dim of 128 for KDA" diff --git a/cula/kda/chunk_bwd.py b/cula/kda/chunk_bwd.py index 859b6be..de7074d 100644 --- a/cula/kda/chunk_bwd.py +++ b/cula/kda/chunk_bwd.py @@ -61,8 +61,6 @@ ) @triton.jit(do_not_specialize=["T"]) def chunk_kda_bwd_kernel_dAv( - q, - k, v, A, do, @@ -80,6 +78,7 @@ def chunk_kda_bwd_kernel_dAv( BV: tl.constexpr, IS_VARLEN: tl.constexpr, ): + # H here is HV (v-head count). q/k are not needed for this computation. i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: @@ -90,8 +89,6 @@ def chunk_kda_bwd_kernel_dAv( bos, eos = i_b * T, i_b * T + T # offset calculation - q += (bos * H + i_h) * K - k += (bos * H + i_h) * K v += (bos * H + i_h) * V do += (bos * H + i_h) * V dv += (bos * H + i_h) * V @@ -165,6 +162,7 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( scale, T, H: tl.constexpr, + HQK: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -173,8 +171,12 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( TRANSPOSE_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, ): + # H = HV (v-head count); grid enumerates B * HV tile-pairs. + # HQK = qk-head count (HQK <= H; when HQK == H this is standard non-GVA). + # For each v-head i_h, the paired qk-head is i_hqk = i_h // (H // HQK). i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H + i_hqk = i_h // (H // HQK) if IS_VARLEN: i_tg = i_t.to(tl.int64) @@ -191,8 +193,9 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( m_t = o_t < T m_last = o_t == min(T, i_t * BT + BT) - 1 - q += (bos * H + i_h) * K - k += (bos * H + i_h) * K + # q/k/dq/dk use qk-head stride (HQK); all others use v-head stride (H = HV). + q += (bos * HQK + i_hqk) * K + k += (bos * HQK + i_hqk) * K v += (bos * H + i_h) * V v_new += (bos * H + i_h) * V g += (bos * H + i_h) * K @@ -201,8 +204,8 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( h += (i_tg * H + i_h) * K * V do += (bos * H + i_h) * V dh += (i_tg * H + i_h) * K * V - dq += (bos * H + i_h) * K - dk += (bos * H + i_h) * K + dq += (bos * HQK + i_hqk) * K + dk += (bos * HQK + i_hqk) * K dv += (bos * H + i_h) * V dv2 += (bos * H + i_h) * V dg += (bos * H + i_h) * K @@ -212,7 +215,7 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_beta = tl.load(p_beta, boundary_check=(0,)) - p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) # H = HV b_A = tl.load(p_A, boundary_check=(0, 1)) b_dA = tl.zeros([BT, BT], dtype=tl.float32) @@ -222,7 +225,7 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K - p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (HQK * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) @@ -287,15 +290,15 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( b_dkgb = tl.dot(b_A, b_dw) b_db += tl.sum(b_dkgb * b_kg, 1) - p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (T, K), (HQK * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_kdk = b_k * b_dk b_dgk += tl.sum(b_kdk, axis=0) b_dg = b_q * b_dq - b_kdk + m_last[:, None] * b_dgk + b_kg * b_dkgb * b_beta[:, None] b_dk = b_dk + b_dkgb * b_gb - p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (HQK * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (HQK * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) @@ -307,8 +310,8 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) b_dA = tl.where(m_A, -b_dA, 0) - p_dA = tl.make_block_ptr(dA, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_dA = tl.make_block_ptr(dA, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) # H = HV + p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT,), (BT,), (0,)) # H = HV tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) @@ -324,7 +327,10 @@ def chunk_kda_bwd_dAv( chunk_size: int = 64, chunk_indices: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - B, T, H, K, V = *k.shape, do.shape[-1] + # q/k are accepted for API compatibility but not forwarded to the kernel + # (they are unused in the dAv computation). + B, T, HV, V = v.shape[0], v.shape[1], v.shape[2], v.shape[3] + K = k.shape[-1] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) @@ -339,12 +345,11 @@ def chunk_kda_bwd_dAv( BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - dA = v.new_empty(B, T, H, BT, dtype=torch.float) + # dA and dv live in v-head space (HV), matching Aqk/v_new/do shapes. + dA = v.new_empty(B, T, HV, BT, dtype=torch.float) dv = torch.empty_like(do) - grid = (NT, B * H) + grid = (NT, B * HV) chunk_kda_bwd_kernel_dAv[grid]( - q=q, - k=k, v=v, A=A, do=do, @@ -354,7 +359,7 @@ def chunk_kda_bwd_dAv( chunk_indices=chunk_indices, scale=scale, T=T, - H=H, + H=HV, K=K, V=V, BT=BT, @@ -382,13 +387,16 @@ def chunk_kda_bwd_wy_dqkg_fused( chunk_indices: torch.LongTensor | None = None, transpose_state_layout: bool = False, ): - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, HQK, K = k.shape + HV = v.shape[2] + V = v.shape[-1] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + # dq/dk live in qk-head space (HQK); dv/dg/db/dA in v-head space (HV). dq = torch.empty_like(q, dtype=torch.float) dk = torch.empty_like(k, dtype=torch.float) dv2 = torch.empty_like(v) @@ -396,7 +404,7 @@ def chunk_kda_bwd_wy_dqkg_fused( db = torch.empty_like(beta, dtype=torch.float) dA = torch.empty_like(A, dtype=torch.float) - grid = (NT, B * H) + grid = (NT, B * HV) chunk_kda_bwd_kernel_wy_dqkg_fused[grid]( q=q, k=k, @@ -419,7 +427,8 @@ def chunk_kda_bwd_wy_dqkg_fused( chunk_indices=chunk_indices, scale=scale, T=T, - H=H, + H=HV, + HQK=HQK, K=K, V=V, BT=BT, @@ -457,7 +466,8 @@ def chunk_kda_bwd( ): assert transpose_state_layout is False, "transpose_state_layout=True is not supported for training." if disable_recompute is False: - B, T, _, _ = k.shape + B, T, HQK, K_dim = k.shape + HV = v.shape[2] if use_gate_in_kernel: g = kda_gate_chunk_cumsum( g=g_org, @@ -475,10 +485,11 @@ def chunk_kda_bwd( cu_seqlens = prepare_uniform_cu_seqlens(B, T, q.device, torch.int32) if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) - w = torch.empty_like(k) + # In GVA mode, w/u/qg/kg live in v-head space (HV), not qk-head space (HQK). + w = torch.empty(B, T, HV, K_dim, device=k.device, dtype=k.dtype) u = torch.empty_like(v) - qg = torch.empty_like(q) if q is not None else None - kg = torch.empty_like(k) if g is not None else None + qg = torch.empty(B, T, HV, K_dim, device=q.device, dtype=q.dtype) if q is not None else None + kg = torch.empty(B, T, HV, K_dim, device=k.device, dtype=k.dtype) if g is not None else None cula_cuda.recompute_w_u_cuda(k, v, beta, Akk, g, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q, qg) if cp_context is not None: # Restore the full initial_state tensor from the compressed version. diff --git a/cula/kda/chunk_fwd.py b/cula/kda/chunk_fwd.py index 9fab235..8ebb782 100644 --- a/cula/kda/chunk_fwd.py +++ b/cula/kda/chunk_fwd.py @@ -126,10 +126,17 @@ def chunk_kda_fwd( # only the first state in the tensor is relevant. We compress it to optimize memory for `save_for_backward`. initial_state = compress_h0(initial_state, context=cp_context) + # GVA: if HQK < HV, broadcast q from HQK heads to HV heads so fwd_o sees + # consistent head dimensions. repeat_interleave mirrors the broadcast semantics + # already used inside chunk_kda_fwd_intra (each QK head paired with + # heads_per_group = HV // HQK consecutive V heads). + HQK, HV = q.shape[2], v.shape[2] + q_fwd_o = q.repeat_interleave(HV // HQK, dim=2) if HV != HQK else q + # Please ensure zeros, since vllm will use padding v o = torch.zeros_like(v) chunk_gla_fwd_o( - q=q, + q=q_fwd_o, v=v_new, g=g, A=Aqk, diff --git a/cula/kda/chunk_intra.py b/cula/kda/chunk_intra.py index 9fcc93a..48e1f38 100644 --- a/cula/kda/chunk_intra.py +++ b/cula/kda/chunk_intra.py @@ -383,6 +383,7 @@ def chunk_kda_bwd_kernel_intra( B, T, H: tl.constexpr, + HV: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, @@ -392,8 +393,12 @@ def chunk_kda_bwd_kernel_intra( SAFE_GATE: tl.constexpr, USE_GATHER: tl.constexpr, ): + # H = HQK (qk-head count); q/k/dq/dk use H stride. + # HV = v-head count; g/beta/dAqk/dAkk/dg/db use HV stride. + # Grid enumerates B * HV: i_h is the v-head index, i_hqk is the qk-head. i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_b, i_h = i_bh // H, i_bh % H + i_b, i_h = i_bh // HV, i_bh % HV + i_hqk = i_h // (HV // H) i_k, i_i = i_kc // NC, i_kc % NC all = B * T @@ -411,38 +416,38 @@ def chunk_kda_bwd_kernel_intra( o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K - q += (bos * H + i_h) * K - k += (bos * H + i_h) * K - g += (bos * H + i_h) * K - beta += bos * H + i_h - - dAqk += (bos * H + i_h) * BT - dAkk += (bos * H + i_h) * BT - dq += (bos * H + i_h) * K - dq2 += (bos * H + i_h) * K - dk += (bos * H + i_h) * K - dk2 += (bos * H + i_h) * K - dg += (bos * H + i_h) * K - dg2 += (bos * H + i_h) * K - db += (i_k * all + bos) * H + i_h - - p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + q += (bos * H + i_hqk) * K + k += (bos * H + i_hqk) * K + g += (bos * HV + i_h) * K + beta += bos * HV + i_h + + dAqk += (bos * HV + i_h) * BT + dAkk += (bos * HV + i_h) * BT + dq += (bos * H + i_hqk) * K + dq2 += (bos * H + i_hqk) * K + dk += (bos * H + i_hqk) * K + dk2 += (bos * H + i_hqk) * K + dg += (bos * HV + i_h) * K + dg2 += (bos * HV + i_h) * K + db += (i_k * all + bos) * HV + i_h + + p_g = tl.make_block_ptr(g, (T, K), (HV * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) - p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_ti,), (BC,), (0,)) b_b = tl.load(p_b, boundary_check=(0,)) b_dq2 = tl.zeros([BC, BK], dtype=tl.float32) b_dk2 = tl.zeros([BC, BK], dtype=tl.float32) if i_i > 0: - p_gn = g + i_ti * H * K + o_k + p_gn = g + i_ti * HV * K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] for i_j in range(0, i_i): p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) - p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (HV * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (HV * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (HV * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) @@ -459,9 +464,9 @@ def chunk_kda_bwd_kernel_intra( o_i = tl.arange(0, BC) m_dA = (i_ti + o_i) < T - o_dA = (i_ti + o_i) * H * BT + i_i * BC + o_dA = (i_ti + o_i) * HV * BT + i_i * BC p_kj = k + i_ti * H * K + o_k - p_gkj = g + i_ti * H * K + o_k + p_gkj = g + i_ti * HV * K + o_k p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) @@ -475,8 +480,8 @@ def chunk_kda_bwd_kernel_intra( p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H * K + o_k b_gn = tl.load(p_gn, mask=m_k, other=0)[None, :] - p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) - p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (HV * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (HV * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) b_dAqk_diag_qk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) b_dAkk_diag_qk = tl.load(p_dAkk, boundary_check=(0, 1)).to(tl.float32) @@ -507,15 +512,15 @@ def chunk_kda_bwd_kernel_intra( b_dq2 += tl.where(m_i, b_dAqk[:, None] * b_kj[None, :] * b_gqk, 0.0) b_dk2 += tl.where(m_i, b_dAkk[:, None] * b_kj[None, :] * b_gqk, 0.0) - p_kj += H * K - p_gkj += H * K + p_kj += H * K # k stride: HQK + p_gkj += HV * K # g stride: HV b_db = tl.sum(b_dk2 * b_k, 1) b_dk2 *= b_b[:, None] - p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) - p_dq2 = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) - p_db = tl.make_block_ptr(db, (T,), (H,), (i_ti,), (BC,), (0,)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) # H = HQK + p_dq2 = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) # H = HQK + p_db = tl.make_block_ptr(db, (T,), (HV,), (i_ti,), (BC,), (0,)) b_dg2 = b_q * b_dq2 b_dq2 = b_dq2 + tl.load(p_dq, boundary_check=(0, 1)) @@ -527,16 +532,16 @@ def chunk_kda_bwd_kernel_intra( NC = min(NC, tl.cdiv(T - i_t * BT, BC)) if i_i < NC - 1: - p_gn = g + (min(i_ti + BC, T) - 1) * H * K + o_k + p_gn = g + (min(i_ti + BC, T) - 1) * HV * K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] for i_j in range(i_i + 1, NC): p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_j * BC,), (BC,), (0,)) - p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) - p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_gk = tl.make_block_ptr(g, (T, K), (HV * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_t * BT + i_j * BC,), (BC,), (0,)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, HV * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, HV * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) # [BC] b_b = tl.load(p_b, boundary_check=(0,)) # [BC, BK] @@ -558,25 +563,25 @@ def chunk_kda_bwd_kernel_intra( b_dkt += tl.dot(b_dAqk, b_qg) b_dkt += tl.dot(b_dAkk, b_kbg) b_dkt *= exp2(b_gn - b_g) - o_dA = i_ti * H * BT + i_i * BC + o_i + o_dA = i_ti * HV * BT + i_i * BC + o_i p_qj = q + i_ti * H * K + o_k p_kj = k + i_ti * H * K + o_k - p_gkj = g + i_ti * H * K + o_k - p_bj = beta + i_ti * H + p_gkj = g + i_ti * HV * K + o_k + p_bj = beta + i_ti * HV if SAFE_GATE: if USE_GATHER: b_gn = gather(b_g, tl.full([1, BK], min(BC // 2, T - i_ti - 1), dtype=tl.int16), axis=0) else: - p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H * K + o_k + p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * HV * K + o_k b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) - p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_ti,), (BC,), (0,)) b_b = tl.load(p_b, boundary_check=(0,)) - p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H * BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) - p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H * BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, HV * BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, HV * BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) b_dAqk_diag_kk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) b_dAkk_diag_kk = tl.load(p_dAkk, boundary_check=(0, 1)).to(tl.float32) @@ -598,8 +603,8 @@ def chunk_kda_bwd_kernel_intra( else: for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] - b_dAqk = tl.load(dAqk + o_dA + j * H * BT) - b_dAkk = tl.load(dAkk + o_dA + j * H * BT) + b_dAqk = tl.load(dAqk + o_dA + j * HV * BT) + b_dAkk = tl.load(dAkk + o_dA + j * HV * BT) # [BK,] b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) b_kbj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) * tl.load(p_bj) @@ -610,14 +615,14 @@ def chunk_kda_bwd_kernel_intra( b_dkt += tl.where(m_i, b_dAqk[:, None] * b_qj[None, :] * b_gkq, 0.0) b_dkt += tl.where(m_i, b_dAkk[:, None] * b_kbj[None, :] * b_gkq, 0.0) - p_qj += H * K - p_kj += H * K - p_gkj += H * K - p_bj += H - p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) - p_dk2 = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) - p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) - p_dg2 = tl.make_block_ptr(dg2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_qj += H * K # k-head stride for q + p_kj += H * K # k-head stride for k + p_gkj += HV * K # v-head stride for g + p_bj += HV # v-head stride for beta + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) # H = HQK + p_dk2 = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) # H = HQK + p_dg = tl.make_block_ptr(dg, (T, K), (HV * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2, (T, K), (HV * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_dg2 += (b_dk2 - b_dkt) * b_k + tl.load(p_dg, boundary_check=(0, 1)) b_dk2 += tl.load(p_dk, boundary_check=(0, 1)) @@ -826,7 +831,9 @@ def chunk_kda_bwd_intra( chunk_size: int = 64, safe_gate: bool = False, ): - B, T, H, K = k.shape + B, T, HQK, K = k.shape + # g/beta/dAqk/dAkk/dg/db live in v-head space (HV). + HV = g.shape[2] BT = chunk_size BC = min(16, BT) BK = min(32, triton.next_power_of_2(K)) @@ -837,11 +844,12 @@ def chunk_kda_bwd_intra( NC = triton.cdiv(BT, BC) NK = triton.cdiv(K, BK) + # dq2/dk2 are in qk-head space; db2/dg2 are in v-head space. dq2 = torch.empty_like(q) dk2 = torch.empty_like(k) db2 = beta.new_empty(NK, *beta.shape, dtype=torch.float) dg2 = torch.empty_like(dg, dtype=torch.float) - grid = (NK * NC, NT, B * H) + grid = (NK * NC, NT, B * HV) chunk_kda_bwd_kernel_intra[grid]( q=q, k=k, @@ -860,7 +868,8 @@ def chunk_kda_bwd_intra( chunk_indices=chunk_indices, B=B, T=T, - H=H, + H=HQK, + HV=HV, K=K, BT=BT, BC=BC, diff --git a/tests/test_kda.py b/tests/test_kda.py index fadadbf..99b3ca1 100644 --- a/tests/test_kda.py +++ b/tests/test_kda.py @@ -25,6 +25,12 @@ from cula.kda import chunk_kda +# ─── helpers ────────────────────────────────────────────────────────────────── + +def _repeat_head(x: torch.Tensor, group_size: int) -> torch.Tensor: + """Replicate tensor along head dim (dim=2) by group_size, keeping contiguous layout.""" + return x.repeat_interleave(group_size, dim=2).contiguous() + pytestmark = pytest.mark.sm100_only @@ -281,3 +287,184 @@ def test_safe_gate_chunk_varlen( assert_close("dg", ref_dg, tri_dg, 0.015) assert_close("db", ref_db, tri_db, 0.015) assert_close("dh0", ref_dh0, tri_dh0, 0.007) + + +# ============================================================================= +# GVA (Grouped Value Attention) end-to-end tests +# ============================================================================= + +@pytest.mark.parametrize("disable_recompute", [True, False], ids=["no_recomp", "recomp"]) +@pytest.mark.parametrize("group_size", [2, 4], ids=["gs2", "gs4"]) +@pytest.mark.parametrize( + ("B", "T", "HQK", "D"), + [ + pytest.param(*cfg, id="B{}-T{}-HQK{}-D{}".format(*cfg)) + for cfg in [ + (1, 256, 2, 128), + (2, 512, 4, 128), + (1, 1000, 4, 128), # non-multiple-of-BT boundary stress + (2, 1024, 4, 128), + ] + ], +) +def test_chunk_kda_gva(B, T, HQK, D, group_size, disable_recompute): + """chunk_kda with native GVA (HQK < HV = HQK * group_size) must produce the + same forward outputs and the same v/g/beta/h0 gradients as running the + reference (FLA naive_recurrent_kda) with q/k expanded to HV heads. + + dq/dk gradients are also verified after summing over the group axis, + since the reference receives k replicated to HV and therefore accumulates + gradients across the group dimension. + """ + HV = HQK * group_size + torch.manual_seed(42) + + # ---- raw tensors -------------------------------------------------------- + q_raw = torch.randn(B, T, HQK, D, dtype=torch.bfloat16) + k_raw = torch.randn(B, T, HQK, D, dtype=torch.bfloat16) + v_raw = torch.randn(B, T, HV, D, dtype=torch.bfloat16) + # Gates must satisfy safe_gate: log-sigmoid values clamped to [-5, 0] + g_raw = F.logsigmoid(torch.randn(B, T, HV, D, dtype=torch.float)).clamp(-5.0, 0.0) + beta_raw = torch.randn(B, T, HV, dtype=torch.float32).sigmoid() + h0_raw = torch.randn(B, HV, D, D, dtype=torch.float32) + + # ---- reference (FLA naive_recurrent_kda with expanded q/k) -------------- + # Apply l2norm before expanding to make both paths numerically comparable. + q_norm = F.normalize(q_raw.float(), p=2, dim=-1).to(torch.bfloat16) + k_norm = F.normalize(k_raw.float(), p=2, dim=-1).to(torch.bfloat16) + + q_hv = _repeat_head(q_norm, group_size).to(device).requires_grad_(True) + k_hv = _repeat_head(k_norm, group_size).to(device).requires_grad_(True) + v_ref = v_raw.to(device).requires_grad_(True) + g_ref = g_raw.to(device).requires_grad_(True) + b_ref = beta_raw.to(device).requires_grad_(True) + h0_ref = h0_raw.to(device).requires_grad_(True) + + ref_o, ref_ht = naive_recurrent_kda( + q=q_hv, k=k_hv, v=v_ref, g=g_ref, beta=b_ref, + initial_state=h0_ref, output_final_state=True, + ) + do = torch.randn_like(ref_o) + dht = torch.randn_like(ref_ht) + ((ref_o * do).sum() + (ref_ht * dht).sum()).backward() + + ref_dq_hv = q_hv.grad # [B,T,HV,D] + ref_dk_hv = k_hv.grad # [B,T,HV,D] + ref_dv = v_ref.grad # [B,T,HV,D] + ref_dg = g_ref.grad # [B,T,HV,D] + ref_db = b_ref.grad # [B,T,HV] + ref_dh0 = h0_ref.grad # [B,HV,D,D] + + # Sum group contributions for dq/dk: [B,T,HV,D] → [B,T,HQK,D] + ref_dq = ref_dq_hv.view(B, T, HQK, group_size, D).sum(dim=3) + ref_dk = ref_dk_hv.view(B, T, HQK, group_size, D).sum(dim=3) + + # ---- cuLA chunk_kda with native GVA ------------------------------------ + q_c = q_norm.to(device).requires_grad_(True) + k_c = k_norm.to(device).requires_grad_(True) + v_c = v_raw.to(device).requires_grad_(True) + g_c = g_raw.to(device).requires_grad_(True) + b_c = beta_raw.to(device).requires_grad_(True) + h0_c = h0_raw.to(device).requires_grad_(True) + + tri_o, tri_ht = chunk_kda( + q=q_c, k=k_c, v=v_c, g=g_c, beta=b_c, + initial_state=h0_c, output_final_state=True, + safe_gate=True, lower_bound=-5.0, + disable_recompute=disable_recompute, + ) + ((tri_o * do).sum() + (tri_ht * dht).sum()).backward() + + # ---- compare ------------------------------------------------------------ + assert_close("o", ref_o, tri_o, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("dq", ref_dq, q_c.grad, 0.01) + assert_close("dk", ref_dk, k_c.grad, 0.01) + assert_close("dv", ref_dv, v_c.grad, 0.008) + assert_close("dg", ref_dg, g_c.grad, 0.02) + assert_close("db", ref_db, b_c.grad, 0.02) + assert_close("dh0", ref_dh0, h0_c.grad, 0.008) + + +@pytest.mark.parametrize("disable_recompute", [True, False], ids=["no_recomp", "recomp"]) +@pytest.mark.parametrize("group_size", [2], ids=["gs2"]) +@pytest.mark.parametrize( + ("HQK", "D", "cu_seqlens"), + [ + pytest.param(*cfg, id="HQK{}-D{}-ns{}".format(cfg[0], cfg[1], len(cfg[2]) - 1)) + for cfg in [ + (2, 128, [0, 256, 500, 1000]), + (4, 128, [0, 100, 300, 1200, 2000]), + (4, 128, [0, 15, 100, 300, 1200, 3000, 4096]), + ] + ], +) +def test_chunk_kda_gva_varlen(HQK, D, cu_seqlens, group_size, disable_recompute): + """GVA chunk_kda correctness under variable-length (packed) inputs.""" + HV = HQK * group_size + torch.manual_seed(42) + + cu_seqlens_t = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + cu_seqlens_cpu = cu_seqlens_t.cpu() + T = int(cu_seqlens_t[-1].item()) + N = len(cu_seqlens) - 1 + + q_raw = torch.randn(1, T, HQK, D, dtype=torch.bfloat16) + k_raw = torch.randn(1, T, HQK, D, dtype=torch.bfloat16) + v_raw = torch.randn(1, T, HV, D, dtype=torch.bfloat16) + g_raw = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float)).clamp(-5.0, 0.0) + beta_raw = torch.randn(1, T, HV, dtype=torch.float32).sigmoid() + h0_raw = torch.randn(N, HV, D, D, dtype=torch.float32) + + q_norm = F.normalize(q_raw.float(), p=2, dim=-1).to(torch.bfloat16) + k_norm = F.normalize(k_raw.float(), p=2, dim=-1).to(torch.bfloat16) + + # ---- reference ---------------------------------------------------------- + q_hv = _repeat_head(q_norm, group_size).to(device).requires_grad_(True) + k_hv = _repeat_head(k_norm, group_size).to(device).requires_grad_(True) + v_ref = v_raw.to(device).requires_grad_(True) + g_ref = g_raw.to(device).requires_grad_(True) + b_ref = beta_raw.to(device).requires_grad_(True) + h0_ref = h0_raw.to(device).requires_grad_(True) + + ref_o, ref_ht = chunk_kda( + q=q_hv, k=k_hv, v=v_ref, g=g_ref, beta=b_ref, + initial_state=h0_ref, output_final_state=True, + cu_seqlens=cu_seqlens_t, cu_seqlens_cpu=cu_seqlens_cpu, + safe_gate=True, lower_bound=-5.0, + disable_recompute=disable_recompute, + ) + do = torch.randn_like(ref_o) + dht = torch.randn_like(ref_ht) + ((ref_o * do).sum() + (ref_ht * dht).sum()).backward() + + ref_dq = q_hv.grad.view(1, T, HQK, group_size, D).sum(dim=3) + ref_dk = k_hv.grad.view(1, T, HQK, group_size, D).sum(dim=3) + ref_dv, ref_dg, ref_db, ref_dh0 = v_ref.grad, g_ref.grad, b_ref.grad, h0_ref.grad + + # ---- cuLA native GVA ---------------------------------------------------- + q_c = q_norm.to(device).requires_grad_(True) + k_c = k_norm.to(device).requires_grad_(True) + v_c = v_raw.to(device).requires_grad_(True) + g_c = g_raw.to(device).requires_grad_(True) + b_c = beta_raw.to(device).requires_grad_(True) + h0_c = h0_raw.to(device).requires_grad_(True) + + tri_o, tri_ht = chunk_kda( + q=q_c, k=k_c, v=v_c, g=g_c, beta=b_c, + initial_state=h0_c, output_final_state=True, + cu_seqlens=cu_seqlens_t, cu_seqlens_cpu=cu_seqlens_cpu, + safe_gate=True, lower_bound=-5.0, + disable_recompute=disable_recompute, + ) + ((tri_o * do).sum() + (tri_ht * dht).sum()).backward() + + # ---- compare ------------------------------------------------------------ + assert_close("o", ref_o, tri_o, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("dq", ref_dq, q_c.grad, 0.01) + assert_close("dk", ref_dk, k_c.grad, 0.01) + assert_close("dv", ref_dv, v_c.grad, 0.008) + assert_close("dg", ref_dg, g_c.grad, 0.02) + assert_close("db", ref_db, b_c.grad, 0.02) + assert_close("dh0", ref_dh0, h0_c.grad, 0.008) From 28c40110139cdda18b01ecfd814876402da495b4 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 16:07:51 +0800 Subject: [PATCH 05/14] benchmark and test --- cula/kda/blackwell_fused_fwd.py | 80 ++++++++++++++++++++++++--------- cula/ops/kda_fully_fused_wip.py | 4 ++ 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/cula/kda/blackwell_fused_fwd.py b/cula/kda/blackwell_fused_fwd.py index 291d966..ea3b994 100644 --- a/cula/kda/blackwell_fused_fwd.py +++ b/cula/kda/blackwell_fused_fwd.py @@ -68,11 +68,22 @@ def forward( chunk_indices: torch.IntTensor | None = None, ): chunk_size = 64 - assert q.shape[-2] == v.shape[-2] == k.shape[-2], "Number of heads must be the same for q, k, v." + + # GVA: q/k are in HQK head space; v/g/beta/o/state are in HV head space. + # HV must be a positive multiple of HQK. When HV == HQK this is standard MHA. + assert q.shape == k.shape, "q and k must have the same shape." + assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share (B, T) dimensions." + B, S, HQK, D = q.shape + HV = v.shape[2] + assert HQK > 0 and HV > 0 and HV % HQK == 0, ( + f"v/g head count (HV={HV}) must be a positive multiple of q/k head count (HQK={HQK})." + ) + assert g.shape == (B, S, HV, D), f"g must be [B,S,HV,D]=({B},{S},{HV},{D}), got {tuple(g.shape)}." + assert beta.shape[:3] == (B, S, HV), f"beta must have shape [B,S,HV,...], got {tuple(beta.shape)}." + heads_per_group = HV // HQK global compiled_kernel_cache - B, S, H, D = q.shape is_varlen = cu_seqlens is not None if is_varlen: assert B == 1, "For varlen, batch size must be 1. Flatten variable-length inputs first." @@ -120,20 +131,29 @@ def forward( q, q_rstd = l2norm_fwd(q) k, k_rstd = l2norm_fwd(k) + # GVA compat: the WIP CuTe kernel currently assumes a single head count for + # all operands. Until P1 (native GVA in the kernel), expand q/k from HQK to + # HV heads so the kernel sees consistent head dimensions everywhere. + # P1 will remove this expansion and thread HQK/HV through the kernel directly. + if heads_per_group > 1: + q = q.repeat_interleave(heads_per_group, dim=2).contiguous() + k = k.repeat_interleave(heads_per_group, dim=2).contiguous() + q_cute = from_dlpack(q.detach()) k_cute = from_dlpack(k.detach()) v_cute = from_dlpack(v.detach()) g_cute = from_dlpack(g.detach()) beta_cute = from_dlpack(beta.detach()) - # FIXME: support return final_states - o = torch.empty_like(q) + # Output lives in HV head space (same as v). + o = torch.empty_like(v) o_cute = from_dlpack(o.detach()) stream = cutlass_torch.default_stream() has_initial_state = initial_state is not None - cache_key = (has_initial_state, output_final_state, safe_gate, is_varlen, scale, chunk_size, D, USE_FAST_MATH) + # Include HQK and HV so that MHA and GVA compilations are cached separately. + cache_key = (has_initial_state, output_final_state, safe_gate, is_varlen, scale, chunk_size, HQK, HV, D, USE_FAST_MATH) # Prepare cu_seqlens as int32 for kernel if is_varlen: @@ -184,7 +204,7 @@ def forward( dc["workspace_cute"] = from_dlpack(ws_buf.detach()) workspace_cute = dc["workspace_cute"] - # State shape: [num_seqs, H, D, D] + # State shape: [num_seqs, HV, D, D] — state is per v-head under GVA. # Prepare initial_state and final_state tensors if has_initial_state: initial_state_f32 = initial_state.to(torch.float32).contiguous() @@ -195,15 +215,18 @@ def forward( initial_state_cute = _dummy_cache[q.device]["state_cute"] if output_final_state: - final_state_f32 = torch.zeros(num_seqs, H, D, D, dtype=torch.float32, device=q.device) + final_state_f32 = torch.zeros(num_seqs, HV, D, D, dtype=torch.float32, device=q.device) final_state_cute = from_dlpack(final_state_f32.detach()) else: # Use cached tiny dummy (pointer won't be dereferenced when output_final_state=False) final_state_f32 = None final_state_cute = _dummy_cache[q.device]["state_cute"] - # problem_size: (num_seqs, total_tokens_or_seq_len, H, D) - problem_size = (num_seqs, S, H, D) + # problem_size: (num_seqs, total_tokens_or_seq_len, HV, D) + # After the q/k expand above, all operands share HV as the head count, so the + # WIP kernel receives a consistent single head dimension. + # P1 will change this to pass (num_seqs, S, HQK, HV, D) for native GVA. + problem_size = (num_seqs, S, HV, D) if cache_key in compiled_kernel_cache: compiled_kernel = compiled_kernel_cache[cache_key] @@ -219,6 +242,10 @@ def forward( output_final_state=output_final_state, is_varlen=is_varlen, use_fast_math=USE_FAST_MATH, + # heads_per_group is stored for future P1 native-GVA kernel support. + # The current WIP kernel receives expanded q/k so heads_per_group == 1 + # from its perspective; this field is a no-op until P1. + heads_per_group=heads_per_group, ) compiled_kernel = cute.compile( attn_kernel, @@ -292,24 +319,24 @@ def flash_kda_prefill( r""" Args: q (torch.Tensor): - queries of shape `[B, T, H, K]`. + queries of shape `[B, T, HQK, K]`. k (torch.Tensor): - keys of shape `[B, T, H, K]`. + keys of shape `[B, T, HQK, K]`. v (torch.Tensor): - values of shape `[B, T, H, V]`. + values of shape `[B, T, HV, V]` where ``HV`` is a positive multiple of ``HQK``. g (torch.Tensor): - (forget) gating tensor (in log space!) of shape `[B, T, H, K]`. + (forget) gating tensor (in log space!) of shape `[B, T, HV, K]`. beta (torch.Tensor): - betas of shape `[B, T, H]`. + betas of shape `[B, T, HV]`. scale (Optional[float]): Scale factor for the KDA attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, H, K, V]` for `N` input sequences. + Initial state of shape `[N, HV, K, V]` for `N` input sequences. For equal-length input sequences, `N` equals the batch size `B`. Default: `None`. output_final_state (Optional[bool]): - Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`. use_qk_l2norm_in_kernel (bool): Whether to apply L2norm to the q,k tensor internally. Default: `False`. use_gate_in_kernel (bool): @@ -336,9 +363,9 @@ def flash_kda_prefill( Returns: o (torch.Tensor): - Outputs of shape `[B, T, H, V]`. + Outputs of shape `[B, T, HV, V]`. final_state (torch.Tensor): - Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`. """ assert_blackwell() # initial_state is now supported @@ -369,9 +396,20 @@ def flash_kda_prefill( if not (-5 <= lower_bound < 0): raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") - assert q.shape == k.shape == g.shape, "q, k, g must have the same shape." - assert beta.shape == q.shape[:3], "beta must be of shape (batch size, seq len, num of head)." - assert v.shape == (*q.shape[:3], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)." + # GVA shape validation — mirrors chunk.py / hopper_fused_fwd.py. + B_outer, T_outer, HQK_outer, D_outer = q.shape + HV_outer = v.shape[2] + assert q.shape == k.shape, "q and k must have the same shape." + assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share (B, T) dimensions." + assert HQK_outer > 0 and HV_outer > 0 and HV_outer % HQK_outer == 0, ( + f"v/g head count (HV={HV_outer}) must be a positive multiple of q/k head count (HQK={HQK_outer})." + ) + assert g.shape == (B_outer, T_outer, HV_outer, D_outer), ( + f"g must be [B,T,HV,D]=({B_outer},{T_outer},{HV_outer},{D_outer}), got {tuple(g.shape)}." + ) + assert beta.shape[:3] == (B_outer, T_outer, HV_outer), ( + f"beta must have shape [B,T,HV,...], got {tuple(beta.shape)}." + ) assert q.dtype == k.dtype == v.dtype == torch.bfloat16, "q, k, v must be in bfloat16." assert q.shape[-1] == k.shape[-1] == v.shape[-1] == 128, "Currently we only support head dim of 128 for KDA" if scale is None: diff --git a/cula/ops/kda_fully_fused_wip.py b/cula/ops/kda_fully_fused_wip.py index adcf809..c05c645 100644 --- a/cula/ops/kda_fully_fused_wip.py +++ b/cula/ops/kda_fully_fused_wip.py @@ -134,6 +134,9 @@ def __init__( num_regs_cuda: int = 224, num_regs_subchunk: int = 192, num_regs_others: int = 64, # Optimized: best config from comprehensive sweep + # GVA: ratio of v-heads to qk-heads. 1 = standard MHA. + # P1 will thread this through the kernel tile scheduler and memory layouts. + heads_per_group: int = 1, ): assert_blackwell() # make scale a constant @@ -143,6 +146,7 @@ def __init__( self.output_final_state = output_final_state self.is_varlen = is_varlen self.use_fast_math = use_fast_math + self.heads_per_group = heads_per_group self.chunk_size = chunk_size self.subchunk_size = 16 From f8edbd1f30aac67cfaffe65a75d14ee3768a436f Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 16:31:28 +0800 Subject: [PATCH 06/14] benchmark and test --- cula/ops/kda_fully_fused_wip.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cula/ops/kda_fully_fused_wip.py b/cula/ops/kda_fully_fused_wip.py index c05c645..adcf809 100644 --- a/cula/ops/kda_fully_fused_wip.py +++ b/cula/ops/kda_fully_fused_wip.py @@ -134,9 +134,6 @@ def __init__( num_regs_cuda: int = 224, num_regs_subchunk: int = 192, num_regs_others: int = 64, # Optimized: best config from comprehensive sweep - # GVA: ratio of v-heads to qk-heads. 1 = standard MHA. - # P1 will thread this through the kernel tile scheduler and memory layouts. - heads_per_group: int = 1, ): assert_blackwell() # make scale a constant @@ -146,7 +143,6 @@ def __init__( self.output_final_state = output_final_state self.is_varlen = is_varlen self.use_fast_math = use_fast_math - self.heads_per_group = heads_per_group self.chunk_size = chunk_size self.subchunk_size = 16 From 485774cb3ea167ad4e89768a78934a67ea7b55b3 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 16:50:54 +0800 Subject: [PATCH 07/14] benchmark and test --- cula/kda/blackwell_fused_fwd.py | 80 +++++++++------------------------ 1 file changed, 21 insertions(+), 59 deletions(-) diff --git a/cula/kda/blackwell_fused_fwd.py b/cula/kda/blackwell_fused_fwd.py index ea3b994..291d966 100644 --- a/cula/kda/blackwell_fused_fwd.py +++ b/cula/kda/blackwell_fused_fwd.py @@ -68,22 +68,11 @@ def forward( chunk_indices: torch.IntTensor | None = None, ): chunk_size = 64 - - # GVA: q/k are in HQK head space; v/g/beta/o/state are in HV head space. - # HV must be a positive multiple of HQK. When HV == HQK this is standard MHA. - assert q.shape == k.shape, "q and k must have the same shape." - assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share (B, T) dimensions." - B, S, HQK, D = q.shape - HV = v.shape[2] - assert HQK > 0 and HV > 0 and HV % HQK == 0, ( - f"v/g head count (HV={HV}) must be a positive multiple of q/k head count (HQK={HQK})." - ) - assert g.shape == (B, S, HV, D), f"g must be [B,S,HV,D]=({B},{S},{HV},{D}), got {tuple(g.shape)}." - assert beta.shape[:3] == (B, S, HV), f"beta must have shape [B,S,HV,...], got {tuple(beta.shape)}." - heads_per_group = HV // HQK + assert q.shape[-2] == v.shape[-2] == k.shape[-2], "Number of heads must be the same for q, k, v." global compiled_kernel_cache + B, S, H, D = q.shape is_varlen = cu_seqlens is not None if is_varlen: assert B == 1, "For varlen, batch size must be 1. Flatten variable-length inputs first." @@ -131,29 +120,20 @@ def forward( q, q_rstd = l2norm_fwd(q) k, k_rstd = l2norm_fwd(k) - # GVA compat: the WIP CuTe kernel currently assumes a single head count for - # all operands. Until P1 (native GVA in the kernel), expand q/k from HQK to - # HV heads so the kernel sees consistent head dimensions everywhere. - # P1 will remove this expansion and thread HQK/HV through the kernel directly. - if heads_per_group > 1: - q = q.repeat_interleave(heads_per_group, dim=2).contiguous() - k = k.repeat_interleave(heads_per_group, dim=2).contiguous() - q_cute = from_dlpack(q.detach()) k_cute = from_dlpack(k.detach()) v_cute = from_dlpack(v.detach()) g_cute = from_dlpack(g.detach()) beta_cute = from_dlpack(beta.detach()) - # Output lives in HV head space (same as v). - o = torch.empty_like(v) + # FIXME: support return final_states + o = torch.empty_like(q) o_cute = from_dlpack(o.detach()) stream = cutlass_torch.default_stream() has_initial_state = initial_state is not None - # Include HQK and HV so that MHA and GVA compilations are cached separately. - cache_key = (has_initial_state, output_final_state, safe_gate, is_varlen, scale, chunk_size, HQK, HV, D, USE_FAST_MATH) + cache_key = (has_initial_state, output_final_state, safe_gate, is_varlen, scale, chunk_size, D, USE_FAST_MATH) # Prepare cu_seqlens as int32 for kernel if is_varlen: @@ -204,7 +184,7 @@ def forward( dc["workspace_cute"] = from_dlpack(ws_buf.detach()) workspace_cute = dc["workspace_cute"] - # State shape: [num_seqs, HV, D, D] — state is per v-head under GVA. + # State shape: [num_seqs, H, D, D] # Prepare initial_state and final_state tensors if has_initial_state: initial_state_f32 = initial_state.to(torch.float32).contiguous() @@ -215,18 +195,15 @@ def forward( initial_state_cute = _dummy_cache[q.device]["state_cute"] if output_final_state: - final_state_f32 = torch.zeros(num_seqs, HV, D, D, dtype=torch.float32, device=q.device) + final_state_f32 = torch.zeros(num_seqs, H, D, D, dtype=torch.float32, device=q.device) final_state_cute = from_dlpack(final_state_f32.detach()) else: # Use cached tiny dummy (pointer won't be dereferenced when output_final_state=False) final_state_f32 = None final_state_cute = _dummy_cache[q.device]["state_cute"] - # problem_size: (num_seqs, total_tokens_or_seq_len, HV, D) - # After the q/k expand above, all operands share HV as the head count, so the - # WIP kernel receives a consistent single head dimension. - # P1 will change this to pass (num_seqs, S, HQK, HV, D) for native GVA. - problem_size = (num_seqs, S, HV, D) + # problem_size: (num_seqs, total_tokens_or_seq_len, H, D) + problem_size = (num_seqs, S, H, D) if cache_key in compiled_kernel_cache: compiled_kernel = compiled_kernel_cache[cache_key] @@ -242,10 +219,6 @@ def forward( output_final_state=output_final_state, is_varlen=is_varlen, use_fast_math=USE_FAST_MATH, - # heads_per_group is stored for future P1 native-GVA kernel support. - # The current WIP kernel receives expanded q/k so heads_per_group == 1 - # from its perspective; this field is a no-op until P1. - heads_per_group=heads_per_group, ) compiled_kernel = cute.compile( attn_kernel, @@ -319,24 +292,24 @@ def flash_kda_prefill( r""" Args: q (torch.Tensor): - queries of shape `[B, T, HQK, K]`. + queries of shape `[B, T, H, K]`. k (torch.Tensor): - keys of shape `[B, T, HQK, K]`. + keys of shape `[B, T, H, K]`. v (torch.Tensor): - values of shape `[B, T, HV, V]` where ``HV`` is a positive multiple of ``HQK``. + values of shape `[B, T, H, V]`. g (torch.Tensor): - (forget) gating tensor (in log space!) of shape `[B, T, HV, K]`. + (forget) gating tensor (in log space!) of shape `[B, T, H, K]`. beta (torch.Tensor): - betas of shape `[B, T, HV]`. + betas of shape `[B, T, H]`. scale (Optional[float]): Scale factor for the KDA attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, HV, K, V]` for `N` input sequences. + Initial state of shape `[N, H, K, V]` for `N` input sequences. For equal-length input sequences, `N` equals the batch size `B`. Default: `None`. output_final_state (Optional[bool]): - Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`. + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. use_qk_l2norm_in_kernel (bool): Whether to apply L2norm to the q,k tensor internally. Default: `False`. use_gate_in_kernel (bool): @@ -363,9 +336,9 @@ def flash_kda_prefill( Returns: o (torch.Tensor): - Outputs of shape `[B, T, HV, V]`. + Outputs of shape `[B, T, H, V]`. final_state (torch.Tensor): - Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`. + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. """ assert_blackwell() # initial_state is now supported @@ -396,20 +369,9 @@ def flash_kda_prefill( if not (-5 <= lower_bound < 0): raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") - # GVA shape validation — mirrors chunk.py / hopper_fused_fwd.py. - B_outer, T_outer, HQK_outer, D_outer = q.shape - HV_outer = v.shape[2] - assert q.shape == k.shape, "q and k must have the same shape." - assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share (B, T) dimensions." - assert HQK_outer > 0 and HV_outer > 0 and HV_outer % HQK_outer == 0, ( - f"v/g head count (HV={HV_outer}) must be a positive multiple of q/k head count (HQK={HQK_outer})." - ) - assert g.shape == (B_outer, T_outer, HV_outer, D_outer), ( - f"g must be [B,T,HV,D]=({B_outer},{T_outer},{HV_outer},{D_outer}), got {tuple(g.shape)}." - ) - assert beta.shape[:3] == (B_outer, T_outer, HV_outer), ( - f"beta must have shape [B,T,HV,...], got {tuple(beta.shape)}." - ) + assert q.shape == k.shape == g.shape, "q, k, g must have the same shape." + assert beta.shape == q.shape[:3], "beta must be of shape (batch size, seq len, num of head)." + assert v.shape == (*q.shape[:3], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)." assert q.dtype == k.dtype == v.dtype == torch.bfloat16, "q, k, v must be in bfloat16." assert q.shape[-1] == k.shape[-1] == v.shape[-1] == 128, "Currently we only support head dim of 128 for KDA" if scale is None: From c6a492bede8df953b65f56a93a343ade07300b66 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 19:55:47 +0800 Subject: [PATCH 08/14] benchmark and test --- benchmarks/bench_recompute_wu.py | 241 ++++++++++++++++++++++++++++++- benchmarks/utils.py | 87 ----------- cula/kda/chunk_intra.py | 152 ++++++++----------- 3 files changed, 300 insertions(+), 180 deletions(-) diff --git a/benchmarks/bench_recompute_wu.py b/benchmarks/bench_recompute_wu.py index c7d0873..c713a15 100644 --- a/benchmarks/bench_recompute_wu.py +++ b/benchmarks/bench_recompute_wu.py @@ -27,7 +27,8 @@ from fla.ops.kda.wy_fast import recompute_w_u_fwd as fla_recompute_w_u_fwd import cula.cudac as cula_cuda -from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs +from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs, prepare_intra_inputs_gva +from cula.kda.chunk_intra import chunk_kda_fwd_intra as cula_chunk_kda_fwd_intra # Constant params B, H, D = 2, 64, 128 @@ -40,6 +41,7 @@ VARIANCE = 1.0 DISABLE_RECOMPUTE = False # Whether to disable recompute (compute QG in forward) +GROUP_SIZE = 1 # GVA group size: HV = GROUP_SIZE * H. 1 means no GVA. def accuracy_stats(a, b): @@ -97,7 +99,7 @@ def run_fla_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, disa def run_cula_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, chunk_size, disable_recompute): - """Run cuLA recompute_w_u_cuda.""" + """Run cuLA recompute_w_u_cuda (MHA: all tensors share the same head dim).""" w = torch.empty_like(k) u = torch.empty_like(v) qg = torch.empty_like(q) if disable_recompute else None @@ -109,6 +111,76 @@ def run_cula_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, chu return w, u, qg, kg +# ============================================================================== +# GVA helpers +# ============================================================================== + +def prepare_recompute_wu_inputs_gva(B, T, HQK, HV, D, device, cu_seqlens=None, chunk_size=BT): + """Prepare GVA inputs for recompute_w_u benchmarking. + + Produces Akk via cuLA's GVA-aware chunk_kda_fwd_intra so the tensor lives in + HV-head space (shape [B, T, HV, BT]). Both FLA (k replicated to HV) and cuLA + (k compact in HQK) receive the same Akk. + """ + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( + B, T, HQK, HV, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + ) + + # Use cuLA GVA intra to produce Akk in HV space. + _, _, _, _, _, Akk = cula_chunk_kda_fwd_intra( + q=q, + k=k, + v=v, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=True, + disable_recompute=False, + ) + + return q, k, v, g, beta, Akk, cu_seqlens, chunk_indices + + +def run_fla_recompute_wu_gva(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, disable_recompute, group_size): + """FLA reference for GVA recompute_w_u. + + FLA does not natively support GVA, so k and q are replicated to HV heads via + repeat_interleave before the call — mirroring the strategy in bench_kda_chunk_intra.py. + """ + k_hv = k.repeat_interleave(group_size, dim=2).contiguous() + q_hv = q.repeat_interleave(group_size, dim=2).contiguous() + return fla_recompute_w_u_fwd( + k=k_hv, + v=v, + beta=beta, + A=Akk, + q=q_hv if disable_recompute else None, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + + +def run_cula_recompute_wu_gva(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, chunk_size, disable_recompute): + """Run cuLA recompute_w_u_cuda with GVA layout. + + k/q live in HQK head space; v/gk/beta/Akk/w/u/kg/qg all live in HV head space. + """ + B_flat, T, HV, Dv = v.shape + w = torch.empty(B_flat, T, HV, Dv, device=k.device, dtype=k.dtype) + u = torch.empty_like(v) + qg = torch.empty(B_flat, T, HV, Dv, device=q.device, dtype=q.dtype) if disable_recompute else None + kg = torch.empty(B_flat, T, HV, Dv, device=k.device, dtype=k.dtype) if gk is not None else None + + cula_cuda.recompute_w_u_cuda( + k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg + ) + return w, u, qg, kg + + # ============================================================================== # Uniform seqlen benchmark # ============================================================================== @@ -241,6 +313,154 @@ def benchmark_recompute_wu_varlen(): print("─" * 100) +# ============================================================================== +# GVA uniform seqlen benchmark +# ============================================================================== +def benchmark_recompute_wu_gva_uniform(group_size: int): + """Benchmark GVA (HV > HQK) recompute_w_u: cuLA vs FLA Triton (k replicated to HV). + + FLA does not natively support GVA, so the reference replicates k/q along the + head axis to HV before calling recompute_w_u_fwd. + """ + device = torch.device("cuda") + chunk_size = BT + HQK = H + HV = HQK * group_size + T_vals = [512, 1024, 4096, 8192, 16384, 32768] + + print("=" * 100) + print( + f" GVA Uniform RecomputeWU Benchmark: cuLA vs FLA Triton " + f"B={B} HQK={HQK} HV={HV} (group_size={group_size}) D={D} disable_recompute={DISABLE_RECOMPUTE}" + ) + print("=" * 100) + print( + f"{'B':>4} {'T':>7} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" + ) + print("─" * 100) + + for T in T_vals: + seq_lens = [T] * B + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + q, k, v, g, beta, Akk, cu_seqlens, chunk_indices = prepare_recompute_wu_inputs_gva( + B, T, HQK, HV, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + ) + + # Accuracy: run once and compare + w_fla, u_fla, qg_fla, kg_fla = run_fla_recompute_wu_gva( + k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, DISABLE_RECOMPUTE, group_size + ) + w_cula, u_cula, qg_cula, kg_cula = run_cula_recompute_wu_gva( + k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE + ) + + stats = {} + for name, t_fla, t_cula in [ + ("w", w_fla, w_cula), + ("u", u_fla, u_cula), + ("qg", qg_fla, qg_cula), + ("kg", kg_fla, kg_cula), + ]: + if t_fla is not None and t_cula is not None: + stats[name] = accuracy_stats(t_fla, t_cula) + rmse = max(s[0] for s in stats.values()) + rel_max = max(s[1] for s in stats.values()) + mean_diff = max(s[2] for s in stats.values()) + + # Performance + ms_fla = triton.testing.do_bench( + lambda: run_fla_recompute_wu_gva( + k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, DISABLE_RECOMPUTE, group_size + ), + ) + ms_cula = triton.testing.do_bench( + lambda: run_cula_recompute_wu_gva( + k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE + ), + ) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + print( + f"{B:>4} {T:>7} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" + ) + + print("─" * 100) + + +# ============================================================================== +# GVA varlen benchmark +# ============================================================================== +def benchmark_recompute_wu_gva_varlen(group_size: int): + """Varlen GVA benchmark for recompute_w_u: cuLA vs FLA Triton (k replicated to HV).""" + device = torch.device("cuda") + chunk_size = BT + HQK = H + HV = HQK * group_size + total_len_vals = [8192, 16384, 32768, 65536] + + print() + print("=" * 110) + print( + f" GVA Varlen RecomputeWU Benchmark: cuLA vs FLA Triton " + f"NUM_SEQS={NUM_SEQS} HQK={HQK} HV={HV} (group_size={group_size}) D={D} disable_recompute={DISABLE_RECOMPUTE}" + ) + print("=" * 110) + print( + f"{'total_len':>10} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" + ) + print("─" * 110) + + for total_len in total_len_vals: + seq_lens = generate_random_seq_lens(NUM_SEQS, total_len, MIN_SEQ_LEN, VARIANCE, SEED) + T = total_len + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + q, k, v, g, beta, Akk, cu_seqlens, chunk_indices = prepare_recompute_wu_inputs_gva( + 1, T, HQK, HV, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + ) + + # Accuracy + w_fla, u_fla, qg_fla, kg_fla = run_fla_recompute_wu_gva( + k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, DISABLE_RECOMPUTE, group_size + ) + w_cula, u_cula, qg_cula, kg_cula = run_cula_recompute_wu_gva( + k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE + ) + + stats = {} + for name, t_fla, t_cula in [ + ("w", w_fla, w_cula), + ("u", u_fla, u_cula), + ("qg", qg_fla, qg_cula), + ("kg", kg_fla, kg_cula), + ]: + if t_fla is not None and t_cula is not None: + stats[name] = accuracy_stats(t_fla, t_cula) + rmse = max(s[0] for s in stats.values()) + rel_max = max(s[1] for s in stats.values()) + mean_diff = max(s[2] for s in stats.values()) + + # Performance + ms_fla = triton.testing.do_bench( + lambda: run_fla_recompute_wu_gva( + k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, DISABLE_RECOMPUTE, group_size + ), + ) + ms_cula = triton.testing.do_bench( + lambda: run_cula_recompute_wu_gva( + k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE + ), + ) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + print( + f"{total_len:>10} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" + ) + + print("─" * 110) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="bench_recompute_wu: cuLA vs FLA Triton for recompute_w_u") parser.add_argument( @@ -248,11 +468,24 @@ def benchmark_recompute_wu_varlen(): action="store_true", help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) + parser.add_argument( + "--group_size", + type=int, + default=1, + help="GVA group size: HV = group_size * H. 1 (default) runs the non-GVA benchmark. " + "Values > 1 run GVA benchmarks comparing cuLA (k in HQK space) vs FLA (k replicated to HV).", + ) args = parser.parse_args() if args.disable_recompute: DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") - benchmark_recompute_wu_uniform() - benchmark_recompute_wu_varlen() + GROUP_SIZE = args.group_size + + if GROUP_SIZE == 1: + benchmark_recompute_wu_uniform() + benchmark_recompute_wu_varlen() + else: + benchmark_recompute_wu_gva_uniform(GROUP_SIZE) + benchmark_recompute_wu_gva_varlen(GROUP_SIZE) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index fcd9f17..05198ab 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -324,93 +324,6 @@ def prepare_safe_gate_inputs( ) -def prepare_safe_gate_inputs_gva( - batch_size, - T, - HQK, - HV, - D, - device, - cu_seqlens=None, - chunk_size=CHUNK_SIZE, - seed=SEED, - has_init_state=False, -): - """Prepare native GVA inputs for cuLA chunk_kda (q/k stay in HQK head space). - - Unlike ``prepare_safe_gate_inputs`` which expands q/k to HV heads, - this function returns q/k in the smaller HQK space so that cuLA can - exercise its native GVA path (HQK < HV). - - GVA layout: - q, k : (batch_flat, T, HQK, D) — l2-normalised, no grad - v : (batch_flat, T, HV, D) - g, beta : (batch_flat, T, HV, D / ∅) — after kda_gate_chunk_cumsum - A_log : (HV,) - dt_bias : (HV * D,) - init_state : (num_seqs, HV, D, D) or None - - The ``g`` tensor is the *pre-processed* chunk-cumsum gate, ready to be fed - directly into chunk_kda with ``use_gate_in_kernel=True``. - """ - assert HV > 0 and HQK > 0 and HV % HQK == 0, f"HV ({HV}) must be a positive multiple of HQK ({HQK})" - - dtype = torch.bfloat16 - scale = D ** (-0.5) - - set_seed(seed) - - q = torch.randn(batch_size, T, HQK, D, dtype=dtype, device=device).requires_grad_(False) - k = torch.randn(batch_size, T, HQK, D, dtype=dtype, device=device).requires_grad_(False) - v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) - g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) - beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) - - # l2-normalise q/k in the HQK head space (native GVA: one QK head per group). - q, _ = l2norm_fwd(q) - k, _ = l2norm_fwd(k) - - A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) - dt_bias = torch.randn(HV * D, dtype=torch.float, device=device).requires_grad_(False) - - # Flatten to batch_size=1 for cu_seqlens compatibility. - if batch_size != 1: - q, k, v, g, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) - - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None - - g = kda_gate_chunk_cumsum( - g=g, - A_log=A_log, - dt_bias=dt_bias, - scale=RCP_LN2, - chunk_size=chunk_size, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - lower_bound=-5.0, - ) - - init_state = None - if has_init_state: - num_seqs = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch_size - init_state = torch.randn(num_seqs, HV, D, D, dtype=torch.float, device=device).requires_grad_(False) - - return dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - A_log=A_log, - dt_bias=dt_bias, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - init_state=init_state, - lower_bound=-5.0, - ) - - def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED): """Prepare preprocessed inputs ready for chunk_kda_fwd_intra. diff --git a/cula/kda/chunk_intra.py b/cula/kda/chunk_intra.py index 48e1f38..a9c620a 100644 --- a/cula/kda/chunk_intra.py +++ b/cula/kda/chunk_intra.py @@ -383,7 +383,6 @@ def chunk_kda_bwd_kernel_intra( B, T, H: tl.constexpr, - HV: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, @@ -393,12 +392,8 @@ def chunk_kda_bwd_kernel_intra( SAFE_GATE: tl.constexpr, USE_GATHER: tl.constexpr, ): - # H = HQK (qk-head count); q/k/dq/dk use H stride. - # HV = v-head count; g/beta/dAqk/dAkk/dg/db use HV stride. - # Grid enumerates B * HV: i_h is the v-head index, i_hqk is the qk-head. i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_b, i_h = i_bh // HV, i_bh % HV - i_hqk = i_h // (HV // H) + i_b, i_h = i_bh // H, i_bh % H i_k, i_i = i_kc // NC, i_kc % NC all = B * T @@ -416,38 +411,38 @@ def chunk_kda_bwd_kernel_intra( o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K - q += (bos * H + i_hqk) * K - k += (bos * H + i_hqk) * K - g += (bos * HV + i_h) * K - beta += bos * HV + i_h - - dAqk += (bos * HV + i_h) * BT - dAkk += (bos * HV + i_h) * BT - dq += (bos * H + i_hqk) * K - dq2 += (bos * H + i_hqk) * K - dk += (bos * H + i_hqk) * K - dk2 += (bos * H + i_hqk) * K - dg += (bos * HV + i_h) * K - dg2 += (bos * HV + i_h) * K - db += (i_k * all + bos) * HV + i_h - - p_g = tl.make_block_ptr(g, (T, K), (HV * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + beta += bos * H + i_h + + dAqk += (bos * H + i_h) * BT + dAkk += (bos * H + i_h) * BT + dq += (bos * H + i_h) * K + dq2 += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dk2 += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + dg2 += (bos * H + i_h) * K + db += (i_k * all + bos) * H + i_h + + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) - p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_ti,), (BC,), (0,)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) b_b = tl.load(p_b, boundary_check=(0,)) b_dq2 = tl.zeros([BC, BK], dtype=tl.float32) b_dk2 = tl.zeros([BC, BK], dtype=tl.float32) if i_i > 0: - p_gn = g + i_ti * HV * K + o_k + p_gn = g + i_ti * H * K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] for i_j in range(0, i_i): p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g, (T, K), (HV * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (HV * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) - p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (HV * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) @@ -464,9 +459,9 @@ def chunk_kda_bwd_kernel_intra( o_i = tl.arange(0, BC) m_dA = (i_ti + o_i) < T - o_dA = (i_ti + o_i) * HV * BT + i_i * BC + o_dA = (i_ti + o_i) * H * BT + i_i * BC p_kj = k + i_ti * H * K + o_k - p_gkj = g + i_ti * HV * K + o_k + p_gkj = g + i_ti * H * K + o_k p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) @@ -480,8 +475,8 @@ def chunk_kda_bwd_kernel_intra( p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H * K + o_k b_gn = tl.load(p_gn, mask=m_k, other=0)[None, :] - p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (HV * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) - p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (HV * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) b_dAqk_diag_qk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) b_dAkk_diag_qk = tl.load(p_dAkk, boundary_check=(0, 1)).to(tl.float32) @@ -512,15 +507,15 @@ def chunk_kda_bwd_kernel_intra( b_dq2 += tl.where(m_i, b_dAqk[:, None] * b_kj[None, :] * b_gqk, 0.0) b_dk2 += tl.where(m_i, b_dAkk[:, None] * b_kj[None, :] * b_gqk, 0.0) - p_kj += H * K # k stride: HQK - p_gkj += HV * K # g stride: HV + p_kj += H * K + p_gkj += H * K b_db = tl.sum(b_dk2 * b_k, 1) b_dk2 *= b_b[:, None] - p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) # H = HQK - p_dq2 = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) # H = HQK - p_db = tl.make_block_ptr(db, (T,), (HV,), (i_ti,), (BC,), (0,)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dq2 = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T,), (H,), (i_ti,), (BC,), (0,)) b_dg2 = b_q * b_dq2 b_dq2 = b_dq2 + tl.load(p_dq, boundary_check=(0, 1)) @@ -532,16 +527,16 @@ def chunk_kda_bwd_kernel_intra( NC = min(NC, tl.cdiv(T - i_t * BT, BC)) if i_i < NC - 1: - p_gn = g + (min(i_ti + BC, T) - 1) * HV * K + o_k + p_gn = g + (min(i_ti + BC, T) - 1) * H * K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] for i_j in range(i_i + 1, NC): p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g, (T, K), (HV * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_t * BT + i_j * BC,), (BC,), (0,)) - p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, HV * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) - p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, HV * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_j * BC,), (BC,), (0,)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) # [BC] b_b = tl.load(p_b, boundary_check=(0,)) # [BC, BK] @@ -563,25 +558,25 @@ def chunk_kda_bwd_kernel_intra( b_dkt += tl.dot(b_dAqk, b_qg) b_dkt += tl.dot(b_dAkk, b_kbg) b_dkt *= exp2(b_gn - b_g) - o_dA = i_ti * HV * BT + i_i * BC + o_i + o_dA = i_ti * H * BT + i_i * BC + o_i p_qj = q + i_ti * H * K + o_k p_kj = k + i_ti * H * K + o_k - p_gkj = g + i_ti * HV * K + o_k - p_bj = beta + i_ti * HV + p_gkj = g + i_ti * H * K + o_k + p_bj = beta + i_ti * H if SAFE_GATE: if USE_GATHER: b_gn = gather(b_g, tl.full([1, BK], min(BC // 2, T - i_ti - 1), dtype=tl.int16), axis=0) else: - p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * HV * K + o_k + p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H * K + o_k b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) - p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_ti,), (BC,), (0,)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) b_b = tl.load(p_b, boundary_check=(0,)) - p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, HV * BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) - p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, HV * BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H * BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H * BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) b_dAqk_diag_kk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) b_dAkk_diag_kk = tl.load(p_dAkk, boundary_check=(0, 1)).to(tl.float32) @@ -603,8 +598,8 @@ def chunk_kda_bwd_kernel_intra( else: for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] - b_dAqk = tl.load(dAqk + o_dA + j * HV * BT) - b_dAkk = tl.load(dAkk + o_dA + j * HV * BT) + b_dAqk = tl.load(dAqk + o_dA + j * H * BT) + b_dAkk = tl.load(dAkk + o_dA + j * H * BT) # [BK,] b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) b_kbj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) * tl.load(p_bj) @@ -615,14 +610,14 @@ def chunk_kda_bwd_kernel_intra( b_dkt += tl.where(m_i, b_dAqk[:, None] * b_qj[None, :] * b_gkq, 0.0) b_dkt += tl.where(m_i, b_dAkk[:, None] * b_kbj[None, :] * b_gkq, 0.0) - p_qj += H * K # k-head stride for q - p_kj += H * K # k-head stride for k - p_gkj += HV * K # v-head stride for g - p_bj += HV # v-head stride for beta - p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) # H = HQK - p_dk2 = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) # H = HQK - p_dg = tl.make_block_ptr(dg, (T, K), (HV * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) - p_dg2 = tl.make_block_ptr(dg2, (T, K), (HV * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_qj += H * K + p_kj += H * K + p_gkj += H * K + p_bj += H + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_dg2 += (b_dk2 - b_dkt) * b_k + tl.load(p_dg, boundary_check=(0, 1)) b_dk2 += tl.load(p_dk, boundary_check=(0, 1)) @@ -764,22 +759,7 @@ def chunk_kda_fwd_intra( unified_gref: bool = False, # Set True for ~5% extra perf (slightly lower precision) ): assert safe_gate, "Only safe_gate=True is supported in chunk_kda_fwd_intra for now" - # GVA support: Q/K have head-dim HQK; V/g/beta/Aqk/Akk/w/u/kg/qg have head-dim HV. - # Pre-GVA behaviour is preserved when HV == HQK. - B, T, HQK, K = k.shape - HV = v.shape[2] - assert v.shape[0] == B and v.shape[1] == T, ( - f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}" - ) - assert HV > 0 and HQK > 0 and HV % HQK == 0, ( - f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" - ) - if gk is not None: - assert gk.shape[0] == B and gk.shape[1] == T and gk.shape[2] == HV, ( - f"gk shape must be (B, T, HV={HV}, K); got {tuple(gk.shape)}" - ) - if beta is not None: - assert beta.shape[-1] == HV, f"beta last dim must equal HV={HV}; got {tuple(beta.shape)}" + B, T, H, K = k.shape BT = chunk_size if cu_seqlens is None: @@ -793,20 +773,18 @@ def chunk_kda_fwd_intra( "cu_seqlens and chunk_indices must be int32 for cuda impl" ) - # Aqk/Akk are produced per v-head (they live in v-head space because g/beta are per v-head). - Aqk = torch.empty(B, T, HV, BT, device=k.device, dtype=k.dtype) - Akk = torch.empty(B, T, HV, BT, device=k.device, dtype=k.dtype) + Aqk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) + Akk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) tile_counter = torch.zeros(1, dtype=torch.int32, device=q.device) cula_cuda.chunk_kda_fwd_intra_cuda( q, k, gk, beta, cu_seqlens, chunk_indices, Aqk, Akk, tile_counter, scale, chunk_size, use_tf32_inverse, unified_gref ) - # w/u/kg/qg are all per-v-head outputs. - w = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) + w = torch.empty_like(k) u = torch.empty_like(v) - qg = torch.empty(B, T, HV, K, device=q.device, dtype=q.dtype) if disable_recompute else None - kg = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) if gk is not None else None + qg = torch.empty_like(q) if disable_recompute else None + kg = torch.empty_like(k) if gk is not None else None cula_cuda.recompute_w_u_cuda( k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg @@ -831,9 +809,7 @@ def chunk_kda_bwd_intra( chunk_size: int = 64, safe_gate: bool = False, ): - B, T, HQK, K = k.shape - # g/beta/dAqk/dAkk/dg/db live in v-head space (HV). - HV = g.shape[2] + B, T, H, K = k.shape BT = chunk_size BC = min(16, BT) BK = min(32, triton.next_power_of_2(K)) @@ -844,12 +820,11 @@ def chunk_kda_bwd_intra( NC = triton.cdiv(BT, BC) NK = triton.cdiv(K, BK) - # dq2/dk2 are in qk-head space; db2/dg2 are in v-head space. dq2 = torch.empty_like(q) dk2 = torch.empty_like(k) db2 = beta.new_empty(NK, *beta.shape, dtype=torch.float) dg2 = torch.empty_like(dg, dtype=torch.float) - grid = (NK * NC, NT, B * HV) + grid = (NK * NC, NT, B * H) chunk_kda_bwd_kernel_intra[grid]( q=q, k=k, @@ -868,8 +843,7 @@ def chunk_kda_bwd_intra( chunk_indices=chunk_indices, B=B, T=T, - H=HQK, - HV=HV, + H=H, K=K, BT=BT, BC=BC, @@ -883,4 +857,4 @@ def chunk_kda_bwd_intra( db = db2.sum(0).add_(db) dg = dg2 - return dq, dk, db, dg + return dq, dk, db, dg \ No newline at end of file From e6c579322134bd35cef45c90e2f888d3f5211004 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 19:58:29 +0800 Subject: [PATCH 09/14] benchmark and test --- benchmarks/bench_kda.py | 227 +++++++++-------- benchmarks/bench_kda_fwd_bwd_e2e.py | 363 +++++++++++++++------------- cula/kda/chunk.py | 13 +- cula/kda/chunk_bwd.py | 67 +++-- cula/kda/chunk_fwd.py | 9 +- cula/utils.py | 2 +- tests/test_kda.py | 187 -------------- 7 files changed, 336 insertions(+), 532 deletions(-) diff --git a/benchmarks/bench_kda.py b/benchmarks/bench_kda.py index 50b7399..dc31d11 100644 --- a/benchmarks/bench_kda.py +++ b/benchmarks/bench_kda.py @@ -15,7 +15,7 @@ """ bench_kda.py — Benchmark: cuLA CuTe DSL vs FLA Triton baseline - for chunk_kda (KDA training, fwd+bwd) + for chunk_kda (KDA forward) Compares: - Accuracy: RMSE, relative max diff between cuLA and FLA outputs @@ -25,13 +25,8 @@ - Fixed-length: B=1, B=2 with various T - Varlen: ~20 seqs with 2-3x length variation -H (number of Q/K heads) is a module-level constant; HV (number of V heads) -defaults to H and can be overridden globally via --hv to run every config in -GVA mode. In GVA mode cuLA receives native HQK q/k; FLA receives q/k -expanded to HV heads. HV must be a positive multiple of H. - Usage: - python bench_kda.py [--mode fixed|varlen|both] [--hv HV] [--ncu] + python bench_kda.py [--mode fixed|varlen|both] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda.py --mode varlen --ncu @@ -53,7 +48,6 @@ build_varlen_configs, exclusive_cumsum, prepare_safe_gate_inputs, - prepare_safe_gate_inputs_gva, set_seed, ) from cula.kda import chunk_kda as cula_chunk_kda @@ -61,10 +55,7 @@ # ============================================================ # Constants # ============================================================ -# H = QK head count; HV = V head count. HV defaults to H (non-GVA / MHA). -# Override via --hv to run every config in GVA mode (HV must be a multiple of H). H, D = 64, 128 -HV = H WARMUP = 10 N_ITERS = 30 NCU_MODE = False @@ -171,70 +162,74 @@ def check_determinism(H=4, total_T=8192, num_seqs=10, iters=10000): assert torch.equal(state, ref_state), f"State mismatch at iter {i}" -def _prepare_inputs(B, T, cu_seqlens): - """Return (inputs, q_fla, k_fla, q_cula, k_cula). - - Non-GVA (HV == H): all four q/k are the same tensor. - GVA (HV > H) : cuLA gets native HQK q/k; FLA gets q/k expanded to HV. - """ - device = torch.device("cuda") - if HV > H: - inputs = prepare_safe_gate_inputs_gva(B, T, H, HV, D, device, cu_seqlens=cu_seqlens) - q_cula, k_cula = inputs["q"], inputs["k"] # [B_flat, T, H, D] - q_fla = q_cula.repeat_interleave(HV // H, dim=2).contiguous() # [B_flat, T, HV, D] - k_fla = k_cula.repeat_interleave(HV // H, dim=2).contiguous() - else: - inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens) - q_cula = q_fla = inputs["q"] - k_cula = k_fla = inputs["k"] - return inputs, q_fla, k_fla, q_cula, k_cula - - # ============================================================ # Fixed-length benchmark # ============================================================ def bench_fixed(configs): - gva_note = f"GVA HV={HV} ({HV // H}x)" if HV > H else f"MHA HV=H={H}" - print("\n" + "=" * 110) - print(f" Fixed-Length Benchmark: cuLA CuTe DSL vs FLA Triton {gva_note} disable_recompute={DISABLE_RECOMPUTE}") - print("=" * 110) + print("\n" + "=" * 100) + print(f" Fixed-Length Benchmark: cuLA CuTe DSL vs FLA Triton disable_recompute={DISABLE_RECOMPUTE}") + print("=" * 100) results = [] for B, T in configs: set_seed(SEED) + device = torch.device("cuda") torch.cuda.empty_cache() seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=torch.device("cuda")) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs, q_fla, k_fla, q_cula, k_cula = _prepare_inputs(B, T, cu_seqlens) - v, g, beta = inputs["v"], inputs["g"], inputs["beta"] + inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - _shared = dict(v=v, g=g, beta=beta, scale=scale, A_log=A_log, dt_bias=dt_bias, - init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound) - common_fla = dict(q=q_fla, k=k_fla, **_shared) - common_cula = dict(q=q_cula, k=k_cula, **_shared) + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) - # Accuracy - o_fla, _ = run_kda(**common_fla, fn=fla_chunk_kda) - o_cula, _ = run_kda(**common_cula, fn=cula_chunk_kda) + # Accuracy: compare outputs + o_fla, _ = run_kda(**common, fn=fla_chunk_kda) + o_cula, _ = run_kda(**common, fn=cula_chunk_kda) torch.cuda.synchronize() + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) # Performance - ms_fla = time_kernel(lambda: run_kda(**common_fla, fn=fla_chunk_kda)) - ms_cula = time_kernel(lambda: run_kda(**common_cula, fn=cula_chunk_kda)) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + def fn_fla(**common_kw): + return lambda: run_kda(**common_kw, fn=fla_chunk_kda) + + def fn_cula(**common_kw): + return lambda: run_kda(**common_kw, fn=cula_chunk_kda) - results.append({ - "B": B, "T": T, "H": H, "HV": HV, - "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, - "ms_fla": ms_fla, "ms_cula": ms_cula, "speedup": speedup, - }) + ms_fla = time_kernel(fn_fla(**common)) + ms_cula = time_kernel(fn_cula(**common)) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - del o_fla, o_cula, inputs + r = { + "B": B, + "T": T, + "rmse": rmse, + "rel_max": rel_max, + "mean_diff": mean_diff, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + results.append(r) + # print(f" B={B:2d} T={T:5d} done ({speedup:.2f}x)") + + del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs torch.cuda.empty_cache() return results @@ -244,51 +239,77 @@ def bench_fixed(configs): # Varlen benchmark # ============================================================ def bench_varlen(configs): - gva_note = f"GVA HV={HV} ({HV // H}x)" if HV > H else f"MHA HV=H={H}" - print("\n" + "=" * 110) - print(f" Varlen Benchmark: cuLA CuTe DSL vs FLA Triton {gva_note} disable_recompute={DISABLE_RECOMPUTE}") - print("=" * 110) + print("\n" + "=" * 100) + print(f" Varlen Benchmark: cuLA CuTe DSL vs FLA Triton disable_recompute={DISABLE_RECOMPUTE}") + print("=" * 100) results = [] for seq_lens, total_len, dist in configs: set_seed(SEED) + device = torch.device("cuda") torch.cuda.empty_cache() T = total_len - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=torch.device("cuda")) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs, q_fla, k_fla, q_cula, k_cula = _prepare_inputs(1, T, cu_seqlens) - v, g, beta = inputs["v"], inputs["g"], inputs["beta"] + inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - _shared = dict(v=v, g=g, beta=beta, scale=scale, A_log=A_log, dt_bias=dt_bias, - init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound) - common_fla = dict(q=q_fla, k=k_fla, **_shared) - common_cula = dict(q=q_cula, k=k_cula, **_shared) + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) # Accuracy - o_fla, _ = run_kda(**common_fla, fn=fla_chunk_kda) - o_cula, _ = run_kda(**common_cula, fn=cula_chunk_kda) + o_fla, _ = run_kda(**common, fn=fla_chunk_kda) + o_cula, _ = run_kda(**common, fn=cula_chunk_kda) torch.cuda.synchronize() + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) # Performance - ms_fla = time_kernel(lambda: run_kda(**common_fla, fn=fla_chunk_kda)) - ms_cula = time_kernel(lambda: run_kda(**common_cula, fn=cula_chunk_kda)) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + def fn_fla(**common_kw): + return lambda: run_kda(**common_kw, fn=fla_chunk_kda) - n_seqs = len(seq_lens) - tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min(seq_lens)}..{max(seq_lens)}] avg={T // n_seqs}" + def fn_cula(**common_kw): + return lambda: run_kda(**common_kw, fn=cula_chunk_kda) - results.append({ - "tag": tag, "dist": dist, "T_total": T, "n_seqs": n_seqs, - "H": H, "HV": HV, - "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, - "ms_fla": ms_fla, "ms_cula": ms_cula, "speedup": speedup, - }) + ms_fla = time_kernel(fn_fla(**common)) + ms_cula = time_kernel(fn_cula(**common)) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - del o_fla, o_cula, inputs + n_seqs = len(seq_lens) + min_l, max_l = min(seq_lens), max(seq_lens) + avg_l = T // n_seqs + tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" + + r = { + "tag": tag, + "dist": dist, + "T_total": T, + "n_seqs": n_seqs, + "rmse": rmse, + "rel_max": rel_max, + "mean_diff": mean_diff, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + results.append(r) + # print(f" {tag:45s} done ({speedup:.2f}x)") + + del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs torch.cuda.empty_cache() return results @@ -298,13 +319,11 @@ def bench_varlen(configs): # Report # ============================================================ def print_report(fixed_results, varlen_results): - sep = "=" * 120 + sep = "=" * 110 print(f"\n\n{sep}") print(" BENCHMARK REPORT: chunk_kda") print(" cuLA CuTe DSL vs FLA Triton") - print(f" D={D} dtype=bf16 safe_gate=True disable_recompute={DISABLE_RECOMPUTE}") - gva_note = f"GVA enabled (HV={HV} > H={H}, ratio={HV // H}x)" if HV > H else f"MHA (HV=H={H})" - print(f" {gva_note}") + print(f" H={H} D={D} dtype=bf16 safe_gate=True disable_recompute={DISABLE_RECOMPUTE}") wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") @@ -313,39 +332,32 @@ def print_report(fixed_results, varlen_results): if fixed_results: print("\n [Fixed-Length]") - print(f" {'─' * 110}") + print(f" {'─' * 85}") print( - f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " - f"{'RMSE':>10s} {'rel_max':>10s} │ " - f"{'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s}" + f" {'B':>3s} {'T':>5s} │ {'RMSE':>10s} {'rel_max':>10s}" + f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s}" ) - print(f" {'─' * 110}") + print(f" {'─' * 85}") for r in fixed_results: - gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" print( - f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " + f" {r['B']:3d} {r['T']:5d} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} │ " f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x" ) - print(f" {'─' * 110}") + print(f" {'─' * 85}") if varlen_results: print("\n [Varlen]") - print(f" {'─' * 120}") - print( - f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " - f"{'RMSE':>10s} {'rel_max':>10s} │ " - f"{'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s}" - ) - print(f" {'─' * 120}") + print(f" {'─' * 100}") + print(f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} │ {'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s}") + print(f" {'─' * 100}") for r in varlen_results: - gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" print( - f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " + f" {r['tag']:>45s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} │ " f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x" ) - print(f" {'─' * 120}") + print(f" {'─' * 100}") print(f"\n{sep}\n") @@ -377,15 +389,9 @@ def main(): action="store_true", help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) - parser.add_argument( - "--hv", - type=int, - default=None, - help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H for GVA mode.", - ) args = parser.parse_args() - global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE, HV + global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE if args.ncu: NCU_MODE = True print("[NCU mode] warmup=1, iters=1") @@ -395,12 +401,6 @@ def main(): if args.disable_recompute: DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") - if args.hv is not None: - if args.hv < H or args.hv % H != 0: - raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") - HV = args.hv - if HV > H: - print(f"[GVA] HV={HV} (H={H}, ratio={HV // H}x)") fixed_configs = [ # (B, T) @@ -434,7 +434,6 @@ def main(): varlen_res = bench_varlen(varlen_configs) print_report(fixed_res, varlen_res) - return fixed_res, varlen_res diff --git a/benchmarks/bench_kda_fwd_bwd_e2e.py b/benchmarks/bench_kda_fwd_bwd_e2e.py index a63ce6e..c6b4117 100644 --- a/benchmarks/bench_kda_fwd_bwd_e2e.py +++ b/benchmarks/bench_kda_fwd_bwd_e2e.py @@ -29,13 +29,8 @@ - forward: forward pass only - e2e: forward + backward (end-to-end) -H (number of Q/K heads) is a module-level constant; HV (number of V heads) -defaults to H and can be overridden globally via --hv to run every config in -GVA mode. In GVA mode cuLA receives native HQK q/k; FLA receives q/k -expanded to HV heads. HV must be a positive multiple of H. - Usage: - python bench_kda_fwd_bwd_e2e.py [--mode fixed|varlen|both] [--phase forward|e2e] [--hv HV] [--ncu] + python bench_kda_fwd_bwd_e2e.py [--mode fixed|varlen|both] [--phase forward|e2e] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda_fwd_bwd_e2e.py --mode varlen --ncu @@ -58,7 +53,6 @@ exclusive_cumsum, generate_random_seq_lens, prepare_safe_gate_inputs, - prepare_safe_gate_inputs_gva, set_seed, ) from cula.kda import chunk_kda as cula_chunk_kda @@ -66,10 +60,7 @@ # ============================================================ # Constants # ============================================================ -# H = QK head count; HV = V head count. HV defaults to H (non-GVA / MHA). -# Override via --hv to run every config in GVA mode (HV must be a multiple of H). H, D = 64, 128 -HV = H WARMUP = 25 N_ITERS = 100 NCU_MODE = False @@ -235,116 +226,113 @@ def check_determinism(num_seqs=5, T=512, iters=20): return True -def _prepare_inputs_e2e(B, T, cu_seqlens): - """Return (inputs, q_fla, k_fla, q_cula, k_cula). - - Non-GVA (HV == H): all four q/k are the same tensor. - GVA (HV > H) : cuLA gets native HQK q/k; FLA gets q/k expanded to HV. - """ - device = torch.device("cuda") - if HV > H: - inputs = prepare_safe_gate_inputs_gva(B, T, H, HV, D, device, cu_seqlens=cu_seqlens, has_init_state=True) - q_cula, k_cula = inputs["q"], inputs["k"] # [B_flat, T, H, D] - q_fla = q_cula.repeat_interleave(HV // H, dim=2).contiguous() # [B_flat, T, HV, D] - k_fla = k_cula.repeat_interleave(HV // H, dim=2).contiguous() - else: - inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=True) - q_cula = q_fla = inputs["q"] - k_cula = k_fla = inputs["k"] - return inputs, q_fla, k_fla, q_cula, k_cula - - -def _compare_accuracy(fla_results, cula_results): - """Compare accuracy between FLA and cuLA results, handling GVA dq/dk shape mismatch.""" - acc = {} - for name in ("o", "ht", "dv", "dg", "dbeta", "dh0"): - if name in fla_results and name in cula_results: - err_ratio, rel_max, mean_diff = accuracy_stats(fla_results[name], cula_results[name]) - acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} - if "dq" in fla_results and "dq" in cula_results: - dq_fla = fla_results["dq"] - dk_fla = fla_results["dk"] - if HV > H: - # Aggregate FLA HV-space grads back to HQK space for comparison - *head_prefix, hv_size, d_size = dq_fla.shape - dq_fla = dq_fla.reshape(*head_prefix, H, HV // H, d_size).sum(dim=-2) - dk_fla = dk_fla.reshape(*head_prefix, H, HV // H, d_size).sum(dim=-2) - for name, ref, out in (("dq", dq_fla, cula_results["dq"]), ("dk", dk_fla, cula_results["dk"])): - err_ratio, rel_max, mean_diff = accuracy_stats(ref, out) - acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} - return acc - - # ============================================================ # Fixed-length benchmark # ============================================================ def bench_fixed(configs): - gva_note = f"GVA HV={HV} ({HV // H}x)" if HV > H else f"MHA HV=H={H}" print("\n" + "=" * 120) - print(f" Fixed-Length E2E Benchmark: cuLA vs FLA {gva_note} phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}") + print(f" Fixed-Length E2E Benchmark: cuLA vs FLA phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}") print("=" * 120) results = [] for B, T in configs: set_seed(SEED) + device = torch.device("cuda") torch.cuda.empty_cache() seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=torch.device("cuda")) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs, q_fla, k_fla, q_cula, k_cula = _prepare_inputs_e2e(B, T, cu_seqlens) - v, g, beta = inputs["v"], inputs["g"], inputs["beta"] + inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=True) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + # Generate do, dht for backward set_seed(SEED + 1) do = torch.randn_like(v) dht = torch.randn_like(init_state) - _shared = dict(v=v, g=g, beta=beta, scale=scale, A_log=A_log, dt_bias=dt_bias, - init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound, - do=do, dht=dht) - common_fla = dict(q=q_fla, k=k_fla, **_shared) - common_cula = dict(q=q_cula, k=k_cula, **_shared) - - # Accuracy + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + do=do, + dht=dht, + ) + + # Accuracy: compare outputs and gradients acc = {} if PHASE == "e2e": - fla_results = run_kda_e2e_with_grads(**common_fla, fn=fla_chunk_kda) - cula_results = run_kda_e2e_with_grads(**common_cula, fn=cula_chunk_kda) + fla_results = run_kda_e2e_with_grads(**common, fn=fla_chunk_kda) + cula_results = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) torch.cuda.synchronize() - acc = _compare_accuracy(fla_results, cula_results) + + for name in ("o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"): + err_ratio, rel_max, mean_diff = accuracy_stats(fla_results[name], cula_results[name]) + acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} else: - o_fla, ht_fla = run_kda_e2e(**common_fla, fn=fla_chunk_kda) - o_cula, ht_cula = run_kda_e2e(**common_cula, fn=cula_chunk_kda) + # forward-only accuracy + o_fla, ht_fla = run_kda_e2e(**common, fn=fla_chunk_kda) + o_cula, ht_cula = run_kda_e2e(**common, fn=cula_chunk_kda) torch.cuda.synchronize() for name, ref, out in [("o", o_fla, o_cula), ("ht", ht_fla, ht_cula)]: err_ratio, rel_max, mean_diff = accuracy_stats(ref, out) acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} - # Timing: fresh leaf tensors with requires_grad - def _make_timing(q_, k_): - return dict( - q=q_.detach().clone().requires_grad_(True), - k=k_.detach().clone().requires_grad_(True), - v=v.detach().clone().requires_grad_(True), - g=g.detach().clone().requires_grad_(True), - beta=beta.detach().clone().requires_grad_(True), - scale=scale, A_log=A_log, dt_bias=dt_bias, - init_state=init_state.detach().clone().requires_grad_(True), - cu_seqlens=cu_seqlens, lower_bound=lower_bound, do=do, dht=dht, - ) - - ms_fla = time_kernel(lambda: run_kda_e2e(**_make_timing(q_fla, k_fla), fn=fla_chunk_kda)) - ms_cula = time_kernel(lambda: run_kda_e2e(**_make_timing(q_cula, k_cula), fn=cula_chunk_kda)) + # For timing, use leaf tensors with requires_grad + q_t = q.detach().clone().requires_grad_(True) + k_t = k.detach().clone().requires_grad_(True) + v_t = v.detach().clone().requires_grad_(True) + g_t = g.detach().clone().requires_grad_(True) + beta_t = beta.detach().clone().requires_grad_(True) + h0_t = init_state.detach().clone().requires_grad_(True) + + timing_common = dict( + q=q_t, + k=k_t, + v=v_t, + g=g_t, + beta=beta_t, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=h0_t, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + do=do, + dht=dht, + ) + + def fn_fla(**kw): + return lambda: run_kda_e2e(**kw, fn=fla_chunk_kda) + + def fn_cula(**kw): + return lambda: run_kda_e2e(**kw, fn=cula_chunk_kda) + + ms_fla = time_kernel(fn_fla(**timing_common)) + ms_cula = time_kernel(fn_cula(**timing_common)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - results.append({ - "B": B, "T": T, "H": H, "HV": HV, - "accuracy": acc, "ms_fla": ms_fla, "ms_cula": ms_cula, "speedup": speedup, - }) - - del inputs, do, dht + r = { + "B": B, + "T": T, + "accuracy": acc, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + results.append(r) + + del q, k, v, g, beta, A_log, dt_bias, inputs, do, dht torch.cuda.empty_cache() return results @@ -354,76 +342,115 @@ def _make_timing(q_, k_): # Varlen benchmark # ============================================================ def bench_varlen(configs): - gva_note = f"GVA HV={HV} ({HV // H}x)" if HV > H else f"MHA HV=H={H}" print("\n" + "=" * 120) - print(f" Varlen E2E Benchmark: cuLA vs FLA {gva_note} phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}") + print(f" Varlen E2E Benchmark: cuLA vs FLA phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}") print("=" * 120) results = [] for seq_lens, total_len, dist in configs: set_seed(SEED) + device = torch.device("cuda") torch.cuda.empty_cache() T = total_len - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=torch.device("cuda")) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs, q_fla, k_fla, q_cula, k_cula = _prepare_inputs_e2e(1, T, cu_seqlens) - v, g, beta = inputs["v"], inputs["g"], inputs["beta"] + inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=True) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + # Generate do, dht for backward set_seed(SEED + 1) do = torch.randn_like(v) dht = torch.randn_like(init_state) - _shared = dict(v=v, g=g, beta=beta, scale=scale, A_log=A_log, dt_bias=dt_bias, - init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound, - do=do, dht=dht) - common_fla = dict(q=q_fla, k=k_fla, **_shared) - common_cula = dict(q=q_cula, k=k_cula, **_shared) - - # Accuracy + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + do=do, + dht=dht, + ) + + # Accuracy: compare outputs and gradients acc = {} if PHASE == "e2e": - fla_results = run_kda_e2e_with_grads(**common_fla, fn=fla_chunk_kda) - cula_results = run_kda_e2e_with_grads(**common_cula, fn=cula_chunk_kda) + fla_results = run_kda_e2e_with_grads(**common, fn=fla_chunk_kda) + cula_results = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) torch.cuda.synchronize() - acc = _compare_accuracy(fla_results, cula_results) + + for name in ("o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"): + err_ratio, rel_max, mean_diff = accuracy_stats(fla_results[name], cula_results[name]) + acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} else: - o_fla, ht_fla = run_kda_e2e(**common_fla, fn=fla_chunk_kda) - o_cula, ht_cula = run_kda_e2e(**common_cula, fn=cula_chunk_kda) + o_fla, ht_fla = run_kda_e2e(**common, fn=fla_chunk_kda) + o_cula, ht_cula = run_kda_e2e(**common, fn=cula_chunk_kda) torch.cuda.synchronize() for name, ref, out in [("o", o_fla, o_cula), ("ht", ht_fla, ht_cula)]: err_ratio, rel_max, mean_diff = accuracy_stats(ref, out) acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} - # Timing: fresh leaf tensors with requires_grad - def _make_timing(q_, k_): - return dict( - q=q_.detach().clone().requires_grad_(True), - k=k_.detach().clone().requires_grad_(True), - v=v.detach().clone().requires_grad_(True), - g=g.detach().clone().requires_grad_(True), - beta=beta.detach().clone().requires_grad_(True), - scale=scale, A_log=A_log, dt_bias=dt_bias, - init_state=init_state.detach().clone().requires_grad_(True), - cu_seqlens=cu_seqlens, lower_bound=lower_bound, do=do, dht=dht, - ) - - ms_fla = time_kernel(lambda: run_kda_e2e(**_make_timing(q_fla, k_fla), fn=fla_chunk_kda)) - ms_cula = time_kernel(lambda: run_kda_e2e(**_make_timing(q_cula, k_cula), fn=cula_chunk_kda)) + # For timing, use leaf tensors with requires_grad + q_t = q.detach().clone().requires_grad_(True) + k_t = k.detach().clone().requires_grad_(True) + v_t = v.detach().clone().requires_grad_(True) + g_t = g.detach().clone().requires_grad_(True) + beta_t = beta.detach().clone().requires_grad_(True) + h0_t = init_state.detach().clone().requires_grad_(True) + + timing_common = dict( + q=q_t, + k=k_t, + v=v_t, + g=g_t, + beta=beta_t, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=h0_t, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + do=do, + dht=dht, + ) + + def fn_fla(**kw): + return lambda: run_kda_e2e(**kw, fn=fla_chunk_kda) + + def fn_cula(**kw): + return lambda: run_kda_e2e(**kw, fn=cula_chunk_kda) + + ms_fla = time_kernel(fn_fla(**timing_common)) + ms_cula = time_kernel(fn_cula(**timing_common)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") n_seqs = len(seq_lens) - tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min(seq_lens)}..{max(seq_lens)}] avg={T // n_seqs}" - - results.append({ - "tag": tag, "dist": dist, "T_total": T, "n_seqs": n_seqs, - "H": H, "HV": HV, - "accuracy": acc, "ms_fla": ms_fla, "ms_cula": ms_cula, "speedup": speedup, - }) - - del inputs, do, dht + min_l, max_l = min(seq_lens), max(seq_lens) + avg_l = T // n_seqs + tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" + + r = { + "tag": tag, + "dist": dist, + "T_total": T, + "n_seqs": n_seqs, + "accuracy": acc, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + results.append(r) + + del q, k, v, g, beta, A_log, dt_bias, inputs, do, dht torch.cuda.empty_cache() return results @@ -438,17 +465,15 @@ def print_report(fixed_results, varlen_results): print(" BENCHMARK REPORT: chunk_kda forward+backward (E2E)") print(" cuLA CuTe DSL vs FLA Triton") print( - f" D={D} dtype=bf16 safe_gate=True phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}" + f" H={H} D={D} dtype=bf16 safe_gate=True phase={PHASE} disable_recompute={DISABLE_RECOMPUTE}" ) - gva_note = f"GVA enabled (HV={HV} > H={H}, ratio={HV // H}x)" if HV > H else f"MHA (HV=H={H})" - print(f" {gva_note}") wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") print(f" Warmup={wu} Iters={ni}{mode_tag}") print(sep) - # Determine which accuracy keys to show (dq/dk present in e2e mode) + # Determine which accuracy keys to show if PHASE == "e2e": acc_keys = ["o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"] else: @@ -458,39 +483,44 @@ def print_report(fixed_results, varlen_results): if fixed_results: print("\n [Fixed-Length]") - print(f" {'─' * 130}") - print(f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " - f"{'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s} │ {'':>10s}{acc_header}") - print(f" {'─' * 130}") + print(f" {'─' * 125}") + + # Header + print(f" {'B':>3s} {'T':>5s} │ {'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s} │ {'':>10s}{acc_header}") + print(f" {'─' * 125}") + for r in fixed_results: - gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" - rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) + rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) err_ratio_vals = " ".join(f"{r['accuracy'].get(k, {}).get('err_ratio', 0.0):10.6f}" for k in acc_keys) - prefix = f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " - blank = f" {'':3s} {'':6s} {'':3s} {'':3s} {'':4s} │ " - timing = f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x │ " - blank_t = f"{'':9s} {'':11s} {'':8s} │ " - print(f"{prefix}{timing}{'rel_max:':>10s}{rel_max_vals}") - print(f"{blank}{blank_t}{'err_ratio:':>10s}{err_ratio_vals}") - print(f" {'─' * 130}") + # Line 1: timing + rel_max + print( + f" {r['B']:3d} {r['T']:5d} │ " + f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x │ " + f"{'rel_max:':>10s}{rel_max_vals}" + ) + # Line 2: err_ratio (no timing columns) + print(f" {'':3s} {'':5s} │ {'':9s} {'':11s} {'':8s} │ {'err_ratio:':>10s}{err_ratio_vals}") + print(f" {'─' * 125}") if varlen_results: print("\n [Varlen]") - print(f" {'─' * 145}") - print(f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " - f"{'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s} │ {'':>10s}{acc_header}") - print(f" {'─' * 145}") + print(f" {'─' * 140}") + + print(f" {'Config':>45s} │ {'FLA(ms)':>9s} {'cuLA(ms)':>11s} {'Speedup':>8s} │ {'':>10s}{acc_header}") + print(f" {'─' * 140}") + for r in varlen_results: - gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" - rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) + rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) err_ratio_vals = " ".join(f"{r['accuracy'].get(k, {}).get('err_ratio', 0.0):10.6f}" for k in acc_keys) - prefix = f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " - blank = f" {'':>45s} {'':3s} {'':3s} {'':4s} │ " - timing = f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x │ " - blank_t = f"{'':9s} {'':11s} {'':8s} │ " - print(f"{prefix}{timing}{'rel_max:':>10s}{rel_max_vals}") - print(f"{blank}{blank_t}{'err_ratio:':>10s}{err_ratio_vals}") - print(f" {'─' * 145}") + # Line 1: timing + rel_max + print( + f" {r['tag']:>45s} │ " + f"{r['ms_fla']:9.4f} {r['ms_cula']:11.4f} {r['speedup']:7.2f}x │ " + f"{'rel_max:':>10s}{rel_max_vals}" + ) + # Line 2: err_ratio (no config/timing columns) + print(f" {'':>45s} │ {'':9s} {'':11s} {'':8s} │ {'err_ratio:':>10s}{err_ratio_vals}") + print(f" {'─' * 140}") print(f"\n{sep}\n") @@ -534,15 +564,9 @@ def main(): action="store_true", help="Run determinism check: verify cuLA produces identical outputs across repeated runs", ) - parser.add_argument( - "--hv", - type=int, - default=None, - help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H for GVA mode.", - ) args = parser.parse_args() - global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE, PHASE, HV + global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE, PHASE if args.ncu: NCU_MODE = True print("[NCU mode] warmup=1, iters=1") @@ -553,12 +577,6 @@ def main(): DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") PHASE = args.phase - if args.hv is not None: - if args.hv < H or args.hv % H != 0: - raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") - HV = args.hv - if HV > H: - print(f"[GVA] HV={HV} (H={H}, ratio={HV // H}x)") if args.check_determinism: det_configs = [(5, 1024), (10, 4096), (10, 8192), (10, 16384)] @@ -598,7 +616,6 @@ def main(): varlen_res = bench_varlen(varlen_configs) print_report(fixed_res, varlen_res) - return fixed_res, varlen_res diff --git a/cula/kda/chunk.py b/cula/kda/chunk.py index e0a9320..b89c780 100644 --- a/cula/kda/chunk.py +++ b/cula/kda/chunk.py @@ -379,17 +379,10 @@ def chunk_kda( if not (-5 <= lower_bound < 0): raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") - B, T, HQK, K_dim = q.shape - HV = g.shape[2] - assert q.shape == k.shape, "q and k must have the same shape." - assert q.shape[:2] == g.shape[:2], "q/k and g must share batch and sequence dimensions." - assert HV % HQK == 0 and HV >= HQK, ( - f"g.shape[2] (HV={HV}) must be a positive multiple of q.shape[2] (HQK={HQK})." - ) - assert g.shape == (B, T, HV, K_dim), f"g must be [B,T,HV,K]=({B},{T},{HV},{K_dim}), got {tuple(g.shape)}." - assert beta.shape[:3] == (B, T, HV), f"beta must have shape [B,T,HV,...], got {tuple(beta.shape)}." - assert v.shape[:3] == (B, T, HV), f"v must have shape [B,T,HV,...], got {tuple(v.shape)}." + assert q.shape == k.shape == g.shape, "q, k, g must have the same shape." assert k.shape[-1] <= 256, "Currently we only support key headdim <=256 for KDA :-(" + assert beta.shape == q.shape[:3], "beta must be of shape (batch size, seq len, num of head)." + assert v.shape == (*q.shape[:3], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)." assert q.dtype == k.dtype == v.dtype == torch.bfloat16, "q, k, v must be in bfloat16." assert beta.dtype == torch.bfloat16 or beta.dtype == torch.float32, "beta must be in bfloat16 or float32." assert q.shape[-1] == k.shape[-1] == v.shape[-1] == 128, "Currently we only support head dim of 128 for KDA" diff --git a/cula/kda/chunk_bwd.py b/cula/kda/chunk_bwd.py index de7074d..859b6be 100644 --- a/cula/kda/chunk_bwd.py +++ b/cula/kda/chunk_bwd.py @@ -61,6 +61,8 @@ ) @triton.jit(do_not_specialize=["T"]) def chunk_kda_bwd_kernel_dAv( + q, + k, v, A, do, @@ -78,7 +80,6 @@ def chunk_kda_bwd_kernel_dAv( BV: tl.constexpr, IS_VARLEN: tl.constexpr, ): - # H here is HV (v-head count). q/k are not needed for this computation. i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: @@ -89,6 +90,8 @@ def chunk_kda_bwd_kernel_dAv( bos, eos = i_b * T, i_b * T + T # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K v += (bos * H + i_h) * V do += (bos * H + i_h) * V dv += (bos * H + i_h) * V @@ -162,7 +165,6 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( scale, T, H: tl.constexpr, - HQK: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -171,12 +173,8 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( TRANSPOSE_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, ): - # H = HV (v-head count); grid enumerates B * HV tile-pairs. - # HQK = qk-head count (HQK <= H; when HQK == H this is standard non-GVA). - # For each v-head i_h, the paired qk-head is i_hqk = i_h // (H // HQK). i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - i_hqk = i_h // (H // HQK) if IS_VARLEN: i_tg = i_t.to(tl.int64) @@ -193,9 +191,8 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( m_t = o_t < T m_last = o_t == min(T, i_t * BT + BT) - 1 - # q/k/dq/dk use qk-head stride (HQK); all others use v-head stride (H = HV). - q += (bos * HQK + i_hqk) * K - k += (bos * HQK + i_hqk) * K + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K v += (bos * H + i_h) * V v_new += (bos * H + i_h) * V g += (bos * H + i_h) * K @@ -204,8 +201,8 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( h += (i_tg * H + i_h) * K * V do += (bos * H + i_h) * V dh += (i_tg * H + i_h) * K * V - dq += (bos * HQK + i_hqk) * K - dk += (bos * HQK + i_hqk) * K + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K dv += (bos * H + i_h) * V dv2 += (bos * H + i_h) * V dg += (bos * H + i_h) * K @@ -215,7 +212,7 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_beta = tl.load(p_beta, boundary_check=(0,)) - p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) # H = HV + p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) b_A = tl.load(p_A, boundary_check=(0, 1)) b_dA = tl.zeros([BT, BT], dtype=tl.float32) @@ -225,7 +222,7 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K - p_k = tl.make_block_ptr(k, (T, K), (HQK * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) @@ -290,15 +287,15 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( b_dkgb = tl.dot(b_A, b_dw) b_db += tl.sum(b_dkgb * b_kg, 1) - p_q = tl.make_block_ptr(q, (T, K), (HQK * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_kdk = b_k * b_dk b_dgk += tl.sum(b_kdk, axis=0) b_dg = b_q * b_dq - b_kdk + m_last[:, None] * b_dgk + b_kg * b_dkgb * b_beta[:, None] b_dk = b_dk + b_dkgb * b_gb - p_dq = tl.make_block_ptr(dq, (T, K), (HQK * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk, (T, K), (HQK * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) @@ -310,8 +307,8 @@ def chunk_kda_bwd_kernel_wy_dqkg_fused( b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) b_dA = tl.where(m_A, -b_dA, 0) - p_dA = tl.make_block_ptr(dA, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) # H = HV - p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT,), (BT,), (0,)) # H = HV + p_dA = tl.make_block_ptr(dA, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT,), (BT,), (0,)) tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) @@ -327,10 +324,7 @@ def chunk_kda_bwd_dAv( chunk_size: int = 64, chunk_indices: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - # q/k are accepted for API compatibility but not forwarded to the kernel - # (they are unused in the dAv computation). - B, T, HV, V = v.shape[0], v.shape[1], v.shape[2], v.shape[3] - K = k.shape[-1] + B, T, H, K, V = *k.shape, do.shape[-1] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) @@ -345,11 +339,12 @@ def chunk_kda_bwd_dAv( BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - # dA and dv live in v-head space (HV), matching Aqk/v_new/do shapes. - dA = v.new_empty(B, T, HV, BT, dtype=torch.float) + dA = v.new_empty(B, T, H, BT, dtype=torch.float) dv = torch.empty_like(do) - grid = (NT, B * HV) + grid = (NT, B * H) chunk_kda_bwd_kernel_dAv[grid]( + q=q, + k=k, v=v, A=A, do=do, @@ -359,7 +354,7 @@ def chunk_kda_bwd_dAv( chunk_indices=chunk_indices, scale=scale, T=T, - H=HV, + H=H, K=K, V=V, BT=BT, @@ -387,16 +382,13 @@ def chunk_kda_bwd_wy_dqkg_fused( chunk_indices: torch.LongTensor | None = None, transpose_state_layout: bool = False, ): - B, T, HQK, K = k.shape - HV = v.shape[2] - V = v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - # dq/dk live in qk-head space (HQK); dv/dg/db/dA in v-head space (HV). dq = torch.empty_like(q, dtype=torch.float) dk = torch.empty_like(k, dtype=torch.float) dv2 = torch.empty_like(v) @@ -404,7 +396,7 @@ def chunk_kda_bwd_wy_dqkg_fused( db = torch.empty_like(beta, dtype=torch.float) dA = torch.empty_like(A, dtype=torch.float) - grid = (NT, B * HV) + grid = (NT, B * H) chunk_kda_bwd_kernel_wy_dqkg_fused[grid]( q=q, k=k, @@ -427,8 +419,7 @@ def chunk_kda_bwd_wy_dqkg_fused( chunk_indices=chunk_indices, scale=scale, T=T, - H=HV, - HQK=HQK, + H=H, K=K, V=V, BT=BT, @@ -466,8 +457,7 @@ def chunk_kda_bwd( ): assert transpose_state_layout is False, "transpose_state_layout=True is not supported for training." if disable_recompute is False: - B, T, HQK, K_dim = k.shape - HV = v.shape[2] + B, T, _, _ = k.shape if use_gate_in_kernel: g = kda_gate_chunk_cumsum( g=g_org, @@ -485,11 +475,10 @@ def chunk_kda_bwd( cu_seqlens = prepare_uniform_cu_seqlens(B, T, q.device, torch.int32) if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) - # In GVA mode, w/u/qg/kg live in v-head space (HV), not qk-head space (HQK). - w = torch.empty(B, T, HV, K_dim, device=k.device, dtype=k.dtype) + w = torch.empty_like(k) u = torch.empty_like(v) - qg = torch.empty(B, T, HV, K_dim, device=q.device, dtype=q.dtype) if q is not None else None - kg = torch.empty(B, T, HV, K_dim, device=k.device, dtype=k.dtype) if g is not None else None + qg = torch.empty_like(q) if q is not None else None + kg = torch.empty_like(k) if g is not None else None cula_cuda.recompute_w_u_cuda(k, v, beta, Akk, g, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q, qg) if cp_context is not None: # Restore the full initial_state tensor from the compressed version. diff --git a/cula/kda/chunk_fwd.py b/cula/kda/chunk_fwd.py index 8ebb782..9fab235 100644 --- a/cula/kda/chunk_fwd.py +++ b/cula/kda/chunk_fwd.py @@ -126,17 +126,10 @@ def chunk_kda_fwd( # only the first state in the tensor is relevant. We compress it to optimize memory for `save_for_backward`. initial_state = compress_h0(initial_state, context=cp_context) - # GVA: if HQK < HV, broadcast q from HQK heads to HV heads so fwd_o sees - # consistent head dimensions. repeat_interleave mirrors the broadcast semantics - # already used inside chunk_kda_fwd_intra (each QK head paired with - # heads_per_group = HV // HQK consecutive V heads). - HQK, HV = q.shape[2], v.shape[2] - q_fwd_o = q.repeat_interleave(HV // HQK, dim=2) if HV != HQK else q - # Please ensure zeros, since vllm will use padding v o = torch.zeros_like(v) chunk_gla_fwd_o( - q=q_fwd_o, + q=q, v=v_new, g=g, A=Aqk, diff --git a/cula/utils.py b/cula/utils.py index 3b43fe0..bd70730 100644 --- a/cula/utils.py +++ b/cula/utils.py @@ -83,7 +83,7 @@ def assert_hopper(device: torch.device | str | int | None = None) -> None: def get_kda_fused_fwd(device: torch.device | str | int | None = None) -> Callable: """Return the appropriate ``kda_prefill`` implementation for *device*. - - sm100/sm103 (Blackwell) → NotImplementedError + - sm100/sm103 (Blackwell) → cula.kda.kda_prefill_blackwell (not yet available) - sm90 (Hopper) → cula.kda.kda_prefill_hopper Args: diff --git a/tests/test_kda.py b/tests/test_kda.py index 99b3ca1..fadadbf 100644 --- a/tests/test_kda.py +++ b/tests/test_kda.py @@ -25,12 +25,6 @@ from cula.kda import chunk_kda -# ─── helpers ────────────────────────────────────────────────────────────────── - -def _repeat_head(x: torch.Tensor, group_size: int) -> torch.Tensor: - """Replicate tensor along head dim (dim=2) by group_size, keeping contiguous layout.""" - return x.repeat_interleave(group_size, dim=2).contiguous() - pytestmark = pytest.mark.sm100_only @@ -287,184 +281,3 @@ def test_safe_gate_chunk_varlen( assert_close("dg", ref_dg, tri_dg, 0.015) assert_close("db", ref_db, tri_db, 0.015) assert_close("dh0", ref_dh0, tri_dh0, 0.007) - - -# ============================================================================= -# GVA (Grouped Value Attention) end-to-end tests -# ============================================================================= - -@pytest.mark.parametrize("disable_recompute", [True, False], ids=["no_recomp", "recomp"]) -@pytest.mark.parametrize("group_size", [2, 4], ids=["gs2", "gs4"]) -@pytest.mark.parametrize( - ("B", "T", "HQK", "D"), - [ - pytest.param(*cfg, id="B{}-T{}-HQK{}-D{}".format(*cfg)) - for cfg in [ - (1, 256, 2, 128), - (2, 512, 4, 128), - (1, 1000, 4, 128), # non-multiple-of-BT boundary stress - (2, 1024, 4, 128), - ] - ], -) -def test_chunk_kda_gva(B, T, HQK, D, group_size, disable_recompute): - """chunk_kda with native GVA (HQK < HV = HQK * group_size) must produce the - same forward outputs and the same v/g/beta/h0 gradients as running the - reference (FLA naive_recurrent_kda) with q/k expanded to HV heads. - - dq/dk gradients are also verified after summing over the group axis, - since the reference receives k replicated to HV and therefore accumulates - gradients across the group dimension. - """ - HV = HQK * group_size - torch.manual_seed(42) - - # ---- raw tensors -------------------------------------------------------- - q_raw = torch.randn(B, T, HQK, D, dtype=torch.bfloat16) - k_raw = torch.randn(B, T, HQK, D, dtype=torch.bfloat16) - v_raw = torch.randn(B, T, HV, D, dtype=torch.bfloat16) - # Gates must satisfy safe_gate: log-sigmoid values clamped to [-5, 0] - g_raw = F.logsigmoid(torch.randn(B, T, HV, D, dtype=torch.float)).clamp(-5.0, 0.0) - beta_raw = torch.randn(B, T, HV, dtype=torch.float32).sigmoid() - h0_raw = torch.randn(B, HV, D, D, dtype=torch.float32) - - # ---- reference (FLA naive_recurrent_kda with expanded q/k) -------------- - # Apply l2norm before expanding to make both paths numerically comparable. - q_norm = F.normalize(q_raw.float(), p=2, dim=-1).to(torch.bfloat16) - k_norm = F.normalize(k_raw.float(), p=2, dim=-1).to(torch.bfloat16) - - q_hv = _repeat_head(q_norm, group_size).to(device).requires_grad_(True) - k_hv = _repeat_head(k_norm, group_size).to(device).requires_grad_(True) - v_ref = v_raw.to(device).requires_grad_(True) - g_ref = g_raw.to(device).requires_grad_(True) - b_ref = beta_raw.to(device).requires_grad_(True) - h0_ref = h0_raw.to(device).requires_grad_(True) - - ref_o, ref_ht = naive_recurrent_kda( - q=q_hv, k=k_hv, v=v_ref, g=g_ref, beta=b_ref, - initial_state=h0_ref, output_final_state=True, - ) - do = torch.randn_like(ref_o) - dht = torch.randn_like(ref_ht) - ((ref_o * do).sum() + (ref_ht * dht).sum()).backward() - - ref_dq_hv = q_hv.grad # [B,T,HV,D] - ref_dk_hv = k_hv.grad # [B,T,HV,D] - ref_dv = v_ref.grad # [B,T,HV,D] - ref_dg = g_ref.grad # [B,T,HV,D] - ref_db = b_ref.grad # [B,T,HV] - ref_dh0 = h0_ref.grad # [B,HV,D,D] - - # Sum group contributions for dq/dk: [B,T,HV,D] → [B,T,HQK,D] - ref_dq = ref_dq_hv.view(B, T, HQK, group_size, D).sum(dim=3) - ref_dk = ref_dk_hv.view(B, T, HQK, group_size, D).sum(dim=3) - - # ---- cuLA chunk_kda with native GVA ------------------------------------ - q_c = q_norm.to(device).requires_grad_(True) - k_c = k_norm.to(device).requires_grad_(True) - v_c = v_raw.to(device).requires_grad_(True) - g_c = g_raw.to(device).requires_grad_(True) - b_c = beta_raw.to(device).requires_grad_(True) - h0_c = h0_raw.to(device).requires_grad_(True) - - tri_o, tri_ht = chunk_kda( - q=q_c, k=k_c, v=v_c, g=g_c, beta=b_c, - initial_state=h0_c, output_final_state=True, - safe_gate=True, lower_bound=-5.0, - disable_recompute=disable_recompute, - ) - ((tri_o * do).sum() + (tri_ht * dht).sum()).backward() - - # ---- compare ------------------------------------------------------------ - assert_close("o", ref_o, tri_o, 0.005) - assert_close("ht", ref_ht, tri_ht, 0.005) - assert_close("dq", ref_dq, q_c.grad, 0.01) - assert_close("dk", ref_dk, k_c.grad, 0.01) - assert_close("dv", ref_dv, v_c.grad, 0.008) - assert_close("dg", ref_dg, g_c.grad, 0.02) - assert_close("db", ref_db, b_c.grad, 0.02) - assert_close("dh0", ref_dh0, h0_c.grad, 0.008) - - -@pytest.mark.parametrize("disable_recompute", [True, False], ids=["no_recomp", "recomp"]) -@pytest.mark.parametrize("group_size", [2], ids=["gs2"]) -@pytest.mark.parametrize( - ("HQK", "D", "cu_seqlens"), - [ - pytest.param(*cfg, id="HQK{}-D{}-ns{}".format(cfg[0], cfg[1], len(cfg[2]) - 1)) - for cfg in [ - (2, 128, [0, 256, 500, 1000]), - (4, 128, [0, 100, 300, 1200, 2000]), - (4, 128, [0, 15, 100, 300, 1200, 3000, 4096]), - ] - ], -) -def test_chunk_kda_gva_varlen(HQK, D, cu_seqlens, group_size, disable_recompute): - """GVA chunk_kda correctness under variable-length (packed) inputs.""" - HV = HQK * group_size - torch.manual_seed(42) - - cu_seqlens_t = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - cu_seqlens_cpu = cu_seqlens_t.cpu() - T = int(cu_seqlens_t[-1].item()) - N = len(cu_seqlens) - 1 - - q_raw = torch.randn(1, T, HQK, D, dtype=torch.bfloat16) - k_raw = torch.randn(1, T, HQK, D, dtype=torch.bfloat16) - v_raw = torch.randn(1, T, HV, D, dtype=torch.bfloat16) - g_raw = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float)).clamp(-5.0, 0.0) - beta_raw = torch.randn(1, T, HV, dtype=torch.float32).sigmoid() - h0_raw = torch.randn(N, HV, D, D, dtype=torch.float32) - - q_norm = F.normalize(q_raw.float(), p=2, dim=-1).to(torch.bfloat16) - k_norm = F.normalize(k_raw.float(), p=2, dim=-1).to(torch.bfloat16) - - # ---- reference ---------------------------------------------------------- - q_hv = _repeat_head(q_norm, group_size).to(device).requires_grad_(True) - k_hv = _repeat_head(k_norm, group_size).to(device).requires_grad_(True) - v_ref = v_raw.to(device).requires_grad_(True) - g_ref = g_raw.to(device).requires_grad_(True) - b_ref = beta_raw.to(device).requires_grad_(True) - h0_ref = h0_raw.to(device).requires_grad_(True) - - ref_o, ref_ht = chunk_kda( - q=q_hv, k=k_hv, v=v_ref, g=g_ref, beta=b_ref, - initial_state=h0_ref, output_final_state=True, - cu_seqlens=cu_seqlens_t, cu_seqlens_cpu=cu_seqlens_cpu, - safe_gate=True, lower_bound=-5.0, - disable_recompute=disable_recompute, - ) - do = torch.randn_like(ref_o) - dht = torch.randn_like(ref_ht) - ((ref_o * do).sum() + (ref_ht * dht).sum()).backward() - - ref_dq = q_hv.grad.view(1, T, HQK, group_size, D).sum(dim=3) - ref_dk = k_hv.grad.view(1, T, HQK, group_size, D).sum(dim=3) - ref_dv, ref_dg, ref_db, ref_dh0 = v_ref.grad, g_ref.grad, b_ref.grad, h0_ref.grad - - # ---- cuLA native GVA ---------------------------------------------------- - q_c = q_norm.to(device).requires_grad_(True) - k_c = k_norm.to(device).requires_grad_(True) - v_c = v_raw.to(device).requires_grad_(True) - g_c = g_raw.to(device).requires_grad_(True) - b_c = beta_raw.to(device).requires_grad_(True) - h0_c = h0_raw.to(device).requires_grad_(True) - - tri_o, tri_ht = chunk_kda( - q=q_c, k=k_c, v=v_c, g=g_c, beta=b_c, - initial_state=h0_c, output_final_state=True, - cu_seqlens=cu_seqlens_t, cu_seqlens_cpu=cu_seqlens_cpu, - safe_gate=True, lower_bound=-5.0, - disable_recompute=disable_recompute, - ) - ((tri_o * do).sum() + (tri_ht * dht).sum()).backward() - - # ---- compare ------------------------------------------------------------ - assert_close("o", ref_o, tri_o, 0.005) - assert_close("ht", ref_ht, tri_ht, 0.005) - assert_close("dq", ref_dq, q_c.grad, 0.01) - assert_close("dk", ref_dk, k_c.grad, 0.01) - assert_close("dv", ref_dv, v_c.grad, 0.008) - assert_close("dg", ref_dg, g_c.grad, 0.02) - assert_close("db", ref_db, b_c.grad, 0.02) - assert_close("dh0", ref_dh0, h0_c.grad, 0.008) From c08a33dee7f42bb61da6c930ae617708f272fdfb Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 20:09:03 +0800 Subject: [PATCH 10/14] benchmark and test --- cula/kda/chunk_intra.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cula/kda/chunk_intra.py b/cula/kda/chunk_intra.py index a9c620a..6f32ae4 100644 --- a/cula/kda/chunk_intra.py +++ b/cula/kda/chunk_intra.py @@ -759,7 +759,9 @@ def chunk_kda_fwd_intra( unified_gref: bool = False, # Set True for ~5% extra perf (slightly lower precision) ): assert safe_gate, "Only safe_gate=True is supported in chunk_kda_fwd_intra for now" - B, T, H, K = k.shape + B, T, H_QK, K = k.shape + # GVA: g/beta/v live in h_v head space; q/k live in h_qk head space. + H_V = v.size(2) BT = chunk_size if cu_seqlens is None: @@ -773,18 +775,20 @@ def chunk_kda_fwd_intra( "cu_seqlens and chunk_indices must be int32 for cuda impl" ) - Aqk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) - Akk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) + # Aqk and Akk are produced per v-head by the intra kernel. + Aqk = torch.empty(B, T, H_V, BT, device=k.device, dtype=k.dtype) + Akk = torch.empty(B, T, H_V, BT, device=k.device, dtype=k.dtype) tile_counter = torch.zeros(1, dtype=torch.int32, device=q.device) cula_cuda.chunk_kda_fwd_intra_cuda( q, k, gk, beta, cu_seqlens, chunk_indices, Aqk, Akk, tile_counter, scale, chunk_size, use_tf32_inverse, unified_gref ) - w = torch.empty_like(k) + # w, u, kg, qg all live in h_v head space. + w = torch.empty_like(v) u = torch.empty_like(v) - qg = torch.empty_like(q) if disable_recompute else None - kg = torch.empty_like(k) if gk is not None else None + qg = torch.empty(B, T, H_V, K, device=q.device, dtype=q.dtype) if disable_recompute else None + kg = torch.empty(B, T, H_V, K, device=k.device, dtype=k.dtype) if gk is not None else None cula_cuda.recompute_w_u_cuda( k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg From 1a0eda85fe4d78bf9c298c29c8a51062b540f0dc Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Tue, 19 May 2026 21:06:38 +0800 Subject: [PATCH 11/14] benchmark and test --- cula/kda/chunk_intra.py | 3 + tests/test_kda_gva_intra_sm100.py | 103 +++++++++++++++++------------- 2 files changed, 60 insertions(+), 46 deletions(-) diff --git a/cula/kda/chunk_intra.py b/cula/kda/chunk_intra.py index 6f32ae4..aeb063f 100644 --- a/cula/kda/chunk_intra.py +++ b/cula/kda/chunk_intra.py @@ -762,6 +762,9 @@ def chunk_kda_fwd_intra( B, T, H_QK, K = k.shape # GVA: g/beta/v live in h_v head space; q/k live in h_qk head space. H_V = v.size(2) + assert H_QK > 0 and H_V > 0 and H_V % H_QK == 0, ( + f"HV ({H_V}) must be a positive multiple of HQK ({H_QK})" + ) BT = chunk_size if cu_seqlens is None: diff --git a/tests/test_kda_gva_intra_sm100.py b/tests/test_kda_gva_intra_sm100.py index a86e56f..8082946 100644 --- a/tests/test_kda_gva_intra_sm100.py +++ b/tests/test_kda_gva_intra_sm100.py @@ -30,6 +30,8 @@ import pytest import torch +from einops import rearrange +from fla.modules.l2norm import l2norm_fwd from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra as fla_chunk_kda_fwd_intra from fla.ops.kda.gate import kda_gate_chunk_cumsum from fla.ops.utils.constant import RCP_LN2 @@ -46,10 +48,6 @@ # Helpers # ========================================================================= -def _l2norm_last(x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.normalize(x.float(), p=2, dim=-1).to(x.dtype) - - def _repeat_head(x: torch.Tensor, group_size: int, head_dim: int = 2) -> torch.Tensor: """Replicate ``x`` along the head axis by ``group_size``. @@ -95,8 +93,15 @@ def _make_gva_inputs( beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() # l2-normalise q/k so that scale/gate ranges match production use. - q = _l2norm_last(q) - k = _l2norm_last(k) + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + + # FLA gate cumsum only supports packed batch (B=1) when cu_seqlens is set. + if B != 1: + q, k, v, g_raw, beta = map( + lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), + (q, k, v, g_raw, beta), + ) # Per-HV gate preprocessing (cumsum inside chunks). A_log = torch.randn(HV, dtype=torch.float, device=device) @@ -157,6 +162,36 @@ def _run_cula_gva(q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size ) +def _assert_intra_outputs_match(ref, tri, disable_recompute: bool) -> None: + """Compare cuLA vs FLA on user-visible intra outputs. + + We intentionally skip ``Aqk``: the cuLA SM100 fused kernel does not + materialise every off-diagonal slot that FLA's multi-kernel path writes, + and the FLA reference can contain NaNs in unused ``Aqk`` entries. The + downstream tensors ``w`` / ``u`` / ``kg`` (and ``Akk``) are the meaningful + correctness signals and match the benchmark's comparison strategy. + """ + w_r, u_r, qg_r, kg_r, _Aqk_r, Akk_r = ref + w_c, u_c, qg_c, kg_c, _Aqk_c, Akk_c = tri + + assert Akk_c.shape == Akk_r.shape, (Akk_c.shape, Akk_r.shape) + assert w_c.shape == w_r.shape, (w_c.shape, w_r.shape) + assert u_c.shape == u_r.shape, (u_c.shape, u_r.shape) + assert kg_c.shape == kg_r.shape, (kg_c.shape, kg_r.shape) + + assert_close("Akk", Akk_r, Akk_c, 0.008) + assert_close("w", w_r, w_c, 0.008) + assert_close("u", u_r, u_c, 0.008) + assert_close("kg", kg_r, kg_c, 0.005) + + if disable_recompute: + assert qg_c is not None and qg_r is not None + assert qg_c.shape == qg_r.shape, (qg_c.shape, qg_r.shape) + assert_close("qg", qg_r, qg_c, 0.005) + else: + assert qg_c is None, "cuLA must not materialise qg when disable_recompute=False" + + # ========================================================================= # Uniform-length tests # ========================================================================= @@ -199,28 +234,11 @@ def test_gva_intra_uniform(B, T, HQK, group_size, D, disable_recompute): q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, group_size, disable_recompute, ) - # All outputs live in HV head space → shapes must match directly. - assert Aqk_c.shape == Aqk_r.shape, (Aqk_c.shape, Aqk_r.shape) - assert Akk_c.shape == Akk_r.shape, (Akk_c.shape, Akk_r.shape) - assert w_c.shape == w_r.shape, (w_c.shape, w_r.shape) - assert u_c.shape == u_r.shape, (u_c.shape, u_r.shape) - assert kg_c.shape == kg_r.shape, (kg_c.shape, kg_r.shape) - - # Aqk / Akk are the core A-matrices; they drive w/u, so keep tolerances tight. - assert_close("Aqk", Aqk_r, Aqk_c, 0.005) - assert_close("Akk", Akk_r, Akk_c, 0.008) - - # recompute_w_u outputs - assert_close("w", w_r, w_c, 0.008) - assert_close("u", u_r, u_c, 0.008) - assert_close("kg", kg_r, kg_c, 0.005) - - if disable_recompute: - assert qg_c is not None and qg_r is not None - assert qg_c.shape == qg_r.shape, (qg_c.shape, qg_r.shape) - assert_close("qg", qg_r, qg_c, 0.005) - else: - assert qg_c is None, "cuLA must not materialise qg when disable_recompute=False" + _assert_intra_outputs_match( + (w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r), + (w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c), + disable_recompute, + ) # ========================================================================= @@ -263,16 +281,11 @@ def test_gva_intra_varlen(HQK, group_size, D, cu_seqlens, disable_recompute): q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices, chunk_size, group_size, disable_recompute, ) - assert_close("Aqk", Aqk_r, Aqk_c, 0.005) - assert_close("Akk", Akk_r, Akk_c, 0.008) - assert_close("w", w_r, w_c, 0.008) - assert_close("u", u_r, u_c, 0.008) - assert_close("kg", kg_r, kg_c, 0.005) - - if disable_recompute: - assert_close("qg", qg_r, qg_c, 0.005) - else: - assert qg_c is None + _assert_intra_outputs_match( + (w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r), + (w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c), + disable_recompute, + ) # ========================================================================= @@ -314,13 +327,11 @@ def test_gva_intra_degenerate_equals_non_gva(B, T, H, D, disable_recompute): safe_gate=True, disable_recompute=disable_recompute, ) - assert_close("Aqk", Aqk_r, Aqk_c, 0.005) - assert_close("Akk", Akk_r, Akk_c, 0.008) - assert_close("w", w_r, w_c, 0.008) - assert_close("u", u_r, u_c, 0.008) - assert_close("kg", kg_r, kg_c, 0.005) - if disable_recompute: - assert_close("qg", qg_r, qg_c, 0.005) + _assert_intra_outputs_match( + (w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r), + (w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c), + disable_recompute, + ) # ========================================================================= @@ -367,7 +378,7 @@ def test_gva_intra_rejects_non_multiple_ratio(): g = torch.randn(B, T, HV, D, dtype=torch.float, device=device) beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, RuntimeError), match=r"multiple|h_v"): cula_chunk_kda_fwd_intra( q=q, k=k, v=v, gk=g, beta=beta, scale=D ** -0.5, cu_seqlens=cu_seqlens, chunk_size=chunk_size, From feebcfd05a04bf0bf0eaed14076a7c9b0c035d5d Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 21 May 2026 23:03:54 +0800 Subject: [PATCH 12/14] benchmark --- benchmarks/bench_kda_chunk_intra.py | 362 +++++++--------------------- 1 file changed, 89 insertions(+), 273 deletions(-) diff --git a/benchmarks/bench_kda_chunk_intra.py b/benchmarks/bench_kda_chunk_intra.py index 466dd70..61b55ad 100644 --- a/benchmarks/bench_kda_chunk_intra.py +++ b/benchmarks/bench_kda_chunk_intra.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +bench_kda_chunk_intra.py — Benchmark: cuLA vs FLA Triton for chunk_kda_fwd_intra + +Supports both standard (HV=H) and GVA (HV > H) modes via --hv / --heads flags. +In GVA mode the FLA reference replicates q/k to HV heads; cuLA operates natively +with compact q/k in HQK space. + +Usage: + python bench_kda_chunk_intra.py [--heads H] [--hv HV] [--disable_recompute] +""" + import argparse import os import pathlib @@ -30,6 +41,7 @@ # Constant params B, H, D = 2, 64, 128 +HV = H # overridable via --hv; HV > H enables GVA mode BT = 64 # chunk size # Varlen benchmark params @@ -39,7 +51,6 @@ VARIANCE = 1.0 DISABLE_RECOMPUTE = False # Whether to disable recompute (compute QG in forward) -GROUP_SIZE = 1 # GVA group size: HV = GROUP_SIZE * H. 1 means no GVA. def accuracy_stats(a, b): @@ -55,217 +66,21 @@ def accuracy_stats(a, b): # ============================================================================== -# Uniform seqlen benchmark +# Unified uniform seqlen benchmark (handles both standard and GVA) # ============================================================================== def benchmark_chunk_intra_uniform(): - device = torch.device("cuda") - chunk_size = BT - T_vals = [512, 1024, 4096, 8192, 16384, 32768] - - print("=" * 90) - print( - f" Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton B={B} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}" - ) - print("=" * 90) - print( - f"{'B':>4} {'T':>7} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" - ) - print("─" * 90) - - for T in T_vals: - seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens) - - # Accuracy: run once and compare - out_fla = fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ) - out_cula = cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ) - # Compare the first output tensor (o) - o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla - o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula - rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) - - # Performance - ms_fla = triton.testing.do_bench( - lambda: fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ), - ) - ms_cula = triton.testing.do_bench( - lambda: cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ), - ) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - - print( - f"{B:>4} {T:>7} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" - ) - - print("─" * 90) - - -# ============================================================================== -# Varlen benchmark -# ============================================================================== -def benchmark_chunk_intra_varlen(): - device = torch.device("cuda") - chunk_size = BT - total_len_vals = [8192, 16384, 32768, 65536] - - print() - print("=" * 100) - print( - f" Varlen ChunkIntra Benchmark: cuLA vs FLA Triton NUM_SEQS={NUM_SEQS} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}" - ) - print("=" * 100) - print( - f"{'total_len':>10} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" - ) - print("─" * 100) - - for total_len in total_len_vals: - seq_lens = generate_random_seq_lens(NUM_SEQS, total_len, MIN_SEQ_LEN, VARIANCE, SEED) - T = total_len - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens) - - # Accuracy - out_fla = fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ) - out_cula = cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ) - o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla - o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula - rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) - - # Performance - ms_fla = triton.testing.do_bench( - lambda: fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ), - ) - ms_cula = triton.testing.do_bench( - lambda: cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ), - ) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - - print( - f"{total_len:>10} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" - ) - - print("─" * 100) - - -# ============================================================================== -# GVA uniform seqlen benchmark -# ============================================================================== -def benchmark_chunk_intra_gva_uniform(group_size: int): - """Benchmark GVA (HV > HQK) intra chunk: cuLA vs FLA Triton (k replicated to HV). - - FLA does not natively support GVA, so the reference replicates k along the - head axis to HV before calling the kernel (same strategy as in the unit tests). - """ device = torch.device("cuda") chunk_size = BT HQK = H - HV = HQK * group_size + gva_mode = HV > HQK + group_size = HV // HQK T_vals = [512, 1024, 4096, 8192, 16384, 32768] + gva_note = f"HQK={HQK} HV={HV} (group_size={group_size})" if gva_mode else f"H={HQK}" print("=" * 100) print( - f" GVA Uniform ChunkIntra Benchmark: cuLA vs FLA Triton " - f"B={B} HQK={HQK} HV={HV} (group_size={group_size}) D={D} disable_recompute={DISABLE_RECOMPUTE}" + f" Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton " + f"B={B} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}" ) print("=" * 100) print( @@ -277,45 +92,39 @@ def benchmark_chunk_intra_gva_uniform(group_size: int): seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( - B, T, HQK, HV, D, device, cu_seqlens=cu_seqlens - ) - - # FLA reference: replicate k/q to HV heads - k_hv = k.repeat_interleave(group_size, dim=2).contiguous() - q_hv = q.repeat_interleave(group_size, dim=2).contiguous() - - # Accuracy: run once and compare - out_fla = fla_chunk_kda_fwd_intra( - q=q_hv, k=k_hv, v=v, gk=g, beta=beta, scale=scale, + if gva_mode: + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( + B, T, HQK, HV, D, device, cu_seqlens=cu_seqlens + ) + k_ref = k.repeat_interleave(group_size, dim=2).contiguous() + q_ref = q.repeat_interleave(group_size, dim=2).contiguous() + else: + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( + B, T, H, D, device, cu_seqlens=cu_seqlens + ) + q_ref, k_ref = q, k + + common_fla = dict( + q=q_ref, k=k_ref, v=v, gk=g, beta=beta, scale=scale, cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, ) - out_cula = cula_chunk_kda_fwd_intra( + common_cula = dict( q=q, k=k, v=v, gk=g, beta=beta, scale=scale, cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, ) - # Compare first output (w) + + # Accuracy: run once and compare + out_fla = fla_chunk_kda_fwd_intra(**common_fla) + out_cula = cula_chunk_kda_fwd_intra(**common_cula) o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) # Performance - ms_fla = triton.testing.do_bench( - lambda: fla_chunk_kda_fwd_intra( - q=q_hv, k=k_hv, v=v, gk=g, beta=beta, scale=scale, - cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, - safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, - ), - ) - ms_cula = triton.testing.do_bench( - lambda: cula_chunk_kda_fwd_intra( - q=q, k=k, v=v, gk=g, beta=beta, scale=scale, - cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, - safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, - ), - ) + ms_fla = triton.testing.do_bench(lambda: fla_chunk_kda_fwd_intra(**common_fla)) + ms_cula = triton.testing.do_bench(lambda: cula_chunk_kda_fwd_intra(**common_cula)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") print( @@ -326,21 +135,22 @@ def benchmark_chunk_intra_gva_uniform(group_size: int): # ============================================================================== -# GVA varlen benchmark +# Unified varlen benchmark (handles both standard and GVA) # ============================================================================== -def benchmark_chunk_intra_gva_varlen(group_size: int): - """Varlen GVA benchmark: cuLA vs FLA Triton (k replicated to HV).""" +def benchmark_chunk_intra_varlen(): device = torch.device("cuda") chunk_size = BT HQK = H - HV = HQK * group_size + gva_mode = HV > HQK + group_size = HV // HQK total_len_vals = [8192, 16384, 32768, 65536] + gva_note = f"HQK={HQK} HV={HV} (group_size={group_size})" if gva_mode else f"H={HQK}" print() print("=" * 110) print( - f" GVA Varlen ChunkIntra Benchmark: cuLA vs FLA Triton " - f"NUM_SEQS={NUM_SEQS} HQK={HQK} HV={HV} (group_size={group_size}) D={D} disable_recompute={DISABLE_RECOMPUTE}" + f" Varlen ChunkIntra Benchmark: cuLA vs FLA Triton " + f"NUM_SEQS={NUM_SEQS} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}" ) print("=" * 110) print( @@ -353,43 +163,39 @@ def benchmark_chunk_intra_gva_varlen(group_size: int): T = total_len cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( - 1, T, HQK, HV, D, device, cu_seqlens=cu_seqlens - ) - - k_hv = k.repeat_interleave(group_size, dim=2).contiguous() - q_hv = q.repeat_interleave(group_size, dim=2).contiguous() - - # Accuracy - out_fla = fla_chunk_kda_fwd_intra( - q=q_hv, k=k_hv, v=v, gk=g, beta=beta, scale=scale, + if gva_mode: + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( + 1, T, HQK, HV, D, device, cu_seqlens=cu_seqlens + ) + k_ref = k.repeat_interleave(group_size, dim=2).contiguous() + q_ref = q.repeat_interleave(group_size, dim=2).contiguous() + else: + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( + 1, T, H, D, device, cu_seqlens=cu_seqlens + ) + q_ref, k_ref = q, k + + common_fla = dict( + q=q_ref, k=k_ref, v=v, gk=g, beta=beta, scale=scale, cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, ) - out_cula = cula_chunk_kda_fwd_intra( + common_cula = dict( q=q, k=k, v=v, gk=g, beta=beta, scale=scale, cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, ) + + # Accuracy + out_fla = fla_chunk_kda_fwd_intra(**common_fla) + out_cula = cula_chunk_kda_fwd_intra(**common_cula) o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) # Performance - ms_fla = triton.testing.do_bench( - lambda: fla_chunk_kda_fwd_intra( - q=q_hv, k=k_hv, v=v, gk=g, beta=beta, scale=scale, - cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, - safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, - ), - ) - ms_cula = triton.testing.do_bench( - lambda: cula_chunk_kda_fwd_intra( - q=q, k=k, v=v, gk=g, beta=beta, scale=scale, - cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, - safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, - ), - ) + ms_fla = triton.testing.do_bench(lambda: fla_chunk_kda_fwd_intra(**common_fla)) + ms_cula = triton.testing.do_bench(lambda: cula_chunk_kda_fwd_intra(**common_cula)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") print( @@ -407,11 +213,16 @@ def benchmark_chunk_intra_gva_varlen(group_size: int): help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) parser.add_argument( - "--group_size", + "--heads", + type=int, + default=None, + help=f"Override number of QK heads (H). Default: {H}.", + ) + parser.add_argument( + "--hv", type=int, - default=1, - help="GVA group size: HV = group_size * H. 1 (default) runs the non-GVA benchmark. " - "Values > 1 run GVA benchmarks comparing cuLA (k in HQK space) vs FLA (k replicated to HV).", + default=None, + help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H to run in GVA mode.", ) args = parser.parse_args() @@ -419,12 +230,17 @@ def benchmark_chunk_intra_gva_varlen(group_size: int): DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") - GROUP_SIZE = args.group_size + if args.heads is not None: + H = args.heads + HV = H # reset HV to new H unless --hv is also provided + + if args.hv is not None: + if args.hv < H or args.hv % H != 0: + raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") + HV = args.hv + + if HV > H: + print(f"[GVA] HV={HV} (H={H}, group_size={HV // H}x)") - if GROUP_SIZE == 1: - benchmark_chunk_intra_uniform() - benchmark_chunk_intra_varlen() - else: - assert H % 1 == 0, "H must be divisible by group_size" - benchmark_chunk_intra_gva_uniform(GROUP_SIZE) - benchmark_chunk_intra_gva_varlen(GROUP_SIZE) + benchmark_chunk_intra_uniform() + benchmark_chunk_intra_varlen() From 853e0d2ab7d79fa2b4d31d70d8bfdb994d8cae6a Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 21 May 2026 23:13:18 +0800 Subject: [PATCH 13/14] benchmark --- benchmarks/bench_kda_chunk_intra.py | 36 +++++--------- benchmarks/bench_recompute_wu.py | 6 +-- benchmarks/utils.py | 75 ++++++++--------------------- 3 files changed, 35 insertions(+), 82 deletions(-) diff --git a/benchmarks/bench_kda_chunk_intra.py b/benchmarks/bench_kda_chunk_intra.py index 61b55ad..0a51ab9 100644 --- a/benchmarks/bench_kda_chunk_intra.py +++ b/benchmarks/bench_kda_chunk_intra.py @@ -36,7 +36,7 @@ from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra as fla_chunk_kda_fwd_intra -from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs, prepare_intra_inputs_gva +from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs from cula.kda.chunk_intra import chunk_kda_fwd_intra as cula_chunk_kda_fwd_intra # Constant params @@ -92,17 +92,12 @@ def benchmark_chunk_intra_uniform(): seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - if gva_mode: - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( - B, T, HQK, HV, D, device, cu_seqlens=cu_seqlens - ) - k_ref = k.repeat_interleave(group_size, dim=2).contiguous() - q_ref = q.repeat_interleave(group_size, dim=2).contiguous() - else: - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( - B, T, H, D, device, cu_seqlens=cu_seqlens - ) - q_ref, k_ref = q, k + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( + B, T, HQK, D, device, cu_seqlens=cu_seqlens, num_v_heads=HV + ) + # FLA reference: replicate q/k to HV heads when in GVA mode + q_ref = q.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else q + k_ref = k.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else k common_fla = dict( q=q_ref, k=k_ref, v=v, gk=g, beta=beta, scale=scale, @@ -163,17 +158,12 @@ def benchmark_chunk_intra_varlen(): T = total_len cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - if gva_mode: - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( - 1, T, HQK, HV, D, device, cu_seqlens=cu_seqlens - ) - k_ref = k.repeat_interleave(group_size, dim=2).contiguous() - q_ref = q.repeat_interleave(group_size, dim=2).contiguous() - else: - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( - 1, T, H, D, device, cu_seqlens=cu_seqlens - ) - q_ref, k_ref = q, k + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( + 1, T, HQK, D, device, cu_seqlens=cu_seqlens, num_v_heads=HV + ) + # FLA reference: replicate q/k to HV heads when in GVA mode + q_ref = q.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else q + k_ref = k.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else k common_fla = dict( q=q_ref, k=k_ref, v=v, gk=g, beta=beta, scale=scale, diff --git a/benchmarks/bench_recompute_wu.py b/benchmarks/bench_recompute_wu.py index c713a15..0b5a3e1 100644 --- a/benchmarks/bench_recompute_wu.py +++ b/benchmarks/bench_recompute_wu.py @@ -27,7 +27,7 @@ from fla.ops.kda.wy_fast import recompute_w_u_fwd as fla_recompute_w_u_fwd import cula.cudac as cula_cuda -from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs, prepare_intra_inputs_gva +from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs from cula.kda.chunk_intra import chunk_kda_fwd_intra as cula_chunk_kda_fwd_intra # Constant params @@ -122,8 +122,8 @@ def prepare_recompute_wu_inputs_gva(B, T, HQK, HV, D, device, cu_seqlens=None, c HV-head space (shape [B, T, HV, BT]). Both FLA (k replicated to HV) and cuLA (k compact in HQK) receive the same Akk. """ - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs_gva( - B, T, HQK, HV, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( + B, T, HQK, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size, num_v_heads=HV ) # Use cuLA GVA intra to produce Akk in HV space. diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 05198ab..12d4a34 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -324,82 +324,45 @@ def prepare_safe_gate_inputs( ) -def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED): +def prepare_intra_inputs( + batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED, num_v_heads=None +): """Prepare preprocessed inputs ready for chunk_kda_fwd_intra. - All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility. - """ - dtype = torch.bfloat16 - scale = D ** (-0.5) - - set_seed(seed) - - q = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) - k = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) - v = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) - g_raw = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) - beta = torch.randn(batch_size, T, H, dtype=torch.float, device=device).sigmoid() - - # l2norm q, k - q, _ = l2norm_fwd(q) - k, _ = l2norm_fwd(k) + Supports both standard (HV=H) and GVA (HV > H) layouts via ``num_v_heads``: - # flatten to batch_size=1 for cu_seqlens compatibility - if batch_size != 1: - q, k, v, g_raw, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g_raw, beta)) + q, k : (batch_size_flat, T, H, D) — Q/K head space (always compact) + v : (batch_size_flat, T, HV, D) — V head space + g : (batch_size_flat, T, HV, D) — gate in V head space (after cumsum) + beta : (batch_size_flat, T, HV) — beta in V head space - # gate preprocessing - A_log = torch.randn(H, dtype=torch.float, device=device) - dt_bias = torch.randn(H * D, dtype=torch.float, device=device) - - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None - - g = kda_gate_chunk_cumsum( - g=g_raw, - A_log=A_log, - dt_bias=dt_bias, - scale=RCP_LN2, - chunk_size=chunk_size, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - lower_bound=-5.0, - ) - - return q, k, v, g, beta, scale, cu_seqlens, chunk_indices - - -def prepare_intra_inputs_gva( - batch_size, T, HQK, HV, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED -): - """Prepare preprocessed inputs for chunk_kda_fwd_intra with GVA (HV >= HQK). - - GVA layout: - q, k : (batch_size_flat, T, HQK, D) — Q/K head space - v : (batch_size_flat, T, HV, D) — V head space - g : (batch_size_flat, T, HV, D) — gate in V head space (after cumsum) - beta : (batch_size_flat, T, HV) — beta in V head space - - When HV == HQK (group_size == 1) this is identical to prepare_intra_inputs. - All tensors are flattened to batch_size=1 for cu_seqlens compatibility. + When ``num_v_heads`` is None or equal to H this matches the original non-GVA + behaviour exactly. All tensors are flattened to batch_size=1 for cu_seqlens + compatibility. """ - assert HV > 0 and HQK > 0 and HV % HQK == 0, f"HV ({HV}) must be a positive multiple of HQK ({HQK})" + HV = H if num_v_heads is None else num_v_heads + assert HV >= H and HV % H == 0, f"num_v_heads ({HV}) must be a positive multiple of H ({H})" + dtype = torch.bfloat16 scale = D ** (-0.5) set_seed(seed) - q = torch.randn(batch_size, T, HQK, D, dtype=dtype, device=device) - k = torch.randn(batch_size, T, HQK, D, dtype=dtype, device=device) + q = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) + k = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device) g_raw = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device) beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid() + # l2norm q, k q, _ = l2norm_fwd(q) k, _ = l2norm_fwd(k) + # flatten to batch_size=1 for cu_seqlens compatibility if batch_size != 1: q, k, v, g_raw, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g_raw, beta)) + # gate preprocessing — A_log / dt_bias live in HV head space A_log = torch.randn(HV, dtype=torch.float, device=device) dt_bias = torch.randn(HV * D, dtype=torch.float, device=device) From e46b178b71c5e6249a8800e0bf70b19c218abe6c Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 21 May 2026 23:18:32 +0800 Subject: [PATCH 14/14] benchmark --- benchmarks/bench_kda_chunk_intra.py | 12 +- benchmarks/bench_recompute_wu.py | 345 ++++++---------------------- 2 files changed, 66 insertions(+), 291 deletions(-) diff --git a/benchmarks/bench_kda_chunk_intra.py b/benchmarks/bench_kda_chunk_intra.py index 0a51ab9..c65c5f3 100644 --- a/benchmarks/bench_kda_chunk_intra.py +++ b/benchmarks/bench_kda_chunk_intra.py @@ -20,7 +20,7 @@ with compact q/k in HQK space. Usage: - python bench_kda_chunk_intra.py [--heads H] [--hv HV] [--disable_recompute] + python bench_kda_chunk_intra.py [--hv HV] [--disable_recompute] """ import argparse @@ -202,12 +202,6 @@ def benchmark_chunk_intra_varlen(): action="store_true", help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) - parser.add_argument( - "--heads", - type=int, - default=None, - help=f"Override number of QK heads (H). Default: {H}.", - ) parser.add_argument( "--hv", type=int, @@ -220,10 +214,6 @@ def benchmark_chunk_intra_varlen(): DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") - if args.heads is not None: - H = args.heads - HV = H # reset HV to new H unless --hv is also provided - if args.hv is not None: if args.hv < H or args.hv % H != 0: raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") diff --git a/benchmarks/bench_recompute_wu.py b/benchmarks/bench_recompute_wu.py index 0b5a3e1..dbde241 100644 --- a/benchmarks/bench_recompute_wu.py +++ b/benchmarks/bench_recompute_wu.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +bench_recompute_wu.py — Benchmark: cuLA vs FLA Triton for recompute_w_u + +Supports both standard (HV=H) and GVA (HV > H) modes via --hv / --heads flags. +In GVA mode the FLA reference replicates q/k to HV heads; cuLA operates natively +with compact q/k in HQK space. + +Usage: + python bench_recompute_wu.py [--hv HV] [--disable_recompute] +""" + import argparse import os import pathlib @@ -32,6 +43,7 @@ # Constant params B, H, D = 2, 64, 128 +HV = H # overridable via --hv; HV > H enables GVA mode BT = 64 # chunk size # Varlen benchmark params @@ -41,7 +53,6 @@ VARIANCE = 1.0 DISABLE_RECOMPUTE = False # Whether to disable recompute (compute QG in forward) -GROUP_SIZE = 1 # GVA group size: HV = GROUP_SIZE * H. 1 means no GVA. def accuracy_stats(a, b): @@ -56,42 +67,36 @@ def accuracy_stats(a, b): return rmse, rel_max, mean_diff -def prepare_recompute_wu_inputs(B, T, H, D, device, cu_seqlens=None, chunk_size=BT): - """Prepare inputs for recompute_w_u benchmarking. +def prepare_recompute_wu_inputs(B, T, device, cu_seqlens=None, chunk_size=BT): + """Prepare inputs for recompute_w_u benchmarking (handles both MHA and GVA). - Runs chunk_kda_fwd_intra (FLA) to produce Akk, then returns - all tensors needed for recompute_w_u_fwd / recompute_w_u_cuda. + Uses cuLA's GVA-aware chunk_kda_fwd_intra to produce Akk in HV head space, + which is valid for both MHA (HV=H) and GVA (HV>H) layouts. """ q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( - B, T, H, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + B, T, H, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size, num_v_heads=HV ) - # Run FLA chunk_kda_fwd_intra to get Akk (shared input for both impls) - _, _, _, _, Aqk, Akk = fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=False, + _, _, _, _, _, Akk = cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=False, ) return q, k, v, g, beta, Akk, cu_seqlens, chunk_indices def run_fla_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, disable_recompute): - """Run FLA recompute_w_u_fwd.""" + """FLA recompute_w_u reference (handles both MHA and GVA via q/k replication).""" + group_size = HV // H + k_ref = k.repeat_interleave(group_size, dim=2).contiguous() if group_size > 1 else k + q_ref = q.repeat_interleave(group_size, dim=2).contiguous() if group_size > 1 else q return fla_recompute_w_u_fwd( - k=k, + k=k_ref, v=v, beta=beta, A=Akk, - q=q if disable_recompute else None, + q=q_ref if disable_recompute else None, gk=gk, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, @@ -99,11 +104,12 @@ def run_fla_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, disa def run_cula_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, chunk_size, disable_recompute): - """Run cuLA recompute_w_u_cuda (MHA: all tensors share the same head dim).""" - w = torch.empty_like(k) + """cuLA recompute_w_u (handles both MHA and GVA; w/u/qg/kg allocated in HV head space).""" + B_flat, T, HV_out, Dv = v.shape + w = torch.empty_like(v) u = torch.empty_like(v) - qg = torch.empty_like(q) if disable_recompute else None - kg = torch.empty_like(k) if gk is not None else None + qg = torch.empty(B_flat, T, HV_out, Dv, device=q.device, dtype=q.dtype) if disable_recompute else None + kg = torch.empty_like(v) if gk is not None else None cula_cuda.recompute_w_u_cuda( k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg @@ -112,99 +118,32 @@ def run_cula_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, chu # ============================================================================== -# GVA helpers -# ============================================================================== - -def prepare_recompute_wu_inputs_gva(B, T, HQK, HV, D, device, cu_seqlens=None, chunk_size=BT): - """Prepare GVA inputs for recompute_w_u benchmarking. - - Produces Akk via cuLA's GVA-aware chunk_kda_fwd_intra so the tensor lives in - HV-head space (shape [B, T, HV, BT]). Both FLA (k replicated to HV) and cuLA - (k compact in HQK) receive the same Akk. - """ - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( - B, T, HQK, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size, num_v_heads=HV - ) - - # Use cuLA GVA intra to produce Akk in HV space. - _, _, _, _, _, Akk = cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=False, - ) - - return q, k, v, g, beta, Akk, cu_seqlens, chunk_indices - - -def run_fla_recompute_wu_gva(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, disable_recompute, group_size): - """FLA reference for GVA recompute_w_u. - - FLA does not natively support GVA, so k and q are replicated to HV heads via - repeat_interleave before the call — mirroring the strategy in bench_kda_chunk_intra.py. - """ - k_hv = k.repeat_interleave(group_size, dim=2).contiguous() - q_hv = q.repeat_interleave(group_size, dim=2).contiguous() - return fla_recompute_w_u_fwd( - k=k_hv, - v=v, - beta=beta, - A=Akk, - q=q_hv if disable_recompute else None, - gk=gk, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - ) - - -def run_cula_recompute_wu_gva(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, chunk_size, disable_recompute): - """Run cuLA recompute_w_u_cuda with GVA layout. - - k/q live in HQK head space; v/gk/beta/Akk/w/u/kg/qg all live in HV head space. - """ - B_flat, T, HV, Dv = v.shape - w = torch.empty(B_flat, T, HV, Dv, device=k.device, dtype=k.dtype) - u = torch.empty_like(v) - qg = torch.empty(B_flat, T, HV, Dv, device=q.device, dtype=q.dtype) if disable_recompute else None - kg = torch.empty(B_flat, T, HV, Dv, device=k.device, dtype=k.dtype) if gk is not None else None - - cula_cuda.recompute_w_u_cuda( - k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg - ) - return w, u, qg, kg - - -# ============================================================================== -# Uniform seqlen benchmark +# Unified uniform seqlen benchmark (handles both standard and GVA) # ============================================================================== def benchmark_recompute_wu_uniform(): device = torch.device("cuda") chunk_size = BT + gva_mode = HV > H + gva_note = f"HQK={H} HV={HV} (group_size={HV // H})" if gva_mode else f"H={H}" T_vals = [512, 1024, 4096, 8192, 16384, 32768] - print("=" * 90) + print("=" * 100) print( - f" Uniform-Length RecomputeWU Benchmark: cuLA vs FLA Triton B={B} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}" + f" Uniform-Length RecomputeWU Benchmark: cuLA vs FLA Triton " + f"B={B} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}" ) - print("=" * 90) + print("=" * 100) print( f"{'B':>4} {'T':>7} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" ) - print("─" * 90) + print("─" * 100) for T in T_vals: seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) q, k, v, g, beta, Akk, cu_seqlens, chunk_indices = prepare_recompute_wu_inputs( - B, T, H, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + B, T, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size ) # Accuracy: run once and compare @@ -215,17 +154,12 @@ def benchmark_recompute_wu_uniform(): k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE ) - # Compare w, u, qg, kg stats = {} for name, t_fla, t_cula in [ - ("w", w_fla, w_cula), - ("u", u_fla, u_cula), - ("qg", qg_fla, qg_cula), - ("kg", kg_fla, kg_cula), + ("w", w_fla, w_cula), ("u", u_fla, u_cula), ("qg", qg_fla, qg_cula), ("kg", kg_fla, kg_cula), ]: if t_fla is not None and t_cula is not None: stats[name] = accuracy_stats(t_fla, t_cula) - # Use max across all outputs for display rmse = max(s[0] for s in stats.values()) rel_max = max(s[1] for s in stats.values()) mean_diff = max(s[2] for s in stats.values()) @@ -243,27 +177,30 @@ def benchmark_recompute_wu_uniform(): f"{B:>4} {T:>7} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" ) - print("─" * 90) + print("─" * 100) # ============================================================================== -# Varlen benchmark +# Unified varlen benchmark (handles both standard and GVA) # ============================================================================== def benchmark_recompute_wu_varlen(): device = torch.device("cuda") chunk_size = BT + gva_mode = HV > H + gva_note = f"HQK={H} HV={HV} (group_size={HV // H})" if gva_mode else f"H={H}" total_len_vals = [8192, 16384, 32768, 65536] print() - print("=" * 100) + print("=" * 110) print( - f" Varlen RecomputeWU Benchmark: cuLA vs FLA Triton NUM_SEQS={NUM_SEQS} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}" + f" Varlen RecomputeWU Benchmark: cuLA vs FLA Triton " + f"NUM_SEQS={NUM_SEQS} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}" ) - print("=" * 100) + print("=" * 110) print( f"{'total_len':>10} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" ) - print("─" * 100) + print("─" * 110) for total_len in total_len_vals: seq_lens = generate_random_seq_lens(NUM_SEQS, total_len, MIN_SEQ_LEN, VARIANCE, SEED) @@ -271,7 +208,7 @@ def benchmark_recompute_wu_varlen(): cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) q, k, v, g, beta, Akk, cu_seqlens, chunk_indices = prepare_recompute_wu_inputs( - 1, T, H, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + 1, T, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size ) # Accuracy @@ -282,17 +219,12 @@ def benchmark_recompute_wu_varlen(): k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE ) - # Compare w, u, qg, kg stats = {} for name, t_fla, t_cula in [ - ("w", w_fla, w_cula), - ("u", u_fla, u_cula), - ("qg", qg_fla, qg_cula), - ("kg", kg_fla, kg_cula), + ("w", w_fla, w_cula), ("u", u_fla, u_cula), ("qg", qg_fla, qg_cula), ("kg", kg_fla, kg_cula), ]: if t_fla is not None and t_cula is not None: stats[name] = accuracy_stats(t_fla, t_cula) - # Use max across all outputs for display rmse = max(s[0] for s in stats.values()) rel_max = max(s[1] for s in stats.values()) mean_diff = max(s[2] for s in stats.values()) @@ -310,154 +242,6 @@ def benchmark_recompute_wu_varlen(): f"{total_len:>10} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" ) - print("─" * 100) - - -# ============================================================================== -# GVA uniform seqlen benchmark -# ============================================================================== -def benchmark_recompute_wu_gva_uniform(group_size: int): - """Benchmark GVA (HV > HQK) recompute_w_u: cuLA vs FLA Triton (k replicated to HV). - - FLA does not natively support GVA, so the reference replicates k/q along the - head axis to HV before calling recompute_w_u_fwd. - """ - device = torch.device("cuda") - chunk_size = BT - HQK = H - HV = HQK * group_size - T_vals = [512, 1024, 4096, 8192, 16384, 32768] - - print("=" * 100) - print( - f" GVA Uniform RecomputeWU Benchmark: cuLA vs FLA Triton " - f"B={B} HQK={HQK} HV={HV} (group_size={group_size}) D={D} disable_recompute={DISABLE_RECOMPUTE}" - ) - print("=" * 100) - print( - f"{'B':>4} {'T':>7} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" - ) - print("─" * 100) - - for T in T_vals: - seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - q, k, v, g, beta, Akk, cu_seqlens, chunk_indices = prepare_recompute_wu_inputs_gva( - B, T, HQK, HV, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size - ) - - # Accuracy: run once and compare - w_fla, u_fla, qg_fla, kg_fla = run_fla_recompute_wu_gva( - k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, DISABLE_RECOMPUTE, group_size - ) - w_cula, u_cula, qg_cula, kg_cula = run_cula_recompute_wu_gva( - k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE - ) - - stats = {} - for name, t_fla, t_cula in [ - ("w", w_fla, w_cula), - ("u", u_fla, u_cula), - ("qg", qg_fla, qg_cula), - ("kg", kg_fla, kg_cula), - ]: - if t_fla is not None and t_cula is not None: - stats[name] = accuracy_stats(t_fla, t_cula) - rmse = max(s[0] for s in stats.values()) - rel_max = max(s[1] for s in stats.values()) - mean_diff = max(s[2] for s in stats.values()) - - # Performance - ms_fla = triton.testing.do_bench( - lambda: run_fla_recompute_wu_gva( - k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, DISABLE_RECOMPUTE, group_size - ), - ) - ms_cula = triton.testing.do_bench( - lambda: run_cula_recompute_wu_gva( - k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE - ), - ) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - - print( - f"{B:>4} {T:>7} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" - ) - - print("─" * 100) - - -# ============================================================================== -# GVA varlen benchmark -# ============================================================================== -def benchmark_recompute_wu_gva_varlen(group_size: int): - """Varlen GVA benchmark for recompute_w_u: cuLA vs FLA Triton (k replicated to HV).""" - device = torch.device("cuda") - chunk_size = BT - HQK = H - HV = HQK * group_size - total_len_vals = [8192, 16384, 32768, 65536] - - print() - print("=" * 110) - print( - f" GVA Varlen RecomputeWU Benchmark: cuLA vs FLA Triton " - f"NUM_SEQS={NUM_SEQS} HQK={HQK} HV={HV} (group_size={group_size}) D={D} disable_recompute={DISABLE_RECOMPUTE}" - ) - print("=" * 110) - print( - f"{'total_len':>10} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" - ) - print("─" * 110) - - for total_len in total_len_vals: - seq_lens = generate_random_seq_lens(NUM_SEQS, total_len, MIN_SEQ_LEN, VARIANCE, SEED) - T = total_len - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - q, k, v, g, beta, Akk, cu_seqlens, chunk_indices = prepare_recompute_wu_inputs_gva( - 1, T, HQK, HV, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size - ) - - # Accuracy - w_fla, u_fla, qg_fla, kg_fla = run_fla_recompute_wu_gva( - k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, DISABLE_RECOMPUTE, group_size - ) - w_cula, u_cula, qg_cula, kg_cula = run_cula_recompute_wu_gva( - k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE - ) - - stats = {} - for name, t_fla, t_cula in [ - ("w", w_fla, w_cula), - ("u", u_fla, u_cula), - ("qg", qg_fla, qg_cula), - ("kg", kg_fla, kg_cula), - ]: - if t_fla is not None and t_cula is not None: - stats[name] = accuracy_stats(t_fla, t_cula) - rmse = max(s[0] for s in stats.values()) - rel_max = max(s[1] for s in stats.values()) - mean_diff = max(s[2] for s in stats.values()) - - # Performance - ms_fla = triton.testing.do_bench( - lambda: run_fla_recompute_wu_gva( - k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, DISABLE_RECOMPUTE, group_size - ), - ) - ms_cula = triton.testing.do_bench( - lambda: run_cula_recompute_wu_gva( - k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE - ), - ) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - - print( - f"{total_len:>10} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" - ) - print("─" * 110) @@ -469,11 +253,10 @@ def benchmark_recompute_wu_gva_varlen(group_size: int): help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) parser.add_argument( - "--group_size", + "--hv", type=int, - default=1, - help="GVA group size: HV = group_size * H. 1 (default) runs the non-GVA benchmark. " - "Values > 1 run GVA benchmarks comparing cuLA (k in HQK space) vs FLA (k replicated to HV).", + default=None, + help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H to run in GVA mode.", ) args = parser.parse_args() @@ -481,11 +264,13 @@ def benchmark_recompute_wu_gva_varlen(group_size: int): DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") - GROUP_SIZE = args.group_size + if args.hv is not None: + if args.hv < H or args.hv % H != 0: + raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") + HV = args.hv + + if HV > H: + print(f"[GVA] HV={HV} (H={H}, group_size={HV // H}x)") - if GROUP_SIZE == 1: - benchmark_recompute_wu_uniform() - benchmark_recompute_wu_varlen() - else: - benchmark_recompute_wu_gva_uniform(GROUP_SIZE) - benchmark_recompute_wu_gva_varlen(GROUP_SIZE) + benchmark_recompute_wu_uniform() + benchmark_recompute_wu_varlen()