diff --git a/benchmarks/bench_kda_chunk_intra.py b/benchmarks/bench_kda_chunk_intra.py index 29aa5ab..c65c5f3 100644 --- a/benchmarks/bench_kda_chunk_intra.py +++ b/benchmarks/bench_kda_chunk_intra.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +bench_kda_chunk_intra.py — Benchmark: cuLA vs FLA Triton for chunk_kda_fwd_intra + +Supports both standard (HV=H) and GVA (HV > H) modes via --hv / --heads flags. +In GVA mode the FLA reference replicates q/k to HV heads; cuLA operates natively +with compact q/k in HQK space. + +Usage: + python bench_kda_chunk_intra.py [--hv HV] [--disable_recompute] +""" + import argparse import os import pathlib @@ -30,6 +41,7 @@ # Constant params B, H, D = 2, 64, 128 +HV = H # overridable via --hv; HV > H enables GVA mode BT = 64 # chunk size # Varlen benchmark params @@ -54,196 +66,133 @@ def accuracy_stats(a, b): # ============================================================================== -# Uniform seqlen benchmark +# Unified uniform seqlen benchmark (handles both standard and GVA) # ============================================================================== def benchmark_chunk_intra_uniform(): device = torch.device("cuda") chunk_size = BT + HQK = H + gva_mode = HV > HQK + group_size = HV // HQK T_vals = [512, 1024, 4096, 8192, 16384, 32768] - print("=" * 90) + gva_note = f"HQK={HQK} HV={HV} (group_size={group_size})" if gva_mode else f"H={HQK}" + print("=" * 100) print( - f" Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton B={B} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}" + f" Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton " + f"B={B} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}" ) - print("=" * 90) + print("=" * 100) print( f"{'B':>4} {'T':>7} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" ) - print("─" * 90) + print("─" * 100) for T in T_vals: seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens) - - # Accuracy: run once and compare - out_fla = fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( + B, T, HQK, D, device, cu_seqlens=cu_seqlens, num_v_heads=HV ) - out_cula = cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, + # FLA reference: replicate q/k to HV heads when in GVA mode + q_ref = q.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else q + k_ref = k.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else k + + common_fla = dict( + q=q_ref, k=k_ref, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, ) - # Compare the first output tensor (o) + common_cula = dict( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ) + + # Accuracy: run once and compare + out_fla = fla_chunk_kda_fwd_intra(**common_fla) + out_cula = cula_chunk_kda_fwd_intra(**common_cula) o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) # Performance - ms_fla = triton.testing.do_bench( - lambda: fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ), - ) - ms_cula = triton.testing.do_bench( - lambda: cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ), - ) + ms_fla = triton.testing.do_bench(lambda: fla_chunk_kda_fwd_intra(**common_fla)) + ms_cula = triton.testing.do_bench(lambda: cula_chunk_kda_fwd_intra(**common_cula)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") print( f"{B:>4} {T:>7} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" ) - print("─" * 90) + print("─" * 100) # ============================================================================== -# Varlen benchmark +# Unified varlen benchmark (handles both standard and GVA) # ============================================================================== def benchmark_chunk_intra_varlen(): device = torch.device("cuda") chunk_size = BT + HQK = H + gva_mode = HV > HQK + group_size = HV // HQK total_len_vals = [8192, 16384, 32768, 65536] + gva_note = f"HQK={HQK} HV={HV} (group_size={group_size})" if gva_mode else f"H={HQK}" print() - print("=" * 100) + print("=" * 110) print( - f" Varlen ChunkIntra Benchmark: cuLA vs FLA Triton NUM_SEQS={NUM_SEQS} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}" + f" Varlen ChunkIntra Benchmark: cuLA vs FLA Triton " + f"NUM_SEQS={NUM_SEQS} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}" ) - print("=" * 100) + print("=" * 110) print( f"{'total_len':>10} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" ) - print("─" * 100) + print("─" * 110) for total_len in total_len_vals: seq_lens = generate_random_seq_lens(NUM_SEQS, total_len, MIN_SEQ_LEN, VARIANCE, SEED) T = total_len cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens) - - # Accuracy - out_fla = fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( + 1, T, HQK, D, device, cu_seqlens=cu_seqlens, num_v_heads=HV ) - out_cula = cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, + # FLA reference: replicate q/k to HV heads when in GVA mode + q_ref = q.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else q + k_ref = k.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else k + + common_fla = dict( + q=q_ref, k=k_ref, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, ) + common_cula = dict( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=DISABLE_RECOMPUTE, + ) + + # Accuracy + out_fla = fla_chunk_kda_fwd_intra(**common_fla) + out_cula = cula_chunk_kda_fwd_intra(**common_cula) o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) # Performance - ms_fla = triton.testing.do_bench( - lambda: fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ), - ) - ms_cula = triton.testing.do_bench( - lambda: cula_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=DISABLE_RECOMPUTE, - ), - ) + ms_fla = triton.testing.do_bench(lambda: fla_chunk_kda_fwd_intra(**common_fla)) + ms_cula = triton.testing.do_bench(lambda: cula_chunk_kda_fwd_intra(**common_cula)) speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") print( f"{total_len:>10} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" ) - print("─" * 100) + print("─" * 110) if __name__ == "__main__": @@ -253,11 +202,25 @@ def benchmark_chunk_intra_varlen(): action="store_true", help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) + parser.add_argument( + "--hv", + type=int, + default=None, + help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H to run in GVA mode.", + ) args = parser.parse_args() if args.disable_recompute: DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") + if args.hv is not None: + if args.hv < H or args.hv % H != 0: + raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") + HV = args.hv + + if HV > H: + print(f"[GVA] HV={HV} (H={H}, group_size={HV // H}x)") + benchmark_chunk_intra_uniform() benchmark_chunk_intra_varlen() diff --git a/benchmarks/bench_recompute_wu.py b/benchmarks/bench_recompute_wu.py index c7d0873..dbde241 100644 --- a/benchmarks/bench_recompute_wu.py +++ b/benchmarks/bench_recompute_wu.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +bench_recompute_wu.py — Benchmark: cuLA vs FLA Triton for recompute_w_u + +Supports both standard (HV=H) and GVA (HV > H) modes via --hv / --heads flags. +In GVA mode the FLA reference replicates q/k to HV heads; cuLA operates natively +with compact q/k in HQK space. + +Usage: + python bench_recompute_wu.py [--hv HV] [--disable_recompute] +""" + import argparse import os import pathlib @@ -28,9 +39,11 @@ import cula.cudac as cula_cuda from benchmarks.utils import SEED, exclusive_cumsum, generate_random_seq_lens, prepare_intra_inputs +from cula.kda.chunk_intra import chunk_kda_fwd_intra as cula_chunk_kda_fwd_intra # Constant params B, H, D = 2, 64, 128 +HV = H # overridable via --hv; HV > H enables GVA mode BT = 64 # chunk size # Varlen benchmark params @@ -54,42 +67,36 @@ def accuracy_stats(a, b): return rmse, rel_max, mean_diff -def prepare_recompute_wu_inputs(B, T, H, D, device, cu_seqlens=None, chunk_size=BT): - """Prepare inputs for recompute_w_u benchmarking. +def prepare_recompute_wu_inputs(B, T, device, cu_seqlens=None, chunk_size=BT): + """Prepare inputs for recompute_w_u benchmarking (handles both MHA and GVA). - Runs chunk_kda_fwd_intra (FLA) to produce Akk, then returns - all tensors needed for recompute_w_u_fwd / recompute_w_u_cuda. + Uses cuLA's GVA-aware chunk_kda_fwd_intra to produce Akk in HV head space, + which is valid for both MHA (HV=H) and GVA (HV>H) layouts. """ q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs( - B, T, H, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + B, T, H, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size, num_v_heads=HV ) - # Run FLA chunk_kda_fwd_intra to get Akk (shared input for both impls) - _, _, _, _, Aqk, Akk = fla_chunk_kda_fwd_intra( - q=q, - k=k, - v=v, - gk=g, - beta=beta, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - chunk_indices=chunk_indices, - safe_gate=True, - disable_recompute=False, + _, _, _, _, _, Akk = cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=False, ) return q, k, v, g, beta, Akk, cu_seqlens, chunk_indices def run_fla_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, disable_recompute): - """Run FLA recompute_w_u_fwd.""" + """FLA recompute_w_u reference (handles both MHA and GVA via q/k replication).""" + group_size = HV // H + k_ref = k.repeat_interleave(group_size, dim=2).contiguous() if group_size > 1 else k + q_ref = q.repeat_interleave(group_size, dim=2).contiguous() if group_size > 1 else q return fla_recompute_w_u_fwd( - k=k, + k=k_ref, v=v, beta=beta, A=Akk, - q=q if disable_recompute else None, + q=q_ref if disable_recompute else None, gk=gk, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, @@ -97,11 +104,12 @@ def run_fla_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, disa def run_cula_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, chunk_size, disable_recompute): - """Run cuLA recompute_w_u_cuda.""" - w = torch.empty_like(k) + """cuLA recompute_w_u (handles both MHA and GVA; w/u/qg/kg allocated in HV head space).""" + B_flat, T, HV_out, Dv = v.shape + w = torch.empty_like(v) u = torch.empty_like(v) - qg = torch.empty_like(q) if disable_recompute else None - kg = torch.empty_like(k) if gk is not None else None + qg = torch.empty(B_flat, T, HV_out, Dv, device=q.device, dtype=q.dtype) if disable_recompute else None + kg = torch.empty_like(v) if gk is not None else None cula_cuda.recompute_w_u_cuda( k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg @@ -110,29 +118,32 @@ def run_cula_recompute_wu(k, v, beta, Akk, q, gk, cu_seqlens, chunk_indices, chu # ============================================================================== -# Uniform seqlen benchmark +# Unified uniform seqlen benchmark (handles both standard and GVA) # ============================================================================== def benchmark_recompute_wu_uniform(): device = torch.device("cuda") chunk_size = BT + gva_mode = HV > H + gva_note = f"HQK={H} HV={HV} (group_size={HV // H})" if gva_mode else f"H={H}" T_vals = [512, 1024, 4096, 8192, 16384, 32768] - print("=" * 90) + print("=" * 100) print( - f" Uniform-Length RecomputeWU Benchmark: cuLA vs FLA Triton B={B} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}" + f" Uniform-Length RecomputeWU Benchmark: cuLA vs FLA Triton " + f"B={B} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}" ) - print("=" * 90) + print("=" * 100) print( f"{'B':>4} {'T':>7} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" ) - print("─" * 90) + print("─" * 100) for T in T_vals: seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) q, k, v, g, beta, Akk, cu_seqlens, chunk_indices = prepare_recompute_wu_inputs( - B, T, H, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + B, T, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size ) # Accuracy: run once and compare @@ -143,17 +154,12 @@ def benchmark_recompute_wu_uniform(): k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE ) - # Compare w, u, qg, kg stats = {} for name, t_fla, t_cula in [ - ("w", w_fla, w_cula), - ("u", u_fla, u_cula), - ("qg", qg_fla, qg_cula), - ("kg", kg_fla, kg_cula), + ("w", w_fla, w_cula), ("u", u_fla, u_cula), ("qg", qg_fla, qg_cula), ("kg", kg_fla, kg_cula), ]: if t_fla is not None and t_cula is not None: stats[name] = accuracy_stats(t_fla, t_cula) - # Use max across all outputs for display rmse = max(s[0] for s in stats.values()) rel_max = max(s[1] for s in stats.values()) mean_diff = max(s[2] for s in stats.values()) @@ -171,27 +177,30 @@ def benchmark_recompute_wu_uniform(): f"{B:>4} {T:>7} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" ) - print("─" * 90) + print("─" * 100) # ============================================================================== -# Varlen benchmark +# Unified varlen benchmark (handles both standard and GVA) # ============================================================================== def benchmark_recompute_wu_varlen(): device = torch.device("cuda") chunk_size = BT + gva_mode = HV > H + gva_note = f"HQK={H} HV={HV} (group_size={HV // H})" if gva_mode else f"H={H}" total_len_vals = [8192, 16384, 32768, 65536] print() - print("=" * 100) + print("=" * 110) print( - f" Varlen RecomputeWU Benchmark: cuLA vs FLA Triton NUM_SEQS={NUM_SEQS} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}" + f" Varlen RecomputeWU Benchmark: cuLA vs FLA Triton " + f"NUM_SEQS={NUM_SEQS} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}" ) - print("=" * 100) + print("=" * 110) print( f"{'total_len':>10} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}" ) - print("─" * 100) + print("─" * 110) for total_len in total_len_vals: seq_lens = generate_random_seq_lens(NUM_SEQS, total_len, MIN_SEQ_LEN, VARIANCE, SEED) @@ -199,7 +208,7 @@ def benchmark_recompute_wu_varlen(): cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) q, k, v, g, beta, Akk, cu_seqlens, chunk_indices = prepare_recompute_wu_inputs( - 1, T, H, D, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size + 1, T, device, cu_seqlens=cu_seqlens, chunk_size=chunk_size ) # Accuracy @@ -210,17 +219,12 @@ def benchmark_recompute_wu_varlen(): k, v, beta, Akk, q, g, cu_seqlens, chunk_indices, chunk_size, DISABLE_RECOMPUTE ) - # Compare w, u, qg, kg stats = {} for name, t_fla, t_cula in [ - ("w", w_fla, w_cula), - ("u", u_fla, u_cula), - ("qg", qg_fla, qg_cula), - ("kg", kg_fla, kg_cula), + ("w", w_fla, w_cula), ("u", u_fla, u_cula), ("qg", qg_fla, qg_cula), ("kg", kg_fla, kg_cula), ]: if t_fla is not None and t_cula is not None: stats[name] = accuracy_stats(t_fla, t_cula) - # Use max across all outputs for display rmse = max(s[0] for s in stats.values()) rel_max = max(s[1] for s in stats.values()) mean_diff = max(s[2] for s in stats.values()) @@ -238,7 +242,7 @@ def benchmark_recompute_wu_varlen(): f"{total_len:>10} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x" ) - print("─" * 100) + print("─" * 110) if __name__ == "__main__": @@ -248,11 +252,25 @@ def benchmark_recompute_wu_varlen(): action="store_true", help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) + parser.add_argument( + "--hv", + type=int, + default=None, + help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H to run in GVA mode.", + ) args = parser.parse_args() if args.disable_recompute: DISABLE_RECOMPUTE = True print("[Disable recompute] pre-compute QG in forward") + if args.hv is not None: + if args.hv < H or args.hv % H != 0: + raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}") + HV = args.hv + + if HV > H: + print(f"[GVA] HV={HV} (H={H}, group_size={HV // H}x)") + benchmark_recompute_wu_uniform() benchmark_recompute_wu_varlen() diff --git a/benchmarks/utils.py b/benchmarks/utils.py index bfd0761..12d4a34 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -324,11 +324,25 @@ def prepare_safe_gate_inputs( ) -def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED): +def prepare_intra_inputs( + batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED, num_v_heads=None +): """Prepare preprocessed inputs ready for chunk_kda_fwd_intra. - All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility. + Supports both standard (HV=H) and GVA (HV > H) layouts via ``num_v_heads``: + + q, k : (batch_size_flat, T, H, D) — Q/K head space (always compact) + v : (batch_size_flat, T, HV, D) — V head space + g : (batch_size_flat, T, HV, D) — gate in V head space (after cumsum) + beta : (batch_size_flat, T, HV) — beta in V head space + + When ``num_v_heads`` is None or equal to H this matches the original non-GVA + behaviour exactly. All tensors are flattened to batch_size=1 for cu_seqlens + compatibility. """ + HV = H if num_v_heads is None else num_v_heads + assert HV >= H and HV % H == 0, f"num_v_heads ({HV}) must be a positive multiple of H ({H})" + dtype = torch.bfloat16 scale = D ** (-0.5) @@ -336,9 +350,9 @@ def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_siz q = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) k = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) - v = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) - g_raw = torch.randn(batch_size, T, H, D, dtype=dtype, device=device) - beta = torch.randn(batch_size, T, H, dtype=torch.float, device=device).sigmoid() + v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device) + g_raw = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device) + beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid() # l2norm q, k q, _ = l2norm_fwd(q) @@ -348,9 +362,9 @@ def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_siz if batch_size != 1: q, k, v, g_raw, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g_raw, beta)) - # gate preprocessing - A_log = torch.randn(H, dtype=torch.float, device=device) - dt_bias = torch.randn(H * D, dtype=torch.float, device=device) + # gate preprocessing — A_log / dt_bias live in HV head space + A_log = torch.randn(HV, dtype=torch.float, device=device) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device) chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None diff --git a/csrc/api/kda_sm100.cu b/csrc/api/kda_sm100.cu index ac32411..ff89887 100644 --- a/csrc/api/kda_sm100.cu +++ b/csrc/api/kda_sm100.cu @@ -37,7 +37,33 @@ ChunkKDAFwdIntra( KDA_fwd_intra_params params; params.total_q_len = q.size(0) * q.size(1); params.b = cu_seqlens.size(0) - 1; - params.h = q.size(2); + // GVA: Q/K are in h_qk head space (from q.size(2)); g/beta/Aqk/Akk are in h_v head + // space (from g.size(2)). When HV == HQK, heads_per_group == 1 and behaviour matches + // the pre-GVA path. + params.h_qk = q.size(2); + params.h_v = g.size(2); + TORCH_CHECK( + k.size(2) == params.h_qk, + "ChunkKDAFwdIntra: k.size(2) (", + k.size(2), + ") must match q.size(2) (", + params.h_qk, + ") under GVA (Q/K share h_qk)."); + TORCH_CHECK( + beta.size(-1) == params.h_v, + "ChunkKDAFwdIntra: beta.size(-1) (", + beta.size(-1), + ") must equal h_v (", + params.h_v, + ")."); + TORCH_CHECK( + params.h_qk > 0 && params.h_v > 0 && params.h_v % params.h_qk == 0, + "ChunkKDAFwdIntra: h_v (", + params.h_v, + ") must be a positive multiple of h_qk (", + params.h_qk, + ")."); + params.heads_per_group = params.h_v / params.h_qk; params.d = q.size(3); params.chunk_size = chunk_size; params.scale = scale; @@ -56,13 +82,15 @@ ChunkKDAFwdIntra( params.chunk_indices_ptr = chunk_indices.data_ptr(); params.Aqk_out_ptr = Aqk_out.data_ptr(); params.Akk_out_ptr = Akk_out.data_ptr(); - params.shape_Akk = cute::make_shape(params.total_q_len, params.chunk_size, params.h); - params.stride_Akk = cute::make_stride(params.chunk_size * params.h, cute::_1{}, params.chunk_size); + // Akk is laid out per v-head: (total_len, chunk_size, h_v). + params.shape_Akk = cute::make_shape(params.total_q_len, params.chunk_size, params.h_v); + params.stride_Akk = cute::make_stride(params.chunk_size * params.h_v, cute::_1{}, params.chunk_size); int tile_num = chunk_indices.size(0); auto device_prop = at::cuda::getCurrentDeviceProperties(); params.num_sm = device_prop->multiProcessorCount; - params.tile_scheduler_params = - StaticPersistentTileScheduler::Params{tile_num, params.h, params.num_sm, (int*)tile_counter.data_ptr()}; + // Tiles are enumerated in v-head space. + params.tile_scheduler_params = StaticPersistentTileScheduler::Params{ + tile_num, params.h_v, params.heads_per_group, params.num_sm, (int*)tile_counter.data_ptr()}; kda::sm100::run_kda_fwd_intra_sm100(params, at::cuda::getCurrentCUDAStream()); } @@ -85,7 +113,31 @@ ChunkKDAFwdRecompWU( KDA_fwd_recomp_w_u_params params; params.total_len = k.size(0) * k.size(1); params.b = cu_seqlens.size(0) - 1; - params.h = k.size(2); + // GVA: K (and optional Q) live in h_qk space; V/G/beta/A/w/u/kg/qg live in h_v space. + params.h_qk = k.size(2); + params.h_v = v.size(2); + TORCH_CHECK( + g.size(2) == params.h_v, + "ChunkKDAFwdRecompWU: g.size(2) (", + g.size(2), + ") must equal v.size(2) (", + params.h_v, + ")."); + TORCH_CHECK( + beta.size(-1) == params.h_v, + "ChunkKDAFwdRecompWU: beta.size(-1) (", + beta.size(-1), + ") must equal h_v (", + params.h_v, + ")."); + TORCH_CHECK( + params.h_qk > 0 && params.h_v > 0 && params.h_v % params.h_qk == 0, + "ChunkKDAFwdRecompWU: h_v (", + params.h_v, + ") must be a positive multiple of h_qk (", + params.h_qk, + ")."); + params.heads_per_group = params.h_v / params.h_qk; params.d = k.size(3); params.chunk_size = chunk_size; TORCH_CHECK( @@ -108,14 +160,32 @@ ChunkKDAFwdRecompWU( TORCH_CHECK( has_q == has_qg_out, "ChunkKDAFwdRecompWU: q and qg_out must either both be provided or both be omitted."); params.store_qg = has_q && has_qg_out; + if (params.store_qg) { + TORCH_CHECK( + q->size(2) == params.h_qk, + "ChunkKDAFwdRecompWU: q.size(2) (", + q->size(2), + ") must equal h_qk (", + params.h_qk, + ")."); + TORCH_CHECK( + qg_out->size(2) == params.h_v, + "ChunkKDAFwdRecompWU: qg_out.size(2) (", + qg_out->size(2), + ") must equal h_v (", + params.h_v, + ")."); + } params.q_ptr = params.store_qg ? q->data_ptr() : nullptr; params.qg_out_ptr = params.store_qg ? qg_out->data_ptr() : nullptr; - params.shape_wukg = cute::make_shape(params.total_len, params.d, params.h); - params.stride_wukg = cute::make_stride(params.d * params.h, cute::_1{}, params.d); + // w/u/kg/qg are per v-head: (total_len, d, h_v). + params.shape_wukg = cute::make_shape(params.total_len, params.d, params.h_v); + params.stride_wukg = cute::make_stride(params.d * params.h_v, cute::_1{}, params.d); int tile_num = chunk_indices.size(0); auto device_prop = at::cuda::getCurrentDeviceProperties(); params.num_sm = device_prop->multiProcessorCount; - params.tile_scheduler_params = StaticPersistentTileScheduler::Params{tile_num, params.h, params.num_sm, nullptr}; + params.tile_scheduler_params = StaticPersistentTileScheduler::Params{ + tile_num, params.h_v, params.heads_per_group, params.num_sm, nullptr}; kda::sm100::run_kda_fwd_recomp_w_u_sm100(params, at::cuda::getCurrentCUDAStream()); } \ No newline at end of file diff --git a/csrc/kda/sm100/kda_config.hpp b/csrc/kda/sm100/kda_config.hpp index 6f96529..67b496a 100644 --- a/csrc/kda/sm100/kda_config.hpp +++ b/csrc/kda/sm100/kda_config.hpp @@ -17,12 +17,18 @@ #include "kda/sm100/tile_scheduler.hpp" struct KDA_fwd_intra_params { - using GmemShapeAkk = cute::Shape; // (seqlen_kv, seqlen_kv, h) + // Akk shape is (total_seqlen, chunk_size, num_v_heads). Under GVA (num_v_heads > num_qk_heads), + // Aqk and Akk are produced per v-head because g/beta/Akk scaling all live in v-head space. + using GmemShapeAkk = cute::Shape; // (seqlen_kv, chunk_size, h_v) using GmemStrideAkk = cute::Stride; int total_q_len; int b; - int h; + // GVA: Q/K are sized by num_qk_heads; V, g, beta are sized by num_v_heads; Aqk/Akk are per v-head. + // When num_v_heads == num_qk_heads, heads_per_group == 1 and behaviour matches the pre-GVA path. + int h_qk; + int h_v; + int heads_per_group; // = h_v / h_qk, precomputed on host int d; int chunk_size; float scale; @@ -30,12 +36,12 @@ struct KDA_fwd_intra_params { bool unified_gref; bool is_beta_bf16; - void* __restrict__ q_ptr; //[b, t, h, d] - void* __restrict__ k_ptr; //[b, t, h, d] - void* __restrict__ g_ptr; //[b, t, h, d] - void* __restrict__ beta_ptr; //[b, t, h] - void* __restrict__ Aqk_out_ptr; //[b, t, h, BT] - void* __restrict__ Akk_out_ptr; //[b, t, h, BT] + void* __restrict__ q_ptr; //[b, t, h_qk, d] + void* __restrict__ k_ptr; //[b, t, h_qk, d] + void* __restrict__ g_ptr; //[b, t, h_v, d] + void* __restrict__ beta_ptr; //[b, t, h_v] + void* __restrict__ Aqk_out_ptr; //[b, t, h_v, BT] + void* __restrict__ Akk_out_ptr; //[b, t, h_v, BT] void* __restrict__ cu_seqlens_ptr; //[b + 1] void* __restrict__ chunk_indices_ptr; //[(b * t) / chunk_size, 2] @@ -48,28 +54,32 @@ struct KDA_fwd_intra_params { }; struct KDA_fwd_recomp_w_u_params { - using GmemShapeWUKg = cute::Shape; // (seqlen_kv, seqlen_kv, h) + // w/u/kg/qg all have shape (total_seqlen, d, num_v_heads) under GVA. + using GmemShapeWUKg = cute::Shape; // (seqlen_kv, d, h_v) using GmemStrideWUKg = cute::Stride; int total_len; int b; - int h; + // GVA: K and (optional) Q are sized by num_qk_heads; V/G/beta/Akk/w/u/kg/qg are per v-head. + int h_qk; + int h_v; + int heads_per_group; // = h_v / h_qk, precomputed on host int d; int chunk_size; bool is_beta_bf16; - void* __restrict__ k_ptr; //[b, t, h, d] - void* __restrict__ v_ptr; //[b, t, h, d] - void* __restrict__ q_ptr; //[b, t, h, d] (optional, for StoreQG) - void* __restrict__ beta_ptr; //[b, t, h] - void* __restrict__ A_ptr; //[b. t, h, BT] - void* __restrict__ g_ptr; //[b, t, h, d] + void* __restrict__ k_ptr; //[b, t, h_qk, d] + void* __restrict__ v_ptr; //[b, t, h_v, d] + void* __restrict__ q_ptr; //[b, t, h_qk, d] (optional, for StoreQG) + void* __restrict__ beta_ptr; //[b, t, h_v] + void* __restrict__ A_ptr; //[b, t, h_v, BT] + void* __restrict__ g_ptr; //[b, t, h_v, d] void* __restrict__ cu_seqlens_ptr; //[b + 1] void* __restrict__ chunk_indices_ptr; //[(b * t) / chunk_size, 2] - void* __restrict__ w_out_ptr; //[b, t, h, d] - void* __restrict__ u_out_ptr; //[b, t, h, d] - void* __restrict__ kg_out_ptr; //[b, t, h, d] - void* __restrict__ qg_out_ptr; //[b, t, h, d] (optional, for StoreQG) + void* __restrict__ w_out_ptr; //[b, t, h_v, d] + void* __restrict__ u_out_ptr; //[b, t, h_v, d] + void* __restrict__ kg_out_ptr; //[b, t, h_v, d] + void* __restrict__ qg_out_ptr; //[b, t, h_v, d] (optional, for StoreQG) bool store_qg; diff --git a/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp b/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp index f314723..f928616 100644 --- a/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp @@ -53,8 +53,8 @@ struct KdaChunkFwdIntraKernelSm100 { using SmemLayoutInputFP32 = typename Mainloop::SmemLayoutInputFP32; // TMA params (for host launcher) - template - using TmaParams = typename Mainloop::template TmaParams; + template + using TmaParams = typename Mainloop::template TmaParams; // Pipeline types (for construction in operator()) using PipelineQKG = typename Mainloop::PipelineQKG; @@ -318,29 +318,40 @@ __launch_bounds__(512, 1, 1) kda_fwd_intra_sm100_kernel_entry( template inline void run_kda_fwd_intra_sm100_impl_dispatch(KDA_fwd_intra_params& params, cudaStream_t stream) { - auto shape_QKG = make_shape(params.total_q_len, params.d, params.h); - auto stride_QKG = make_stride(params.h * params.d, _1{}, params.d); + // GVA: Q/K are sized by `h_qk`; G is sized by `h_v`. When HV == HQK + // (heads_per_group == 1), shape_QK and shape_VG coincide with the + // pre-GVA shape_QKG and behaviour is unchanged. + auto shape_QK = make_shape(params.total_q_len, params.d, params.h_qk); + auto stride_QK = make_stride(params.h_qk * params.d, _1{}, params.d); + auto shape_VG = make_shape(params.total_q_len, params.d, params.h_v); + auto stride_VG = make_stride(params.h_v * params.d, _1{}, params.d); // --- Build TMA descriptors --- auto tma_Q = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((ku::bf16*)params.q_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((ku::bf16*)params.q_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_K = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((ku::bf16*)params.k_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((ku::bf16*)params.k_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_G = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputFP32{}); // --- Pack TMA params --- - typename Kernel::template TmaParams + typename Kernel::template TmaParams< + decltype(shape_QK), + decltype(shape_VG), + decltype(tma_Q), + decltype(tma_K), + decltype(tma_G)> tma_params = { - shape_QKG, + shape_QK, + shape_VG, tma_Q, tma_K, tma_G, diff --git a/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp b/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp index 849e910..7b624eb 100644 --- a/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp @@ -226,9 +226,13 @@ struct KdaChunkFwdIntraMainloopSm100 { }; // ===================== TMA Params ===================== - template + // GVA: Q/K live in h_qk head space (shape_qk), while G lives in h_v + // head space (shape_vg). When h_v == h_qk both shapes coincide and the + // TMA descriptors degrade to the pre-GVA behaviour. + template struct TmaParams { - ShapeQKG shape_qkg; + ShapeQK shape_qk; + ShapeVG shape_vg; TMA_Q tma_q; TMA_K tma_k; TMA_G tma_g; @@ -317,7 +321,10 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // head_idx here is the v-head index (Aqk/Akk/beta/g live in v-head space). + // qk_head_idx is only consumed by the TMA load warp for Q/K slicing. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -501,7 +508,8 @@ struct KdaChunkFwdIntraMainloopSm100 { int token_offset = cu_seqlens_ptr[batch_idx]; int row = idx_in_warpgroup % 64; int BT = TileT; - int H = params.h; + // Aqk is laid out per v-head: row-stride is h_v * BT, head slot offset is head_idx * BT. + int H = params.h_v; __nv_bfloat16* Aqk_base = reinterpret_cast<__nv_bfloat16*>(params.Aqk_out_ptr); __nv_bfloat16* qk_out_row = Aqk_base + static_cast(token_offset + tile_idx * TileT + row) * H * BT + head_idx * BT; @@ -567,7 +575,10 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // MMA loop does not actually consume head_idx, but we decode to advance the + // same tile space as the other warps (num_blocks * num_v_heads). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -702,21 +713,24 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - // Decode tile coordinates - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Decode tile coordinates. head_idx is the v-head index (used for G), + // and qk_head_idx is the companion Q/K head (computed from heads_per_group). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); - int head_idx = get<1>(blk_coord); + int head_idx = get<1>(blk_coord); // v-head index int tile_idx = get<2>(blk_coord); + int qk_head_idx = get<3>(blk_coord); // == head_idx / heads_per_group int token_offset = cu_seqlens_ptr[batch_idx]; int seq_len = cu_seqlens_ptr[batch_idx + 1] - cu_seqlens_ptr[batch_idx]; int sub_seq_len = min(TileT, seq_len - tile_idx * TileT); Tensor mQ = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_q.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_q.get_tma_tensor(tma_params.shape_qk)); Tensor mK = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qk)); Tensor mG = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_vg)); // TMA load body (Q, K, G — unified pipeline, single barrier per stage) CUTE_NO_UNROLL @@ -726,12 +740,13 @@ struct KdaChunkFwdIntraMainloopSm100 { Tensor sK = make_tensor(make_smem_ptr(shared_plan->k[buf_idx].data()), SmemLayoutInputBF16{}); Tensor sG = make_tensor(make_smem_ptr(shared_plan->g[buf_idx].data()), SmemLayoutInputFP32{}); + // GVA: K and Q are sliced by qk_head_idx; G is sliced by head_idx (v-head). Tensor gK = local_tile( - mK(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); + mK(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); Tensor gG = local_tile( mG(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); Tensor gQ = local_tile( - mQ(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); + mQ(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); // Single acquire for all three TMA copies qkg_load_pipeline.producer_acquire(qkg_load_pipe_state_write); @@ -768,7 +783,9 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Akk is laid out per v-head (params.shape_Akk uses h_v), so we index by head_idx. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -881,7 +898,9 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // beta is per v-head: layout (total_seqlen, h_v), row stride = h_v. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -895,7 +914,7 @@ struct KdaChunkFwdIntraMainloopSm100 { shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = (thread_idx < sub_seq_len) ? float(reinterpret_cast( - params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) + params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h_v + head_idx]) : float(0); } fence_view_async_shared(); diff --git a/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp b/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp index 73cb408..2cae6c0 100644 --- a/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp @@ -41,14 +41,16 @@ struct KdaChunkFwdRecompWUKernelSm100 { // TMA params (for host launcher) template < - typename ShapeKVG, + typename ShapeQK, + typename ShapeVG, typename ShapeAkk, typename TMA_V, typename TMA_K, typename TMA_G, typename TMA_Akk, typename TMA_Q = int> - using TmaParams = typename Mainloop::template TmaParams; + using TmaParams = + typename Mainloop::template TmaParams; // Pipeline types (for construction in operator()) using PipelineA = typename Mainloop::PipelineA; @@ -431,25 +433,29 @@ __launch_bounds__(384, 1, 1) kda_fwd_recomp_w_u_sm100_kernel_entry( template inline void run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cudaStream_t stream) { - auto shape_KVG = make_shape(params.total_len, params.d, params.h); - auto stride_KVG = make_stride(params.h * params.d, _1{}, params.d); - auto shape_Akk = make_shape(params.total_len, params.chunk_size, params.h); - auto stride_Akk = make_stride(params.h * params.chunk_size, _1{}, params.chunk_size); + // GVA: K and (optional) Q are sized by h_qk; V and G are sized by h_v. + // Akk lives in v-head space (BT x BT per v-head). + auto shape_QK = make_shape(params.total_len, params.d, params.h_qk); + auto stride_QK = make_stride(params.h_qk * params.d, _1{}, params.d); + auto shape_VG = make_shape(params.total_len, params.d, params.h_v); + auto stride_VG = make_stride(params.h_v * params.d, _1{}, params.d); + auto shape_Akk = make_shape(params.total_len, params.chunk_size, params.h_v); + auto stride_Akk = make_stride(params.h_v * params.chunk_size, _1{}, params.chunk_size); // --- Build TMA descriptors --- auto tma_V = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.v_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.v_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputBF16{}); auto tma_K = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.k_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.k_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_G = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputFP32{}); auto tma_Akk = cute::make_tma_copy( @@ -457,12 +463,12 @@ run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cu make_tensor(make_gmem_ptr((bf16*)params.A_ptr), make_layout(shape_Akk, stride_Akk)), typename Kernel::SmemLayoutInputAkkBF16{}); - // Q TMA descriptor (only meaningful when StoreQG=true) + // Q TMA descriptor (only meaningful when StoreQG=true). Q lives in h_qk head space. auto tma_Q = [&]() { if constexpr (Kernel::StoreQG) { return cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.q_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.q_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); } else { return 0; // placeholder, not used @@ -471,14 +477,15 @@ run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cu // --- Pack TMA params --- typename Kernel::template TmaParams< - decltype(shape_KVG), + decltype(shape_QK), + decltype(shape_VG), decltype(shape_Akk), decltype(tma_V), decltype(tma_K), decltype(tma_G), decltype(tma_Akk), decltype(tma_Q)> - tma_params = {shape_KVG, shape_Akk, tma_V, tma_K, tma_G, tma_Akk, tma_Q}; + tma_params = {shape_QK, shape_VG, shape_Akk, tma_V, tma_K, tma_G, tma_Akk, tma_Q}; // --- Launch config --- auto kernel_fn = &kda_fwd_recomp_w_u_sm100_kernel_entry; diff --git a/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp b/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp index 0f6a66f..53f482c 100644 --- a/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp @@ -190,8 +190,11 @@ struct KdaChunkFwdRecompWUMainloopSm100 { }; // ===================== TMA Params ===================== + // GVA: K and (optional) Q live in h_qk head space (shape_qk), while V + // and G live in h_v head space (shape_vg). Akk is per v-head. template < - typename ShapeKVG, + typename ShapeQK, + typename ShapeVG, typename ShapeAkk, typename TMA_V, typename TMA_K, @@ -199,7 +202,8 @@ struct KdaChunkFwdRecompWUMainloopSm100 { typename TMA_Akk, typename TMA_Q = int> struct TmaParams { - ShapeKVG shape_kvg; + ShapeQK shape_qk; + ShapeVG shape_vg; ShapeAkk shape_Akk; TMA_V tma_v; TMA_K tma_k; @@ -255,7 +259,10 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Prologue touches K (h_qk) and G (h_v) + beta (h_v) + optional Q (h_qk). + // head_idx is the v-head index; qk_head_idx is derived via heads_per_group. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -635,7 +642,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Epilogue consumes V/beta (both h_v) and writes w/u/kg/qg (all h_v). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -735,9 +744,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { // each thread processes one row of W/U (TileK columns) int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16); - // GMEM output address: layout [total_len, d, h], stride [d*h, 1, d] + // GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d] __nv_bfloat16* out_row_base = - out_ptr_base + (token_offset_cur + row) * params.d * params.h + head_idx * params.d; + out_ptr_base + (token_offset_cur + row) * params.d * params.h_v + head_idx * params.d; constexpr int QuarK = TileK / 4; @@ -799,7 +808,8 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { // int tid = tile_scheduler.get_current_tile_id(); - // auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h_v, params.heads_per_group, + // chunk_indices_ptr, cu_seqlens_ptr); // ============================================================ // Once per WU: Wait for Akk in SMEM (from Load warp) @@ -879,31 +889,36 @@ struct KdaChunkFwdRecompWUMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - // Decode tile coordinates - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Decode tile coordinates. head_idx is the v-head (used for V/G/Akk + // TMA loads); qk_head_idx (= head_idx / heads_per_group) is used for + // K/Q TMA loads under GVA. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); - int head_idx = get<1>(blk_coord); + int head_idx = get<1>(blk_coord); // v-head int tile_idx = get<2>(blk_coord); + int qk_head_idx = get<3>(blk_coord); // qk-head int token_offset = cu_seqlens_ptr[batch_idx]; int seq_len = cu_seqlens_ptr[batch_idx + 1] - cu_seqlens_ptr[batch_idx]; int sub_seq_len = min(TileT, seq_len - tile_idx * TileT); // Build GMEM tensor views (with domain offset for batch) + // K and Q live in h_qk head space (shape_qk); V, G and Akk live in h_v space. Tensor mK = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qk)); Tensor mV = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_v.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_v.get_tma_tensor(tma_params.shape_vg)); Tensor mG = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_vg)); Tensor mA = domain_offset( make_coord(token_offset, _0{}, _0{}), tma_params.tma_akk.get_tma_tensor(tma_params.shape_Akk)); - // Q GMEM tensor (only used when StoreQG=true) + // Q GMEM tensor (only used when StoreQG=true). Q lives in h_qk space. [[maybe_unused]] auto mQ = [&]() { if constexpr (StoreQG) { return domain_offset( make_coord(token_offset, _0{}, _0{}), - tma_params.tma_q.get_tma_tensor(tma_params.shape_kvg)); + tma_params.tma_q.get_tma_tensor(tma_params.shape_qk)); } else { return 0; // unused placeholder } @@ -936,8 +951,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { Tensor sG = make_tensor( make_smem_ptr(shared_plan->g[g_pipe_state_write.index()].data()), SmemLayoutInputFP32{}); + // GVA slicing: K uses qk_head_idx; V and G use the v-head index. Tensor gK = local_tile( - mK(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); + mK(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); Tensor gV = local_tile( mV(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); Tensor gG = local_tile( @@ -963,8 +979,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { Tensor sQ = make_tensor( make_smem_ptr(shared_plan->q_buf.q[q_pipe_state_write.index()].data()), SmemLayoutInputBF16{}); + // Q (StoreQG) lives in h_qk space → slice by qk_head_idx. Tensor gQ = local_tile( - mQ(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); + mQ(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); q_pipeline.producer_acquire(q_pipe_state_write); ku::launch_tma_copy( tma_params.tma_q, gQ, sQ, *q_pipeline.producer_get_barrier(q_pipe_state_write)); @@ -997,7 +1014,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // LoadAux: beta is per v-head (row stride = h_v). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -1013,7 +1032,7 @@ struct KdaChunkFwdRecompWUMainloopSm100 { float beta_val = (thread_idx < sub_seq_len) ? float(reinterpret_cast( - params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) + params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h_v + head_idx]) : float(0); shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = beta_val; } diff --git a/csrc/kda/sm100/tile_scheduler.hpp b/csrc/kda/sm100/tile_scheduler.hpp index 47044aa..695bb26 100644 --- a/csrc/kda/sm100/tile_scheduler.hpp +++ b/csrc/kda/sm100/tile_scheduler.hpp @@ -26,11 +26,20 @@ // No smem synchronization needed — every CTA processes tiles starting // at blockIdx.x and striding by gridDim.x. All warps within a CTA // independently maintain the same tile_id, so no tile pipeline is needed. +// +// GVA (Grouped V-head Attention) support: +// Q/K are sized by `num_qk_heads`; V, g, beta, O and state tensors are +// sized by `num_v_heads`. We enumerate tiles by `num_v_heads` so that +// each v-head is scheduled independently, and derive the companion +// `qk_head_idx = v_head_idx / heads_per_group` on the device side. +// `heads_per_group = num_v_heads / num_qk_heads` is precomputed on the +// host to avoid a per-tile integer division. // =================================================================== struct StaticPersistentTileScheduler { struct Params { - int num_blocks; // number of sequence chunks (from chunk_indices) - int num_heads; + int num_blocks; // number of sequence chunks (from chunk_indices) + int num_heads; // == num_v_heads; tiles are enumerated by v-head + int heads_per_group; // == num_v_heads / num_qk_heads, precomputed on host int num_sm; int* tile_counter; // unused @@ -77,14 +86,22 @@ struct StaticPersistentTileScheduler { return current_tile_id < total_tiles(); } + // Decode tile_id -> (batch_idx, v_head_idx, seq_idx, qk_head_idx). + // `num_v_heads` is the number of V/O/g/beta heads; tile enumeration is + // done in v-head space. `heads_per_group` (= num_v_heads/num_qk_heads) + // is used to derive the companion Q/K head index for GVA. + // For backward compatibility, when HV == HQK, `heads_per_group == 1` + // and `qk_head_idx == v_head_idx`. CUTLASS_DEVICE static auto - decode_tile_coord(int tile_id, int num_heads, int* chunk_indices_ptr, int* cu_seqlens_ptr) { + decode_tile_coord( + int tile_id, int num_v_heads, int heads_per_group, int* chunk_indices_ptr, int* /*cu_seqlens_ptr*/) { using namespace cute; - int tile_idx_raw = tile_id / num_heads; - int head_idx = tile_id % num_heads; + int tile_idx_raw = tile_id / num_v_heads; + int v_head_idx = tile_id % num_v_heads; + int qk_head_idx = v_head_idx / heads_per_group; int batch_idx = chunk_indices_ptr[tile_idx_raw * 2]; int seq_idx = chunk_indices_ptr[tile_idx_raw * 2 + 1]; - return make_coord(batch_idx, head_idx, seq_idx, 0); + return make_coord(batch_idx, v_head_idx, seq_idx, qk_head_idx); } }; \ No newline at end of file diff --git a/cula/kda/chunk_intra.py b/cula/kda/chunk_intra.py index 0703638..aeb063f 100644 --- a/cula/kda/chunk_intra.py +++ b/cula/kda/chunk_intra.py @@ -759,7 +759,12 @@ def chunk_kda_fwd_intra( unified_gref: bool = False, # Set True for ~5% extra perf (slightly lower precision) ): assert safe_gate, "Only safe_gate=True is supported in chunk_kda_fwd_intra for now" - B, T, H, K = k.shape + B, T, H_QK, K = k.shape + # GVA: g/beta/v live in h_v head space; q/k live in h_qk head space. + H_V = v.size(2) + assert H_QK > 0 and H_V > 0 and H_V % H_QK == 0, ( + f"HV ({H_V}) must be a positive multiple of HQK ({H_QK})" + ) BT = chunk_size if cu_seqlens is None: @@ -773,18 +778,20 @@ def chunk_kda_fwd_intra( "cu_seqlens and chunk_indices must be int32 for cuda impl" ) - Aqk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) - Akk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) + # Aqk and Akk are produced per v-head by the intra kernel. + Aqk = torch.empty(B, T, H_V, BT, device=k.device, dtype=k.dtype) + Akk = torch.empty(B, T, H_V, BT, device=k.device, dtype=k.dtype) tile_counter = torch.zeros(1, dtype=torch.int32, device=q.device) cula_cuda.chunk_kda_fwd_intra_cuda( q, k, gk, beta, cu_seqlens, chunk_indices, Aqk, Akk, tile_counter, scale, chunk_size, use_tf32_inverse, unified_gref ) - w = torch.empty_like(k) + # w, u, kg, qg all live in h_v head space. + w = torch.empty_like(v) u = torch.empty_like(v) - qg = torch.empty_like(q) if disable_recompute else None - kg = torch.empty_like(k) if gk is not None else None + qg = torch.empty(B, T, H_V, K, device=q.device, dtype=q.dtype) if disable_recompute else None + kg = torch.empty(B, T, H_V, K, device=k.device, dtype=k.dtype) if gk is not None else None cula_cuda.recompute_w_u_cuda( k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg @@ -857,4 +864,4 @@ def chunk_kda_bwd_intra( db = db2.sum(0).add_(db) dg = dg2 - return dq, dk, db, dg + return dq, dk, db, dg \ No newline at end of file diff --git a/tests/test_kda_gva_intra_sm100.py b/tests/test_kda_gva_intra_sm100.py new file mode 100644 index 0000000..8082946 --- /dev/null +++ b/tests/test_kda_gva_intra_sm100.py @@ -0,0 +1,386 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SM100 KDA GVA (HV > HQK) support in chunk_kda_fwd_intra. + +The SM100 kernels (kda_fwd_intra / kda_fwd_recomp_w_u) now accept: + * q, k with head-dim ``HQK`` + * v, g, beta with head-dim ``HV`` where ``HV = group_size * HQK`` (group_size >= 1) + +This file verifies that the cuLA GVA path produces numerically matching results +compared to the FLA Triton reference, where the FLA reference does not natively +support GVA and therefore receives ``k`` replicated along the head axis to +``HV`` heads. Both uniform-length and varlen layouts are covered, and an +additional degeneracy test asserts that ``HV == HQK`` (group_size == 1) keeps +the non-GVA behaviour untouched. +""" + +from __future__ import annotations + +import pytest +import torch +from einops import rearrange +from fla.modules.l2norm import l2norm_fwd +from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra as fla_chunk_kda_fwd_intra +from fla.ops.kda.gate import kda_gate_chunk_cumsum +from fla.ops.utils.constant import RCP_LN2 +from fla.ops.utils.index import prepare_chunk_indices +from fla.utils import assert_close, device + +from cula.kda.chunk_intra import chunk_kda_fwd_intra as cula_chunk_kda_fwd_intra +from cula.utils import prepare_uniform_cu_seqlens + +pytestmark = pytest.mark.sm100_only + + +# ========================================================================= +# Helpers +# ========================================================================= + +def _repeat_head(x: torch.Tensor, group_size: int, head_dim: int = 2) -> torch.Tensor: + """Replicate ``x`` along the head axis by ``group_size``. + + Mirrors GVA's broadcasting semantics: each QK head is paired with + ``group_size`` consecutive V heads, so ``k[..., h_qk, :]`` is used by + ``v[..., h_qk * group_size : (h_qk + 1) * group_size, :]``. + """ + return x.repeat_interleave(group_size, dim=head_dim).contiguous() + + +def _make_gva_inputs( + B: int, + T: int, + HQK: int, + HV: int, + D: int, + chunk_size: int, + cu_seqlens: torch.Tensor | None = None, + dtype: torch.dtype = torch.bfloat16, + seed: int = 42, +): + """Construct inputs for chunk_kda_fwd_intra in GVA layout. + + Returns: + q, k : (B, T, HQK, D) dtype + v : (B, T, HV, D) dtype + g : (B, T, HV, D) float32, after kda_gate_chunk_cumsum + beta : (B, T, HV) float32 in (0, 1) + scale : float + cu_seqlens : (N+1,) int32 or None + chunk_indices: (NT, 2) int32 or None + """ + assert HV % HQK == 0 and HV >= HQK, f"invalid HV/HQK: {HV}/{HQK}" + + torch.manual_seed(seed) + scale = D ** (-0.5) + + # QK are in HQK head space; V / gates / beta live in HV space. + q = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + k = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + v = torch.randn(B, T, HV, D, dtype=dtype, device=device) + g_raw = torch.randn(B, T, HV, D, dtype=dtype, device=device) + beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() + + # l2-normalise q/k so that scale/gate ranges match production use. + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + + # FLA gate cumsum only supports packed batch (B=1) when cu_seqlens is set. + if B != 1: + q, k, v, g_raw, beta = map( + lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), + (q, k, v, g_raw, beta), + ) + + # Per-HV gate preprocessing (cumsum inside chunks). + A_log = torch.randn(HV, dtype=torch.float, device=device) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device) + + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + ) + g = kda_gate_chunk_cumsum( + g=g_raw, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=-5.0, + ) + return q, k, v, g, beta, scale, cu_seqlens, chunk_indices + + +def _run_fla_ref(q, k_hqk, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, group_size, disable_recompute): + """Reference: replicate k along head axis to HV, then call FLA intra. + + FLA's chunk_kda_fwd_intra assumes H == HQK == HV (no GVA), so we construct + the HV-head view of k and q before invoking it. + """ + k_hv = _repeat_head(k_hqk, group_size) + q_hv = _repeat_head(q, group_size) + return fla_chunk_kda_fwd_intra( + q=q_hv, + k=k_hv, + v=v, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=True, + disable_recompute=disable_recompute, + ) + + +def _run_cula_gva(q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute): + return cula_chunk_kda_fwd_intra( + q=q, + k=k, + v=v, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=True, + disable_recompute=disable_recompute, + ) + + +def _assert_intra_outputs_match(ref, tri, disable_recompute: bool) -> None: + """Compare cuLA vs FLA on user-visible intra outputs. + + We intentionally skip ``Aqk``: the cuLA SM100 fused kernel does not + materialise every off-diagonal slot that FLA's multi-kernel path writes, + and the FLA reference can contain NaNs in unused ``Aqk`` entries. The + downstream tensors ``w`` / ``u`` / ``kg`` (and ``Akk``) are the meaningful + correctness signals and match the benchmark's comparison strategy. + """ + w_r, u_r, qg_r, kg_r, _Aqk_r, Akk_r = ref + w_c, u_c, qg_c, kg_c, _Aqk_c, Akk_c = tri + + assert Akk_c.shape == Akk_r.shape, (Akk_c.shape, Akk_r.shape) + assert w_c.shape == w_r.shape, (w_c.shape, w_r.shape) + assert u_c.shape == u_r.shape, (u_c.shape, u_r.shape) + assert kg_c.shape == kg_r.shape, (kg_c.shape, kg_r.shape) + + assert_close("Akk", Akk_r, Akk_c, 0.008) + assert_close("w", w_r, w_c, 0.008) + assert_close("u", u_r, u_c, 0.008) + assert_close("kg", kg_r, kg_c, 0.005) + + if disable_recompute: + assert qg_c is not None and qg_r is not None + assert qg_c.shape == qg_r.shape, (qg_c.shape, qg_r.shape) + assert_close("qg", qg_r, qg_c, 0.005) + else: + assert qg_c is None, "cuLA must not materialise qg when disable_recompute=False" + + +# ========================================================================= +# Uniform-length tests +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("B", "T", "HQK", "group_size", "D"), + [ + pytest.param(*cfg, id="B{}-T{}-HQK{}-gs{}-D{}".format(*cfg)) + for cfg in [ + # group_size == 2: classic GVA 2:1 + (1, 256, 2, 2, 128), + (2, 512, 4, 2, 128), + # group_size == 4: wider grouping + (1, 1024, 2, 4, 128), + (2, 1024, 4, 4, 128), + # Non-multiple-of-BT sequence length to stress boundary handling. + (1, 500, 2, 2, 128), + (1, 1000, 4, 2, 128), + ] + ], +) +def test_gva_intra_uniform(B, T, HQK, group_size, D, disable_recompute): + """cuLA GVA path must match FLA(k-replicated-to-HV) for uniform seqlens.""" + HV = HQK * group_size + chunk_size = 64 + + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + + # cuLA GVA path (k in HQK head space). + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute, + ) + + # FLA reference (k replicated to HV). + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = _run_fla_ref( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, group_size, disable_recompute, + ) + + _assert_intra_outputs_match( + (w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r), + (w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c), + disable_recompute, + ) + + +# ========================================================================= +# Varlen tests +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("HQK", "group_size", "D", "cu_seqlens"), + [ + pytest.param(*cfg, id="HQK{}-gs{}-D{}-ns{}".format(cfg[0], cfg[1], cfg[2], len(cfg[3]) - 1)) + for cfg in [ + (2, 2, 128, [0, 256, 500, 1000]), + (4, 2, 128, [0, 100, 300, 1200, 2000]), + (2, 4, 128, [0, 15, 100, 300, 1200, 2048]), + # Simulated realistic trace. + ( + 4, 2, 128, + [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096], + ), + ] + ], +) +def test_gva_intra_varlen(HQK, group_size, D, cu_seqlens, disable_recompute): + """GVA correctness under variable-length (packed) inputs.""" + HV = HQK * group_size + chunk_size = 64 + + cu_seqlens_t = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + T = int(cu_seqlens_t[-1].item()) + # Packed layout uses B=1 and a flat time axis. + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices = _make_gva_inputs( + B=1, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens_t, + ) + + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices, chunk_size, disable_recompute, + ) + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = _run_fla_ref( + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices, chunk_size, group_size, disable_recompute, + ) + + _assert_intra_outputs_match( + (w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r), + (w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c), + disable_recompute, + ) + + +# ========================================================================= +# Degeneracy: HV == HQK must match the non-GVA (same-shape) reference +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("B", "T", "H", "D"), + [ + pytest.param(*cfg, id="B{}-T{}-H{}-D{}".format(*cfg)) + for cfg in [ + (1, 512, 4, 128), + (2, 1024, 4, 128), + ] + ], +) +def test_gva_intra_degenerate_equals_non_gva(B, T, H, D, disable_recompute): + """When HV == HQK, the GVA code path must be byte-for-byte equivalent + to the non-GVA path that existed before this change. + + We do not have a separate "non-GVA" entrypoint, but we can assert the + cuLA path matches FLA with *no* head replication (group_size=1), which + exercises the ``HV == HQK`` fast-path inside the new kernels. + """ + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=H, HV=H, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute, + ) + # group_size=1 → no replication; identical input shape to cuLA. + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = fla_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=disable_recompute, + ) + + _assert_intra_outputs_match( + (w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r), + (w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c), + disable_recompute, + ) + + +# ========================================================================= +# Shape / contract sanity checks (run even without a reference) +# ========================================================================= + +@pytest.mark.parametrize("group_size", [1, 2, 4]) +def test_gva_intra_output_shapes(group_size): + """All outputs of chunk_kda_fwd_intra must live in HV-head space.""" + B, T, HQK, D = 1, 256, 2, 128 + HV = HQK * group_size + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + w, u, qg, kg, Aqk, Akk = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute=True, + ) + + assert Aqk.shape == (B, T, HV, chunk_size), Aqk.shape + assert Akk.shape == (B, T, HV, chunk_size), Akk.shape + assert w.shape == (B, T, HV, D), w.shape + assert u.shape == (B, T, HV, D), u.shape + assert kg.shape == (B, T, HV, D), kg.shape + assert qg is not None and qg.shape == (B, T, HV, D), (None if qg is None else qg.shape) + + +# ========================================================================= +# Negative / assertion tests +# ========================================================================= + +def test_gva_intra_rejects_non_multiple_ratio(): + """HV must be a positive integer multiple of HQK.""" + B, T, HQK, HV, D = 1, 128, 3, 5, 128 # 5 % 3 != 0 + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + # We intentionally do not use _make_gva_inputs because the assert fires + # before kernel launch on the python side. + dtype = torch.bfloat16 + q = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + k = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + v = torch.randn(B, T, HV, D, dtype=dtype, device=device) + g = torch.randn(B, T, HV, D, dtype=torch.float, device=device) + beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() + + with pytest.raises((AssertionError, RuntimeError), match=r"multiple|h_v"): + cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=D ** -0.5, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, + safe_gate=True, disable_recompute=False, + )