diff --git a/benchmarks/bench_chunk_delta_h.py b/benchmarks/bench_chunk_delta_h.py index f14c866..a0cf054 100644 --- a/benchmarks/bench_chunk_delta_h.py +++ b/benchmarks/bench_chunk_delta_h.py @@ -104,20 +104,20 @@ def bench_non_varlen(configs): print("=" * 80) results = [] - for B, T, H, use_gk, use_h0, store_ht, save_vnew in configs: + for B, T, H, HV, use_gk, use_h0, store_ht, save_vnew in configs: torch.manual_seed(42) torch.cuda.empty_cache() k = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 - w = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 - u = torch.randn(B, T, H, V, device=device, dtype=dtype) * 0.1 + w = torch.randn(B, T, HV, K, device=device, dtype=dtype) * 0.1 + u = torch.randn(B, T, HV, V, device=device, dtype=dtype) * 0.1 gk = None h0 = None if use_gk: - gk = -torch.abs(torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1).cumsum(dim=1) + gk = -torch.abs(torch.randn(B, T, HV, K, device=device, dtype=torch.float32) * 0.1).cumsum(dim=1) if use_h0: - h0 = torch.randn(B, H, K, V, device=device, dtype=torch.float32) * 0.01 + h0 = torch.randn(B, HV, K, V, device=device, dtype=torch.float32) * 0.01 # ---- FLA baseline ---- fla_result = fla_fwd_h( @@ -192,10 +192,13 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0): flags.append("vn") flag_str = f" [{','.join(flags)}]" if flags else "" + hv_str = f"/{HV}" if HV != H else "" r = { "B": B, "T": T, "H": H, + "HV": HV, + "hv_str": hv_str, "flags": flag_str, "max_diff": max_diff, "mean_diff": mean_diff, @@ -205,7 +208,7 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0): } results.append(r) print( - f" B={B:2d} T={T:5d} H={H:3d}{flag_str:<16s} | " + f" B={B:2d} T={T:5d} H={H:3d}{hv_str:<4s}{flag_str:<16s} | " f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f} | " f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | " f"speedup={speedup:.2f}x" @@ -243,7 +246,7 @@ def bench_varlen(configs): print("=" * 80) results = [] - for num_seqs, total_T, H, ratio, use_gk, use_h0, store_ht, save_vnew in configs: + for num_seqs, total_T, H, HV, ratio, use_gk, use_h0, store_ht, save_vnew in configs: seq_lens = generate_seq_lens(num_seqs, total_T, ratio) cu_seqlens_list = [0] for sl in seq_lens: @@ -261,22 +264,22 @@ def bench_varlen(configs): torch.manual_seed(42) torch.cuda.empty_cache() - # Both FLA and CuTe DSL use [1, total_T, H, ...] (4D with B=1) + # Both FLA and CuTe DSL use [1, total_T, H/HV, ...] (4D with B=1) k = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 - w = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 - u = torch.randn(1, total_T, H, V, device=device, dtype=dtype) * 0.1 + w = torch.randn(1, total_T, HV, K, device=device, dtype=dtype) * 0.1 + u = torch.randn(1, total_T, HV, V, device=device, dtype=dtype) * 0.1 gk = None h0 = None if use_gk: - gk_raw = torch.randn(1, total_T, H, K, device=device, dtype=torch.float32) * 0.1 + gk_raw = torch.randn(1, total_T, HV, K, device=device, dtype=torch.float32) * 0.1 gk = torch.zeros_like(gk_raw) for i in range(num_seqs): bos = cu_seqlens[i].item() eos = cu_seqlens[i + 1].item() gk[:, bos:eos] = -torch.abs(gk_raw[:, bos:eos]).cumsum(dim=1) if use_h0: - h0 = torch.randn(num_seqs, H, K, V, device=device, dtype=torch.float32) * 0.01 + h0 = torch.randn(num_seqs, HV, K, V, device=device, dtype=torch.float32) * 0.01 # ---- FLA baseline ---- fla_result = fla_fwd_h( @@ -359,10 +362,13 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0, cu=cu_seqlens): flags.append("vn") flag_str = f" [{','.join(flags)}]" if flags else "" + hv_str = f"/{HV}" if HV != H else "" r = { "tag": tag, "T_total": total_T, "H": H, + "HV": HV, + "hv_str": hv_str, "n_seqs": num_seqs, "flags": flag_str, "max_diff": max_diff, @@ -373,7 +379,7 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0, cu=cu_seqlens): } results.append(r) print( - f" {tag:40s} H={H:3d}{flag_str:<16s} | " + f" {tag:40s} H={H:3d}{hv_str:<4s}{flag_str:<16s} | " f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f} | " f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | " f"speedup={speedup:.2f}x" @@ -406,7 +412,7 @@ def print_report(nv_results, vl_results): ) print(f" {'─' * 100}") for r in nv_results: - label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d}{r['flags']}" + label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d}{r['hv_str']}{r['flags']}" print( f" {label:<35s} │ " f"{r['max_diff']:10.6f} {r['mean_diff']:12.8f} │ " @@ -426,7 +432,7 @@ def print_report(nv_results, vl_results): ) print(f" {'─' * 115}") for r in vl_results: - label = f"{r['tag']} H={r['H']:3d}{r['flags']}" + label = f"{r['tag']} H={r['H']:3d}{r['hv_str']}{r['flags']}" print( f" {label:>55s} │ " f"{r['max_diff']:10.6f} {r['mean_diff']:12.8f} │ " @@ -452,6 +458,18 @@ def main(): choices=["non-varlen", "varlen", "both"], help="Which benchmark mode to run (default: both)", ) + parser.add_argument( + "--heads", + type=int, + default=64, + help="Number of QK heads H (default: 64)", + ) + parser.add_argument( + "--hv", + type=int, + default=None, + help="Number of value heads HV (default: same as --heads, i.e. no GVA)", + ) parser.add_argument( "--ncu", action="store_true", @@ -464,22 +482,25 @@ def main(): NCU_MODE = True print("[NCU mode] warmup=1, iters=1") - # (B, T, H, use_gk, use_h0, store_ht, save_vnew) + H = args.heads + HV = args.hv if args.hv is not None else H + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + + # (B, T, H, HV, use_gk, use_h0, store_ht, save_vnew) non_varlen_configs = [ - # Sweep B × H with all features (gk, h0, ht, vnew) - (1, 8192, 64, True, True, True, True), - (2, 8192, 64, True, True, True, True), - (4, 8192, 64, True, True, True, True), - (8, 8192, 64, True, True, True, True), + (1, 8192, H, HV, True, True, True, True), + (2, 8192, H, HV, True, True, True, True), + (4, 8192, H, HV, True, True, True, True), + (8, 8192, H, HV, True, True, True, True), ] - # (num_seqs, total_T, H, ratio, use_gk, use_h0, store_ht, save_vnew) + # (num_seqs, total_T, H, HV, ratio, use_gk, use_h0, store_ht, save_vnew) varlen_configs = [ - (20, 8192, 64, 2.0, True, True, True, True), - (25, 8192, 64, 3.0, True, True, True, True), - (20, 8192, 64, 4.0, True, True, True, True), - (20, 32768, 64, 2.0, True, True, True, True), - (25, 32768, 64, 3.0, True, True, True, True), + (20, 8192, H, HV, 2.0, True, True, True, True), + (25, 8192, H, HV, 3.0, True, True, True, True), + (20, 8192, H, HV, 4.0, True, True, True, True), + (20, 32768, H, HV, 2.0, True, True, True, True), + (25, 32768, H, HV, 3.0, True, True, True, True), ] nv_res, vl_res = [], [] diff --git a/benchmarks/bench_fwd_o.py b/benchmarks/bench_fwd_o.py index 7aa40b0..31f64da 100644 --- a/benchmarks/bench_fwd_o.py +++ b/benchmarks/bench_fwd_o.py @@ -108,17 +108,17 @@ def bench_non_varlen(configs): print("=" * 80) results = [] - for B, T, H in configs: + for B, T, H, HV in configs: scale = K**-0.5 NT = (T + BT - 1) // BT torch.manual_seed(42) torch.cuda.empty_cache() q = torch.randn(B, T, H, K, dtype=dtype, device=device) - v = torch.randn(B, T, H, V, dtype=dtype, device=device) - g = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1 - h = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - A = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 + v = torch.randn(B, T, HV, V, dtype=dtype, device=device) + g = torch.randn(B, T, HV, K, dtype=torch.float32, device=device) * 0.1 + h = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + A = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1 # ---- FLA baseline (accuracy) ---- o_fla = chunk_gla_fwd_o_gk( @@ -133,7 +133,7 @@ def bench_non_varlen(configs): ) # ---- CuTe DSL (accuracy) ---- - o_cute_t = torch.zeros(B, T, H, V, dtype=dtype, device=device) + o_cute_t = torch.zeros(B, T, HV, V, dtype=dtype, device=device) # Warmup / first call triggers compilation via cache chunk_gla_fwd_o( @@ -183,10 +183,13 @@ def run_cute(q=q, v=v, g=g, h=h, o=o_cute_t, A=A, scale=scale): ms_cute = time_kernel(run_cute) speedup = ms_fla / ms_cute if ms_cute > 0 else float("inf") + hv_str = f"/{HV}" if HV != H else "" r = { "B": B, "T": T, "H": H, + "HV": HV, + "hv_str": hv_str, "max_diff": max_diff, "rel_max_diff": rel_max_diff, "mean_diff": mean_diff, @@ -196,7 +199,7 @@ def run_cute(q=q, v=v, g=g, h=h, o=o_cute_t, A=A, scale=scale): } results.append(r) print( - f" B={B:2d} T={T:5d} H={H:2d} | " + f" B={B:2d} T={T:5d} H={H:2d}{hv_str:<4s} | " f"max_diff={max_diff:.6f} rel_max={rel_max_diff:.6f} mean_diff={mean_diff:.8f} | " f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | " f"speedup={speedup:.2f}x" @@ -232,7 +235,7 @@ def bench_varlen(configs): print("=" * 80) results = [] - for seq_lens, H in configs: + for seq_lens, H, HV in configs: scale = K**-0.5 T_total = sum(seq_lens) cu_seqlens_list = [0] @@ -244,12 +247,12 @@ def bench_varlen(configs): torch.cuda.empty_cache() # Flat token-indexed tensors (shared data for both kernels) - # 4D with B=1: [1, T_total, H, *] + # q uses H (QK heads), g/v/h/A/o use HV (value heads) q_flat = torch.randn(1, T_total, H, K, dtype=dtype, device=device) - v_flat = torch.randn(1, T_total, H, V, dtype=dtype, device=device) - g_flat = torch.randn(1, T_total, H, K, dtype=torch.float32, device=device) * 0.1 - h_flat = torch.randn(1, total_nt_val, H, K, V, dtype=dtype, device=device) * 0.01 - A_flat = torch.randn(1, T_total, H, BT, dtype=dtype, device=device) * 0.1 + v_flat = torch.randn(1, T_total, HV, V, dtype=dtype, device=device) + g_flat = torch.randn(1, T_total, HV, K, dtype=torch.float32, device=device) * 0.1 + h_flat = torch.randn(1, total_nt_val, HV, K, V, dtype=dtype, device=device) * 0.01 + A_flat = torch.randn(1, T_total, HV, BT, dtype=dtype, device=device) * 0.1 # ---- FLA baseline (needs [1, T_total, H, *] + cu_seqlens int64) ---- cu_fla = torch.tensor(cu_seqlens_list, dtype=torch.long, device=device) @@ -267,7 +270,7 @@ def bench_varlen(configs): ) # ---- CuTe DSL varlen ---- - o_cute_flat = torch.zeros(1, T_total, H, V, dtype=dtype, device=device) + o_cute_flat = torch.zeros(1, T_total, HV, V, dtype=dtype, device=device) cu_cute = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) ci_cute = build_chunk_indices(seq_lens, BT=BT, device=device) @@ -339,10 +342,13 @@ def run_cute( min_l, max_l = min(seq_lens), max(seq_lens) avg_l = T_total // n_seqs tag = f"{n_seqs}seqs T={T_total} [{min_l}..{max_l}] avg={avg_l}" + hv_str = f"/{HV}" if HV != H else "" r = { "tag": tag, "T_total": T_total, "H": H, + "HV": HV, + "hv_str": hv_str, "n_seqs": n_seqs, "max_diff": max_diff, "rel_max_diff": rel_max_diff, @@ -353,7 +359,7 @@ def run_cute( } results.append(r) print( - f" {tag:45s} H={H:2d} | " + f" {tag:45s} H={H:2d}{hv_str:<4s} | " f"max_diff={max_diff:.6f} rel_max={rel_max_diff:.6f} mean_diff={mean_diff:.8f} | " f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | " f"speedup={speedup:.2f}x" @@ -380,15 +386,16 @@ def print_report(nv_results, vl_results): if nv_results: print("\n [Non-Varlen]") hdr = ( - f" {'B':>3s} {'T':>5s} {'H':>3s} │ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}" + f" {'B':>3s} {'T':>5s} {'H':>7s} │ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}" f" │ {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}" ) print(f" {'─' * 90}") print(hdr) print(f" {'─' * 90}") for r in nv_results: + h_label = f"{r['H']}{r['hv_str']}" print( - f" {r['B']:3d} {r['T']:5d} {r['H']:3d} │ " + f" {r['B']:3d} {r['T']:5d} {h_label:>7s} │ " f"{r['max_diff']:10.6f} {r['rel_max_diff']:10.6f} {r['mean_diff']:12.8f} │ " f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x" ) @@ -397,15 +404,16 @@ def print_report(nv_results, vl_results): if vl_results: print("\n [Varlen]") hdr = ( - f" {'Config':>45s} {'H':>3s} │ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}" + f" {'Config':>45s} {'H':>7s} │ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}" f" │ {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}" ) print(f" {'─' * 117}") print(hdr) print(f" {'─' * 117}") for r in vl_results: + h_label = f"{r['H']}{r['hv_str']}" print( - f" {r['tag']:>45s} {r['H']:3d} │ " + f" {r['tag']:>45s} {h_label:>7s} │ " f"{r['max_diff']:10.6f} {r['rel_max_diff']:10.6f} {r['mean_diff']:12.8f} │ " f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x" ) @@ -431,6 +439,18 @@ def main(): action="store_true", help="NCU profiling mode: warmup=1, iters=1", ) + parser.add_argument( + "--heads", + type=int, + default=64, + help="Number of QK heads H (default: 64)", + ) + parser.add_argument( + "--hv", + type=int, + default=None, + help="Number of value heads HV (default: same as --heads, i.e. no GVA)", + ) args = parser.parse_args() global NCU_MODE @@ -438,21 +458,24 @@ def main(): NCU_MODE = True print("[NCU mode] warmup=1, iters=1") + H = args.heads + HV = args.hv if args.hv is not None else H + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + non_varlen_configs = [ - # (B, T, H) - (2, 8192, 64), - (2, 32768, 64), - (4, 8192, 64), - (4, 32768, 64), + # (B, T, H, HV) + (2, 8192, H, HV), + (2, 32768, H, HV), + (4, 8192, H, HV), + (4, 32768, H, HV), ] varlen_configs = [ - # (seq_lens, H) — realistic serving scenarios - # ~20-25 seqs, total 8k/32k, lengths vary 2-3x, H=64 - (gen_varlen_seqs(8192, 20, seed=1), 64), - (gen_varlen_seqs(8192, 25, seed=2), 64), - (gen_varlen_seqs(32768, 20, seed=3), 64), - (gen_varlen_seqs(32768, 25, seed=4), 64), + # (seq_lens, H, HV) + (gen_varlen_seqs(8192, 20, seed=1), H, HV), + (gen_varlen_seqs(8192, 25, seed=2), H, HV), + (gen_varlen_seqs(32768, 20, seed=3), H, HV), + (gen_varlen_seqs(32768, 25, seed=4), H, HV), ] nv_res, vl_res = [], [] diff --git a/cula/ops/chunk_delta_h.py b/cula/ops/chunk_delta_h.py index b2a61de..67399d6 100644 --- a/cula/ops/chunk_delta_h.py +++ b/cula/ops/chunk_delta_h.py @@ -192,7 +192,7 @@ def _plan_tmem_offsets(tiled_mma_wh, tile_wh, tiled_mma_kv, tile_kv, state_tmem_ ) return wh_off, state_off, vnew_off, kv_off, total - def _compute_grid(self, B, H, V): + def _compute_grid(self, B, HV, V): num_v_tiles = (V + self.BV - 1) // self.BV if self.is_varlen: if self.persistent: @@ -202,26 +202,26 @@ def _compute_grid(self, B, H, V): return (sm_count, 1, 1) else: # Non-persistent: one CTA per work unit, free HW scheduling - total_work_units = num_v_tiles * H * B + total_work_units = num_v_tiles * HV * B return (total_work_units, 1, 1) - return (num_v_tiles, H, B) + return (num_v_tiles, HV, B) @cute.jit def __call__( self, k_in: cute.Tensor, # [B, T, H, K] or [T_total, H, K] - w_in: cute.Tensor, # [B, T, H, K] or [T_total, H, K] - u_in: cute.Tensor, # [B, T, H, V] or [T_total, H, V] - g_in: cute.Tensor, # [B, T, H] or [T_total, H] (fp32, unused currently) - gk_in: cute.Tensor, # [B, T, H, K] or [T_total, H, K] (fp32) - h_out_in: cute.Tensor, # [B, NT, H, K, V] or [NT_total, H, K, V] - v_new_in: cute.Tensor, # [B, T, H, V] or [T_total, H, V] - h0_in: cute.Tensor, # [B, H, K, V] (fp32) - ht_in: cute.Tensor, # [B, H, K, V] + w_in: cute.Tensor, # [B, T, HV, K] or [T_total, HV, K] + u_in: cute.Tensor, # [B, T, HV, V] or [T_total, HV, V] + g_in: cute.Tensor, # [B, T, HV] or [T_total, HV] (fp32, unused currently) + gk_in: cute.Tensor, # [B, T, HV, K] or [T_total, HV, K] (fp32) + h_out_in: cute.Tensor, # [B, NT, HV, K, V] or [NT_total, HV, K, V] + v_new_in: cute.Tensor, # [B, T, HV, V] or [T_total, HV, V] + h0_in: cute.Tensor, # [B, HV, K, V] (fp32) + ht_in: cute.Tensor, # [B, HV, K, V] cu_seqlens_in: cute.Tensor, # [N+1] int32 chunk_offsets_in: cute.Tensor, # [N+1] int32 workspace_in: cute.Tensor, # workspace buffer - problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + problem_size: tuple[Int32, Int32, Int32, Int32, Int32, Int32], total_nt: Int32, use_g: Int32, use_gk: Int32, @@ -243,7 +243,7 @@ def __call__( chunk_offsets_ptr = chunk_offsets_in.iterator workspace_ptr = workspace_in.iterator - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size # For varlen: B=num_seqs, T=total_tokens, data tensors use data_B=1. # For non-varlen: data_B=B, NT=ceil(T/BT). @@ -259,34 +259,34 @@ def __call__( kt_layout = cute.make_layout((K, T, (H, data_B)), stride=(1, H * K, (K, T * H * K))) kt = cute.make_tensor(k_ptr, kt_layout) - w_layout = cute.make_layout((T, K, (H, data_B)), stride=(H * K, 1, (K, T * H * K))) + w_layout = cute.make_layout((T, K, (HV, data_B)), stride=(HV * K, 1, (K, T * HV * K))) w = cute.make_tensor(w_ptr, w_layout) - u_layout = cute.make_layout((T, V, (H, data_B)), stride=(H * V, 1, (V, T * H * V))) + u_layout = cute.make_layout((T, V, (HV, data_B)), stride=(HV * V, 1, (V, T * HV * V))) u = cute.make_tensor(u_ptr, u_layout) v_new = cute.make_tensor(v_new_ptr, u_layout) # h_out: for varlen, NT=total_chunks and data_B=1; for non-varlen, NT=per-seq chunks and data_B=B h_out_T_layout = cute.make_layout( - (V, K, (NT, H, data_B)), - stride=(1, V, (H * K * V, K * V, NT * H * K * V)), + (V, K, (NT, HV, data_B)), + stride=(1, V, (HV * K * V, K * V, NT * HV * K * V)), ) h_out_T = cute.make_tensor(h_out_ptr, h_out_T_layout) # h0/ht always use B=num_seqs (same for both varlen and non-varlen) - h0_layout = cute.make_layout((K, V, (H, B)), stride=(V, 1, (K * V, H * K * V))) + h0_layout = cute.make_layout((K, V, (HV, B)), stride=(V, 1, (K * V, HV * K * V))) h0 = cute.make_tensor(h0_ptr, h0_layout) - ht_T_layout = cute.make_layout((V, K, (H, B)), stride=(1, V, (K * V, H * K * V))) + ht_T_layout = cute.make_layout((V, K, (HV, B)), stride=(1, V, (K * V, HV * K * V))) ht_T = cute.make_tensor(ht_ptr, ht_T_layout) # gk K-first view for TMA: (K, T, (H, data_B)) with K contiguous - gk_K_layout = cute.make_layout((K, T, (H, data_B)), stride=(1, H * K, (K, T * H * K))) + gk_K_layout = cute.make_layout((K, T, (HV, data_B)), stride=(1, HV * K, (K, T * HV * K))) gk_K = cute.make_tensor(gk_ptr, gk_K_layout) # Transposed U view: (V, T, (H, data_B)) to match WH acc shape (M=BV, N=BT) - u_T_layout = cute.make_layout((V, T, (H, data_B)), stride=(1, H * V, (V, T * H * V))) + u_T_layout = cute.make_layout((V, T, (HV, data_B)), stride=(1, HV * V, (V, T * HV * V))) u_T = cute.make_tensor(u_ptr, u_T_layout) self.k_dtype = kt.element_type @@ -432,8 +432,8 @@ def __call__( # v_new transposed GMEM view: (V, T, (H, data_B)) for TMA store v_new_T_layout = cute.make_layout( - (V, T, (H, data_B)), - stride=(1, H * V, (V, T * H * V)), + (V, T, (HV, data_B)), + stride=(1, HV * V, (V, T * HV * V)), ) v_new_T = cute.make_tensor(v_new_ptr, v_new_T_layout) @@ -557,7 +557,7 @@ class SharedStorage: sched_consumed_mbar: cute.struct.MemRange[Int64, 2] self.shared_storage = SharedStorage - self.grid = self._compute_grid(B, H, V) + self.grid = self._compute_grid(B, HV, V) self.kernel( wh_tiled_mma, @@ -642,7 +642,7 @@ def kernel( cu_seqlens: cute.Tensor, chunk_offsets: cute.Tensor, workspace_iter: cute.Pointer, - problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + problem_size: tuple[Int32, Int32, Int32, Int32, Int32, Int32], use_gk: Int32, use_initial_state: Int32, store_final_state: Int32, @@ -819,7 +819,7 @@ def kernel( tCtAccKV = cute.make_tensor(tmem_ptr + self.tmem_kv_off, tCtAccKV_fake.layout) # ===================== Block indices ===================== - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size BT = self.BT if cutlass.const_expr(self.is_varlen): @@ -828,7 +828,7 @@ def kernel( block_idx_x = cute.arch.block_idx()[0] grid_dim_x = cute.arch.grid_dim()[0] num_v_tiles = (V + self.BV - 1) // self.BV - total_work_units = num_v_tiles * H * B + total_work_units = num_v_tiles * HV * B if cutlass.const_expr(self.persistent): # Dynamic scheduling: while loop uses work_idx < total_work_units num_iters = Int32(0) # not used, while loop controls iteration @@ -838,6 +838,7 @@ def kernel( work_idx = Int32(0) v_tile_idx = Int32(0) hidx = Int32(0) + i_h = Int32(0) bidx = Int32(0) tok_offset = Int32(0) seq_len = Int32(0) @@ -846,6 +847,7 @@ def kernel( chunk_off = Int32(0) else: (v_tile_idx, hidx, bidx) = cute.arch.block_idx() + i_h = hidx // (HV // H) tok_offset = Int32(0) seq_len = T NT = (T + BT - 1) // BT @@ -907,8 +909,9 @@ def kernel( work_idx = block_idx_x + wu_iter * grid_dim_x v_tile_idx = work_idx % num_v_tiles temp_work = work_idx // num_v_tiles - hidx = temp_work % H - bidx = temp_work // H + hidx = temp_work % HV + bidx = temp_work // HV + i_h = hidx // (HV // H) tok_offset = cu_seqlens[bidx] seq_len = cu_seqlens[bidx + 1] - tok_offset NT = (seq_len + BT - 1) // BT @@ -946,7 +949,7 @@ def kernel( self.kv_mma_tiler, kv_tiled_mma, data_bidx, - hidx, + i_h, ) # U TMA load partition (non-MMA, epilog-style) @@ -1087,7 +1090,7 @@ def kernel( if cutlass.const_expr(self.is_varlen): if cutlass.const_expr(not self.persistent): work_idx = block_idx_x + wu_iter * grid_dim_x - bidx_mma = (work_idx // num_v_tiles) // H + bidx_mma = (work_idx // num_v_tiles) // HV tok_off_mma = cu_seqlens[bidx_mma] NT = (cu_seqlens[bidx_mma + 1] - tok_off_mma + BT - 1) // BT if cutlass.const_expr(PRINT_DEBUG): @@ -1280,8 +1283,9 @@ def kernel( work_idx = block_idx_x + wu_iter * grid_dim_x v_tile_idx = work_idx % num_v_tiles temp_work = work_idx // num_v_tiles - hidx = temp_work % H - bidx = temp_work // H + hidx = temp_work % HV + bidx = temp_work // HV + i_h = hidx // (HV // H) tok_offset = cu_seqlens[bidx] seq_len = cu_seqlens[bidx + 1] - tok_offset NT = (seq_len + BT - 1) // BT @@ -1494,8 +1498,9 @@ def kernel( work_idx = block_idx_x + wu_iter * grid_dim_x v_tile_idx = work_idx % num_v_tiles temp_work = work_idx // num_v_tiles - hidx = temp_work % H - bidx = temp_work // H + hidx = temp_work % HV + bidx = temp_work // HV + i_h = hidx // (HV // H) tok_offset = cu_seqlens[bidx] seq_len = cu_seqlens[bidx + 1] - tok_offset NT = (seq_len + BT - 1) // BT @@ -1582,7 +1587,7 @@ def kernel( # Construct GMEM tile for this chunk vnew_chunk_raw = ( v_new_tensor.iterator - + (tok_offset + chunk_idx * BT) * H * V + + (tok_offset + chunk_idx * BT) * HV * V + hidx * V + v_tile_idx * self.BV ) @@ -1593,7 +1598,7 @@ def kernel( assumed_align=16, ) vnew_stride_t = cute.assume( - H * V, + HV * V, divby=128 // self.io_dtype.width, ) gVnew_chunk = cute.make_tensor( @@ -1792,11 +1797,11 @@ def reference_bf16_roundtrip(k, w, u, g=None, gk=None, h0=None, chunk_size=64): # Compile cache + TVM-FFI API # --------------------------------------------------------------------------- -# Internal cache: maps (is_varlen, persistent, H, K, V, chunk_size) → compiled_fn +# Internal cache: maps (is_varlen, persistent, H, HV, K, V, chunk_size) → compiled_fn _delta_h_kernel_cache: dict = {} -def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fast_math): +def _compile_delta_h_variant(is_varlen, persistent, H, HV, K, V, chunk_size, use_fast_math): """Compile one ChunkDeltaRuleFwdH kernel variant. Returns the compiled TVM-FFI callable. Uses make_fake_compact_tensor and make_fake_stream for compilation with @@ -1825,7 +1830,7 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas sym_ns = cute.sym_int() # num_seqs (varlen h0/ht) or B (non-varlen, == sym_a) if is_varlen: - # varlen: data tensors are [T_total, H, ...] (3D) + # varlen: data tensors are [T_total, H/HV, ...] (3D) k_fake = make_fake_compact_tensor( cutlass.BFloat16, (sym_a, H, K), @@ -1834,42 +1839,42 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas ) w_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, H, K), + (sym_a, HV, K), stride_order=(2, 1, 0), assumed_align=128, ) u_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, H, V), + (sym_a, HV, V), stride_order=(2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, H), + (sym_a, HV), stride_order=(1, 0), assumed_align=128, ) gk_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, H, K), + (sym_a, HV, K), stride_order=(2, 1, 0), assumed_align=128, ) v_new_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, H, V), + (sym_a, HV, V), stride_order=(2, 1, 0), assumed_align=128, ) h_out_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_nt, H, K, V), + (sym_nt, HV, K, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) else: - # non-varlen: data tensors are [B, T, H, ...] (4D) + # non-varlen: data tensors are [B, T, H/HV, ...] (4D) k_fake = make_fake_compact_tensor( cutlass.BFloat16, (sym_a, sym_b, H, K), @@ -1878,52 +1883,52 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas ) w_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, K), + (sym_a, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) u_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, V), + (sym_a, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, sym_b, H), + (sym_a, sym_b, HV), stride_order=(2, 1, 0), assumed_align=128, ) gk_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, sym_b, H, K), + (sym_a, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) v_new_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, V), + (sym_a, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) h_out_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_nt, H, K, V), + (sym_a, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128, ) - # h0/ht use [B, H, K, V] (non-varlen) or [num_seqs, H, K, V] (varlen) + # h0/ht use [B, HV, K, V] (non-varlen) or [num_seqs, HV, K, V] (varlen) # In varlen mode, num_seqs != T_total, so use a separate sym_ns h0_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_ns, H, K, V), + (sym_ns, HV, K, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) ht_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_ns, H, K, V), + (sym_ns, HV, K, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) @@ -1958,7 +1963,7 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas cu_fake, co_fake, ws_fake, - (Int32(1), Int32(1), Int32(H), Int32(K), Int32(V)), + (Int32(1), Int32(1), Int32(H), Int32(HV), Int32(K), Int32(V)), Int32(1), # total_nt dummy Int32(0), # use_g Int32(0), # use_gk @@ -1971,7 +1976,7 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas return compiled_fn -def _get_compiled_delta_h(is_varlen, persistent, H, K, V, chunk_size): +def _get_compiled_delta_h(is_varlen, persistent, H, HV, K, V, chunk_size): """Get a compiled ChunkDeltaRuleFwdH kernel with on-demand (lazy) compilation. Each variant is compiled exactly once and cached. Compilation is deferred @@ -1980,14 +1985,15 @@ def _get_compiled_delta_h(is_varlen, persistent, H, K, V, chunk_size): where a subsequent cute.compile can invalidate previously compiled but not-yet-executed functions. - Cache key: (is_varlen, persistent, H, K, V, chunk_size, USE_FAST_MATH) + Cache key: (is_varlen, persistent, H, HV, K, V, chunk_size, USE_FAST_MATH) """ - key = (is_varlen, persistent, H, K, V, chunk_size, USE_FAST_MATH) + key = (is_varlen, persistent, H, HV, K, V, chunk_size, USE_FAST_MATH) if key not in _delta_h_kernel_cache: _delta_h_kernel_cache[key] = _compile_delta_h_variant( is_varlen, persistent, H, + HV, K, V, chunk_size, @@ -2016,13 +2022,17 @@ def chunk_gated_delta_rule_fwd_h( Interface aligned with FLA's chunk_gated_delta_rule_fwd_h for fair benchmarking. Allocates output tensors internally and returns (h, v_new, final_state). + GVA (Gated Value Attention): k uses H (QK) heads; w, u, g, gk, h, h0, ht + use HV (value) heads. HV is inferred from u.shape[2]. When H == HV this + reduces to standard (non-GVA) behavior. + Args: - k: key tensor [B, T, H, K] bf16 - w: decay weight tensor [B, T, H, K] bf16 - u: value tensor [B, T, H, V] bf16 - g: scalar gate [B, T, H] fp32, or None - gk: key gate [B, T, H, K] fp32, or None - initial_state: h0 [N, H, K, V] fp32, or None + k: key tensor [B, T, H, K] bf16 + w: decay weight tensor [B, T, HV, K] bf16 + u: value tensor [B, T, HV, V] bf16 + g: scalar gate [B, T, HV] fp32, or None + gk: key gate [B, T, HV, K] fp32, or None + initial_state: h0 [N, HV, K, V] fp32, or None output_final_state: whether to return final_state chunk_size: chunk size (default 64) save_new_value: whether to return v_new @@ -2032,11 +2042,12 @@ def chunk_gated_delta_rule_fwd_h( Returns: (h, v_new, final_state) — same as FLA - h: [B, NT, H, K, V] bf16 (or [1, NT_total, H, K, V] for varlen) - v_new: [B, T, H, V] bf16 (or None if save_new_value=False) - final_state: [N, H, K, V] fp32 (or None if output_final_state=False) + h: [B, NT, HV, K, V] bf16 (or [1, NT_total, HV, K, V] for varlen) + v_new: [B, T, HV, V] bf16 (or None if save_new_value=False) + final_state: [N, HV, K, V] fp32 (or None if output_final_state=False) """ B, T, H, K_dim = k.shape + HV = u.shape[2] V_dim = u.shape[3] BT = chunk_size is_varlen = cu_seqlens is not None @@ -2069,29 +2080,29 @@ def chunk_gated_delta_rule_fwd_h( w_kern = w[0] u_kern = u[0] # Use torch.empty for dummies the kernel won't read (flag-gated) - g_kern = g[0] if g is not None else torch.empty(T, H, device=k.device, dtype=torch.float32) - gk_kern = gk[0] if gk is not None else torch.empty(T, H, K_dim, device=k.device, dtype=torch.float32) + g_kern = g[0] if g is not None else torch.empty(T, HV, device=k.device, dtype=torch.float32) + gk_kern = gk[0] if gk is not None else torch.empty(T, HV, K_dim, device=k.device, dtype=torch.float32) # Allocate outputs (3D for kernel) - h_out_kern = k.new_empty(total_nt, H, K_dim, V_dim) # bf16 + h_out_kern = k.new_empty(total_nt, HV, K_dim, V_dim) # bf16 v_new_kern = torch.empty_like(u_kern) # always allocate; kernel checks save_v_new flag h0_kern = ( initial_state if initial_state is not None - else torch.empty(N, H, K_dim, V_dim, device=k.device, dtype=torch.float32) + else torch.empty(N, HV, K_dim, V_dim, device=k.device, dtype=torch.float32) ) # ht is purely an output (kernel writes all elements when store_final_state=1); # use empty instead of zeros to skip the zero-fill kernel launch. # NOTE: Ensure final output is zeros # vLLM will use padding for CUDA Graph - ht_kern = torch.zeros(N, H, K_dim, V_dim, device=k.device, dtype=torch.float32) + ht_kern = torch.zeros(N, HV, K_dim, V_dim, device=k.device, dtype=torch.float32) # Workspace: first 4 bytes used as atomic counter for dynamic scheduling workspace = torch.zeros(max(N * 128, 4), dtype=torch.uint8, device=k.device) - ps = (Int32(N), Int32(T), Int32(H), Int32(K_dim), Int32(V_dim)) + ps = (Int32(N), Int32(T), Int32(H), Int32(HV), Int32(K_dim), Int32(V_dim)) - compiled_fn = _get_compiled_delta_h(True, persistent, H, K_dim, V_dim, chunk_size) + compiled_fn = _get_compiled_delta_h(True, persistent, H, HV, K_dim, V_dim, chunk_size) compiled_fn( k_kern, w_kern, @@ -2125,29 +2136,29 @@ def chunk_gated_delta_rule_fwd_h( N = B # Allocate outputs - h = k.new_empty(B, NT, H, K_dim, V_dim) # bf16 + h = k.new_empty(B, NT, HV, K_dim, V_dim) # bf16 v_new_out = torch.empty_like(u) # always allocate; kernel checks save_v_new flag # Use torch.empty for dummies the kernel won't read (flag-gated) h0 = ( initial_state if initial_state is not None - else torch.empty(B, H, K_dim, V_dim, device=k.device, dtype=torch.float32) + else torch.empty(B, HV, K_dim, V_dim, device=k.device, dtype=torch.float32) ) # ht must share sym_ns (first dim) with h0, so always use B - ht = k.new_zeros(B, H, K_dim, V_dim, dtype=torch.float32) + ht = k.new_zeros(B, HV, K_dim, V_dim, dtype=torch.float32) # Dummy tensors for unused optional gate inputs (kernel checks flags) - g_kern = g if g is not None else torch.empty(B, T, H, device=k.device, dtype=torch.float32) - gk_kern = gk if gk is not None else torch.empty(B, T, H, K_dim, device=k.device, dtype=torch.float32) + g_kern = g if g is not None else torch.empty(B, T, HV, device=k.device, dtype=torch.float32) + gk_kern = gk if gk is not None else torch.empty(B, T, HV, K_dim, device=k.device, dtype=torch.float32) # Dummy cu_seqlens / chunk_offsets / workspace (kernel requires them) cu_dummy = torch.empty(2, dtype=torch.int32, device=k.device) co_dummy = torch.empty(2, dtype=torch.int32, device=k.device) ws_dummy = torch.empty(128, dtype=torch.uint8, device=k.device) - ps = (Int32(B), Int32(T), Int32(H), Int32(K_dim), Int32(V_dim)) + ps = (Int32(B), Int32(T), Int32(H), Int32(HV), Int32(K_dim), Int32(V_dim)) - compiled_fn = _get_compiled_delta_h(False, persistent, H, K_dim, V_dim, chunk_size) + compiled_fn = _get_compiled_delta_h(False, persistent, H, HV, K_dim, V_dim, chunk_size) compiled_fn( k, w, @@ -2181,21 +2192,25 @@ def main(): parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--seq_len", type=int, default=256) parser.add_argument("--num_heads", type=int, default=1) + parser.add_argument( + "--num_v_heads", type=int, default=None, help="Number of value heads (default: num_heads, i.e. no GVA)" + ) parser.add_argument("--head_dim_k", type=int, default=128) parser.add_argument("--head_dim_v", type=int, default=128) parser.add_argument("--chunk_size", type=int, default=64) args = parser.parse_args() B, T, H, K, V = args.batch_size, args.seq_len, args.num_heads, args.head_dim_k, args.head_dim_v + HV = args.num_v_heads if args.num_v_heads is not None else H BT = args.chunk_size NT = (T + BT - 1) // BT - print(f"V2 Test: B={B}, T={T}, H={H}, K={K}, V={V}, BT={BT}, NT={NT}") + print(f"V2 Test: B={B}, T={T}, H={H}, HV={HV}, K={K}, V={V}, BT={BT}, NT={NT}") torch.manual_seed(42) k = torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16) * 0.1 - w = torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16) * 0.1 - u = torch.randn(B, T, H, V, device="cuda", dtype=torch.bfloat16) * 0.1 + w = torch.randn(B, T, HV, K, device="cuda", dtype=torch.bfloat16) * 0.1 + u = torch.randn(B, T, HV, V, device="cuda", dtype=torch.bfloat16) * 0.1 def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, store_ht, do_save_vnew=0): h_out, v_new, ht = chunk_gated_delta_rule_fwd_h( @@ -2212,24 +2227,28 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st torch.cuda.synchronize() # Ensure consistent return shapes for backward compat with manual tests if h_out is None: - h_out = torch.zeros(B, NT, H, K, V, device="cuda", dtype=torch.bfloat16) + h_out = torch.zeros(B, NT, HV, K, V, device="cuda", dtype=torch.bfloat16) if v_new is None: - v_new = torch.zeros(B, T, H, V, device="cuda", dtype=torch.bfloat16) + v_new = torch.zeros(B, T, HV, V, device="cuda", dtype=torch.bfloat16) if ht is None: - ht = torch.zeros(B, H, K, V, device="cuda", dtype=torch.float32) + ht = torch.zeros(B, HV, K, V, device="cuda", dtype=torch.float32) return h_out, v_new, ht all_pass = True + # For GVA (H != HV), expand k to HV heads for reference comparison + G = HV // H + k_ref = k.repeat_interleave(G, dim=2) if G > 1 else k + # ===== Test 1: No gating, no h0 ===== print("\n" + "=" * 60) print("Test 1: No gating, no h0") - g_z = torch.zeros(B, T, H, device="cuda", dtype=torch.float32) - gk_z = torch.zeros(B, T, H, K, device="cuda", dtype=torch.float32) - h0_z = torch.zeros(B, H, K, V, device="cuda", dtype=torch.float32) + g_z = torch.zeros(B, T, HV, device="cuda", dtype=torch.float32) + gk_z = torch.zeros(B, T, HV, K, device="cuda", dtype=torch.float32) + h0_z = torch.zeros(B, HV, K, V, device="cuda", dtype=torch.float32) h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_z, 0, 0, 0, 0) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, h0=None, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, h0=None, chunk_size=BT) max_diff = 0.0 for t in range(min(NT - 1, len(h_ref_bf16))): @@ -2243,15 +2262,14 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Test 2: With gk + h0 ===== print("\n" + "=" * 60) print("Test 2: With gk + h0") - gk_val = torch.randn(B, T, H, K, device="cuda", dtype=torch.float32) * 0.1 + gk_val = torch.randn(B, T, HV, K, device="cuda", dtype=torch.float32) * 0.1 gk_val = -torch.abs(gk_val) gk_val = gk_val.cumsum(dim=1) - # Pre-scale by RCP_LN2 to match KDA convention (kernel does exp2 directly) gk_val = gk_val * INV_LN2 - h0_val = torch.randn(B, H, K, V, device="cuda", dtype=torch.float32) * 0.01 + h0_val = torch.randn(B, HV, K, V, device="cuda", dtype=torch.float32) * 0.01 h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_val, 0, 1, 1, 0) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) max_diff = 0.0 for t in range(min(NT - 1, len(h_ref_bf16))): @@ -2265,14 +2283,13 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Test 3: With gk gating ===== print("\n" + "=" * 60) print("Test 3: With gk gating") - gk_val = torch.randn(B, T, H, K, device="cuda", dtype=torch.float32) * 0.1 + gk_val = torch.randn(B, T, HV, K, device="cuda", dtype=torch.float32) * 0.1 gk_val = -torch.abs(gk_val) gk_val = gk_val.cumsum(dim=1) - # Pre-scale by RCP_LN2 to match KDA convention (kernel does exp2 directly) gk_val = gk_val * INV_LN2 h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_z, 0, 1, 0, 0) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=None, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, gk=gk_val, h0=None, chunk_size=BT) max_diff = 0.0 for t in range(min(NT - 1, len(h_ref_bf16))): @@ -2286,10 +2303,10 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Test 4: With h0 initial state ===== print("\n" + "=" * 60) print("Test 4: With h0 initial state") - h0_val = torch.randn(B, H, K, V, device="cuda", dtype=torch.float32) * 0.01 + h0_val = torch.randn(B, HV, K, V, device="cuda", dtype=torch.float32) * 0.01 h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_val, 0, 0, 1, 0) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, h0=h0_val, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, h0=h0_val, chunk_size=BT) # h_out[0] should be h0 (bf16 rounded) h0_bf16 = h0_val.to(torch.bfloat16) @@ -2310,12 +2327,9 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st print("Test 5: store_final_state") h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_z, 0, 0, 0, 1) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, h0=None, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, h0=None, chunk_size=BT) - # ht should match the last h_ref (after all chunks) - ht_ref = h_ref_bf16[-1] # last chunk's state - # ht layout: (B, H, K, V) but kernel writes in transposed (V, K) format - # Compare ht[0, 0] with ht_ref + ht_ref = h_ref_bf16[-1] d_ht = (ht[0, 0].float() - ht_ref.float()).abs().max().item() print(f" ht vs ref: {d_ht:.6f}") t5_pass = d_ht < 0.5 @@ -2327,7 +2341,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st print("Test 6: gk + h0 + ht (all features)") h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_val, 0, 1, 1, 1) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) max_diff = 0.0 for t in range(min(NT - 1, len(h_ref_bf16))): @@ -2375,7 +2389,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st print("Test 8: v_new output (no gating)") h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_z, 0, 0, 0, 0, do_save_vnew=1) - vnew_ref, _ = reference_bf16_roundtrip(k, w, u, h0=None, chunk_size=BT) + vnew_ref, _ = reference_bf16_roundtrip(k_ref, w, u, h0=None, chunk_size=BT) d_vnew = (v_new.float() - vnew_ref.float()).abs().max().item() print(f" v_new max diff: {d_vnew:.6f}") @@ -2388,7 +2402,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st print("Test 9: v_new output (with gk gating)") h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_z, 0, 1, 0, 0, do_save_vnew=1) - vnew_ref, _ = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=None, chunk_size=BT) + vnew_ref, _ = reference_bf16_roundtrip(k_ref, w, u, gk=gk_val, h0=None, chunk_size=BT) d_vnew = (v_new.float() - vnew_ref.float()).abs().max().item() print(f" v_new max diff: {d_vnew:.6f}") @@ -2418,12 +2432,13 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Benchmark ===== print("\n" + "=" * 60) - print("Benchmark: B=4, T=4096, H=64, K=128, V=128") - Bb, Tb, Hb = 4, 4096, 64 + hv_tag = f"/{HV}" if HV != H else "" + print(f"Benchmark: B=4, T=4096, H={H}{hv_tag}, K=128, V=128") + Bb, Tb = 4, 4096 torch.manual_seed(999) - kb = torch.randn(Bb, Tb, Hb, K, device="cuda", dtype=torch.bfloat16) * 0.1 - wb = torch.randn(Bb, Tb, Hb, K, device="cuda", dtype=torch.bfloat16) * 0.1 - ub = torch.randn(Bb, Tb, Hb, V, device="cuda", dtype=torch.bfloat16) * 0.1 + kb = torch.randn(Bb, Tb, H, K, device="cuda", dtype=torch.bfloat16) * 0.1 + wb = torch.randn(Bb, Tb, HV, K, device="cuda", dtype=torch.bfloat16) * 0.1 + ub = torch.randn(Bb, Tb, HV, V, device="cuda", dtype=torch.bfloat16) * 0.1 def run_bench(): chunk_gated_delta_rule_fwd_h( diff --git a/cula/ops/fwd_o.py b/cula/ops/fwd_o.py index 2c820aa..9d4bcb4 100644 --- a/cula/ops/fwd_o.py +++ b/cula/ops/fwd_o.py @@ -198,7 +198,7 @@ def __init__( ) self.buffer_align_bytes = 1024 - def _compute_grid(self, B, T, H, V, total_nt=None): + def _compute_grid(self, B, T, HV, V, total_nt=None): """Compute grid dimensions for kernel launch.""" num_v_tiles = (V + self.BV - 1) // self.BV if self.persistent: @@ -210,10 +210,10 @@ def _compute_grid(self, B, T, H, V, total_nt=None): return (sm_count, 1, 1) elif self.is_varlen: # Non-persistent varlen: one CTA per work unit. - total_work_units = num_v_tiles * total_nt * H + total_work_units = num_v_tiles * total_nt * HV return (total_work_units, 1, 1) NT = (T + self.BT - 1) // self.BT - return (num_v_tiles, NT, B * H) + return (num_v_tiles, NT, B * HV) @staticmethod def _plan_tmem_offsets( @@ -260,14 +260,14 @@ def _plan_tmem_offsets( def __call__( self, q_in: cute.Tensor, # [B, T, H, K] (B=1 for varlen) - v_in: cute.Tensor, # [B, T, H, V] (B=1 for varlen) - g_in: cute.Tensor, # [B, T, H, K] fp32 (B=1 for varlen) - h_in: cute.Tensor, # [B, NT, H, K, V] (B=1 for varlen) - o_in: cute.Tensor, # [B, T, H, V] (B=1 for varlen) - A_in: cute.Tensor, # [B, T, H, BT] (B=1 for varlen) + v_in: cute.Tensor, # [B, T, HV, V] (B=1 for varlen) + g_in: cute.Tensor, # [B, T, HV, K] fp32 (B=1 for varlen) + h_in: cute.Tensor, # [B, NT, HV, K, V] (B=1 for varlen) + o_in: cute.Tensor, # [B, T, HV, V] (B=1 for varlen) + A_in: cute.Tensor, # [B, T, HV, BT] (B=1 for varlen) cu_seqlens_in: cute.Tensor, # [N+1] int32 chunk_indices_in: cute.Tensor, # [NT, 2] int32 - problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + problem_size: tuple[Int32, Int32, Int32, Int32, Int32, Int32], total_nt: Int32, # total chunks across all seqs (varlen) stream, ): @@ -281,7 +281,7 @@ def __call__( cu_seqlens_ptr = cu_seqlens_in.iterator chunk_indices_ptr = chunk_indices_in.iterator - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size BT = self.BT # For varlen: B=num_seqs, T=max_seqlen (or total_tokens), data_B=1 @@ -303,17 +303,17 @@ def __call__( ) q = cute.make_tensor(q_ptr, q_layout) - # g layout: token-indexed (T, K, (H, data_B)) — fp32 (separate from q) + # g layout: token-indexed (T, K, (HV, data_B)) — fp32 g_layout = cute.make_layout( - (T, K, (H, data_B)), - stride=(H * K, 1, (K, T * H * K)), + (T, K, (HV, data_B)), + stride=(HV * K, 1, (K, T * HV * K)), ) g = cute.make_tensor(g_ptr, g_layout) - # o: row-major (T, V, (H, data_B)) — token-indexed for direct GMEM write (varlen) + # o: row-major (T, V, (HV, data_B)) — token-indexed for direct GMEM write (varlen) o_layout = cute.make_layout( - (T, V, (H, data_B)), - stride=(H * V, 1, (V, T * H * V)), + (T, V, (HV, data_B)), + stride=(HV * V, 1, (V, T * HV * V)), ) o = cute.make_tensor(o_ptr, o_layout) @@ -323,8 +323,8 @@ def __call__( # TMA descriptor collapses the degenerate H dim; keeping batch # at coord-2 guarantees it always maps to an existing TMA dim. v_T_layout = cute.make_layout( - (V, T, (data_B, H)), - stride=(1, H * V, (T * H * V, V)), + (V, T, (data_B, HV)), + stride=(1, HV * V, (T * HV * V, V)), ) v_T = cute.make_tensor(v_ptr, v_T_layout) @@ -337,15 +337,15 @@ def __call__( h_nt_total = B * NT # NOTE: Mode 2 uses (batch, H) order — see v_T comment above. h_T_layout = cute.make_layout( - (V, K, (h_nt_total, H)), - stride=(1, V, (H * K * V, K * V)), + (V, K, (h_nt_total, HV)), + stride=(1, V, (HV * K * V, K * V)), ) h_T = cute.make_tensor(h_ptr, h_T_layout) - # A layout: token-indexed (T, BT, (H, data_B)) + # A layout: token-indexed (T, BT, (HV, data_B)) a_layout = cute.make_layout( - (T, BT, (H, data_B)), - stride=(H * BT, 1, (BT, T * H * BT)), + (T, BT, (HV, data_B)), + stride=(HV * BT, 1, (BT, T * HV * BT)), ) A = cute.make_tensor(A_ptr, a_layout) @@ -570,7 +570,7 @@ class SharedStorage: ) # ===================== Grid ===================== - grid = self._compute_grid(B, T, H, V, total_nt=total_nt) + grid = self._compute_grid(B, T, HV, V, total_nt=total_nt) # ===================== cu_seqlens / chunk_indices tensors ===================== cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((B + 1,))) @@ -683,7 +683,7 @@ def kernel( problem_size, total_nt, ): - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size BT = self.BT # ===================== Work decode ===================== @@ -693,12 +693,13 @@ def kernel( # Persistent kernel: 1D grid, work decoded inside each warp's loop block_idx_x = cute.arch.block_idx()[0] grid_dim_x = cute.arch.grid_dim()[0] - total_work_units = num_v_tiles * total_nt * H + total_work_units = num_v_tiles * total_nt * HV num_iters = (total_work_units - block_idx_x + grid_dim_x - 1) // grid_dim_x # Pre-initialize persistent loop variables (CuTe DSL requirement) i_v = Int32(0) chunk_global_idx = Int32(0) i_h = Int32(0) + i_qh = Int32(0) i_b = Int32(0) i_t = Int32(0) tok_offset = Int32(0) @@ -713,8 +714,9 @@ def kernel( i_v = cute.arch.block_idx()[0] i_t = cute.arch.block_idx()[1] i_bh = cute.arch.block_idx()[2] - i_b = i_bh // H - i_h = i_bh % H + i_b = i_bh // HV + i_h = i_bh % HV + i_qh = i_h // (HV // H) tok_offset = i_b * T seq_len = T data_bidx = i_b @@ -874,6 +876,7 @@ def kernel( temp_work = work_idx // num_v_tiles chunk_flat = temp_work % total_nt i_h = temp_work // total_nt + i_qh = i_h // (HV // H) if cutlass.const_expr(self.is_varlen): i_b = chunk_indices[(chunk_flat, 0)] i_t = chunk_indices[(chunk_flat, 1)] @@ -901,7 +904,7 @@ def kernel( # --- Unconditional TMA partitions --- bSG_sQ, bSG_gQ = self._epilog_partition_varlen( tma_atom_q, - tma_q_v[None, None, (i_h, data_bidx)], + tma_q_v[None, None, (i_qh, data_bidx)], (self.BT, self.BK), sQ_epi, ) @@ -1001,7 +1004,7 @@ def kernel( # Bulk prefetch: SMEM → registers (all 256 bf16 at once) cute.autovec_copy(tOsO, tOrO) - o_chunk_raw = o_tensor.iterator + (tok_offset + i_t * BT) * H * V + i_h * V + i_v * self.BV + o_chunk_raw = o_tensor.iterator + (tok_offset + i_t * BT) * HV * V + i_h * V + i_v * self.BV o_chunk_ptr = cute.make_ptr( self.io_dtype, o_chunk_raw.toint(), @@ -1009,7 +1012,7 @@ def kernel( assumed_align=16, ) o_stride_bt = cute.assume( - H * V, + HV * V, divby=128 // self.io_dtype.width, ) gO_chunk = cute.make_tensor( @@ -1580,7 +1583,7 @@ def reference_chunk_gla_fwd_o(q, v, g, h, A, scale, chunk_size=64): # Compile cache + TVM-FFI API # --------------------------------------------------------------------------- -# Internal cache: maps (is_varlen, persistent, H, K, V, scale, chunk_size) → compiled_fn +# Internal cache: maps (is_varlen, persistent, H, HV, K, V, scale, chunk_size) → compiled_fn _fwd_o_kernel_cache: dict = {} # Pre-allocated dummy tensors for non-varlen path (avoid per-call torch.zeros) @@ -1588,7 +1591,7 @@ def reference_chunk_gla_fwd_o(q, v, g, h, A, scale, chunk_size=64): _fwd_o_dummy_chunk_indices: torch.Tensor = None -def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, use_fast_math): +def _compile_fwd_o_variant(is_varlen, persistent, H, HV, K, V, scale, chunk_size, use_fast_math): """Compile one ChunkGlaFwdO kernel variant. Returns the compiled TVM-FFI callable. Uses make_fake_compact_tensor and make_fake_stream for compilation with @@ -1615,8 +1618,8 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us BT = chunk_size if is_varlen: - # varlen: tensors are [1, T_total, H, ...] (4D with B=1) - # This avoids squeeze(0) CPU overhead at the call site. + # varlen: tensors are [1, T_total, H/HV, ...] (4D with B=1) + # q uses H (QK heads), g/v/o/A use HV (value heads) q_fake = make_fake_compact_tensor( cutlass.BFloat16, (1, sym_b, H, K), @@ -1625,30 +1628,31 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us ) v_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_b, H, V), + (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (1, sym_b, H, K), + (1, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) o_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_b, H, V), + (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) A_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_b, H, BT), + (1, sym_b, HV, BT), stride_order=(3, 2, 1, 0), assumed_align=128, ) else: - # non-varlen: tensors are [B, T, H, ...] (4D) + # non-varlen: tensors are [B, T, H/HV, ...] (4D) + # q uses H (QK heads), g/v/o/A use HV (value heads) q_fake = make_fake_compact_tensor( cutlass.BFloat16, (sym_a, sym_b, H, K), @@ -1657,42 +1661,42 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us ) v_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, V), + (sym_a, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, sym_b, H, K), + (sym_a, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) o_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, V), + (sym_a, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) A_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, BT), + (sym_a, sym_b, HV, BT), stride_order=(3, 2, 1, 0), assumed_align=128, ) if is_varlen: - # varlen: h is [1, NT_total, H, K, V] (5D with B=1) + # varlen: h is [1, NT_total, HV, K, V] (5D with B=1) h_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_nt, H, K, V), + (1, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128, ) else: - # non-varlen: h is [B, NT, H, K, V] (5D) + # non-varlen: h is [B, NT, HV, K, V] (5D) h_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_nt, H, K, V), + (sym_a, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128, ) @@ -1720,7 +1724,7 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us A_fake, cu_fake, ci_fake, - (Int32(1), Int32(1), Int32(H), Int32(K), Int32(V)), + (Int32(1), Int32(1), Int32(H), Int32(HV), Int32(K), Int32(V)), Int32(1), stream_fake, options=COMPILE_OPTIONS, @@ -1728,7 +1732,7 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us return compiled_fn -def _get_compiled_fwd_o(is_varlen, persistent, H, K, V, scale, chunk_size): +def _get_compiled_fwd_o(is_varlen, persistent, H, HV, K, V, scale, chunk_size): """Get a compiled ChunkGlaFwdO kernel with on-demand (lazy) compilation. Each variant is compiled exactly once and cached. Compilation is deferred @@ -1737,14 +1741,15 @@ def _get_compiled_fwd_o(is_varlen, persistent, H, K, V, scale, chunk_size): where a subsequent cute.compile can invalidate previously compiled but not-yet-executed functions. - Cache key: (is_varlen, persistent, H, K, V, scale, chunk_size, USE_FAST_MATH) + Cache key: (is_varlen, persistent, H, HV, K, V, scale, chunk_size, USE_FAST_MATH) """ - key = (is_varlen, persistent, H, K, V, scale, chunk_size, USE_FAST_MATH) + key = (is_varlen, persistent, H, HV, K, V, scale, chunk_size, USE_FAST_MATH) if key not in _fwd_o_kernel_cache: _fwd_o_kernel_cache[key] = _compile_fwd_o_variant( is_varlen, persistent, H, + HV, K, V, scale, @@ -1778,15 +1783,15 @@ def chunk_gla_fwd_o( sym_int() is used for B, T, NT so a single compilation handles all batch-size / sequence-length combinations. - Cache key: (is_varlen, persistent, H, K, V, scale, chunk_size) + Cache key: (is_varlen, persistent, H, HV, K, V, scale, chunk_size) Args: - q: query tensor — [B, T, H, K] bf16 (both non-varlen and varlen with B=1) - v: value tensor — [B, T, H, V] bf16 (both non-varlen and varlen with B=1) - g: gate tensor — [B, T, H, K] fp32 (both non-varlen and varlen with B=1) - h: state tensor — [B, NT, H, K, V] bf16 (B=1 for varlen) - o: output tensor (pre-allocated) — same shape as q but with V dim - A: attention matrix — [B, T, H, BT] bf16 (both non-varlen and varlen with B=1) + q: query tensor — [B, T, H, K] bf16 (H = QK heads) + v: value tensor — [B, T, HV, V] bf16 (HV = value heads, HV >= H) + g: gate tensor — [B, T, HV, K] fp32 + h: state tensor — [B, NT, HV, K, V] bf16 (B=1 for varlen) + o: output tensor (pre-allocated) — [B, T, HV, V] bf16 + A: attention matrix — [B, T, HV, BT] bf16 scale: attention scale factor chunk_size: chunk size (default: 64) cu_seqlens: cumulative sequence lengths [N+1] int32 (varlen only) @@ -1802,20 +1807,22 @@ def chunk_gla_fwd_o( "cu_seqlens and chunk_indices are required for varlen mode" ) assert q.dim() == 4 and q.shape[0] == 1, f"varlen mode expects [1, T_total, H, K] input, got shape {q.shape}" - assert h.dim() == 5 and h.shape[0] == 1, f"varlen mode expects [1, NT_total, H, K, V] for h, got shape {h.shape}" + assert h.dim() == 5 and h.shape[0] == 1, f"varlen mode expects [1, NT_total, HV, K, V] for h, got shape {h.shape}" T_total = q.shape[1] H = q.shape[2] + HV = v.shape[2] K = q.shape[3] V = v.shape[3] num_seqs = cu_seqlens.shape[0] - 1 total_nt_val = chunk_indices.shape[0] - ps = (Int32(num_seqs), Int32(T_total), Int32(H), Int32(K), Int32(V)) + ps = (Int32(num_seqs), Int32(T_total), Int32(H), Int32(HV), Int32(K), Int32(V)) else: B, T, H, K = q.shape + HV = v.shape[2] V = v.shape[3] NT = (T + chunk_size - 1) // chunk_size total_nt_val = B * NT - ps = (Int32(B), Int32(T), Int32(H), Int32(K), Int32(V)) + ps = (Int32(B), Int32(T), Int32(H), Int32(HV), Int32(K), Int32(V)) if cu_seqlens is None: global _fwd_o_dummy_cu_seqlens if _fwd_o_dummy_cu_seqlens is None or _fwd_o_dummy_cu_seqlens.device != q.device: @@ -1831,6 +1838,7 @@ def chunk_gla_fwd_o( is_varlen, persistent, H, + HV, K, V, scale, @@ -1864,6 +1872,7 @@ def main(): parser.add_argument("--B", type=int, default=2) parser.add_argument("--T", type=int, default=256) parser.add_argument("--H", type=int, default=4) + parser.add_argument("--HV", type=int, default=None, help="Number of value heads (default: same as --H)") parser.add_argument("--K", type=int, default=128) parser.add_argument("--V", type=int, default=128) parser.add_argument("--scale", type=float, default=None) @@ -1873,12 +1882,16 @@ def main(): if args.scale is None: args.scale = args.K**-0.5 B, T, H, K, V = args.B, args.T, args.H, args.K, args.V + HV = args.HV if args.HV is not None else H + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + G = HV // H BT = args.chunk_size scale = args.scale NT = (T + BT - 1) // BT dtype, device = torch.bfloat16, "cuda" - print(f"Config: B={B}, T={T}, H={H}, K={K}, V={V}, BT={BT}, scale={scale:.4f}") + hv_str = f"/{HV}" if HV != H else "" + print(f"Config: B={B}, T={T}, H={H}{hv_str}, K={K}, V={V}, BT={BT}, scale={scale:.4f}") print(f" Chunks per seq: {NT}, Total chunks: {B * NT}") if args.test in ("correctness", "both"): @@ -1888,13 +1901,14 @@ def main(): print("\n=== Non-Varlen Correctness Test ===") torch.manual_seed(42) q_nv = torch.randn(B, T, H, K, dtype=dtype, device=device) - v_nv = torch.randn(B, T, H, V, dtype=dtype, device=device) - g_nv = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1 - h_nv = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - A_nv = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 + v_nv = torch.randn(B, T, HV, V, dtype=dtype, device=device) + g_nv = torch.randn(B, T, HV, K, dtype=torch.float32, device=device) * 0.1 + h_nv = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + A_nv = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1 - o_ref_nv = reference_chunk_gla_fwd_o(q_nv, v_nv, g_nv, h_nv, A_nv, scale, BT) - o_nv = torch.zeros(B, T, H, V, dtype=dtype, device=device) + q_ref = q_nv.repeat_interleave(G, dim=2) + o_ref_nv = reference_chunk_gla_fwd_o(q_ref, v_nv, g_nv, h_nv, A_nv, scale, BT) + o_nv = torch.zeros(B, T, HV, V, dtype=dtype, device=device) chunk_gla_fwd_o( q=q_nv, @@ -1943,13 +1957,14 @@ def main(): ci_t = build_chunk_indices(seq_lens, BT=BT, device=device) q_flat = torch.randn(1, T_total, H, K, dtype=dtype, device=device) - v_flat = torch.randn(1, T_total, H, V, dtype=dtype, device=device) - g_flat = torch.randn(1, T_total, H, K, dtype=torch.float32, device=device) * 0.1 - h_flat = torch.randn(1, total_nt_val, H, K, V, dtype=dtype, device=device) * 0.01 - A_flat = torch.randn(1, T_total, H, BT, dtype=dtype, device=device) * 0.1 - o_flat = torch.zeros(1, T_total, H, V, dtype=dtype, device=device) + v_flat = torch.randn(1, T_total, HV, V, dtype=dtype, device=device) + g_flat = torch.randn(1, T_total, HV, K, dtype=torch.float32, device=device) * 0.1 + h_flat = torch.randn(1, total_nt_val, HV, K, V, dtype=dtype, device=device) * 0.01 + A_flat = torch.randn(1, T_total, HV, BT, dtype=dtype, device=device) * 0.1 + o_flat = torch.zeros(1, T_total, HV, V, dtype=dtype, device=device) # Reference per-sequence + q_ref_flat = q_flat[:, :, :, :].repeat_interleave(G, dim=2) o_ref_flat = torch.zeros_like(o_flat) for seq_idx, sl in enumerate(seq_lens): s = cu_seqlens_list[seq_idx] @@ -1957,7 +1972,13 @@ def main(): co = chunk_offsets_list[seq_idx] nt_seq = (sl + BT - 1) // BT o_seq = reference_chunk_gla_fwd_o( - q_flat[:, s:e], v_flat[:, s:e], g_flat[:, s:e], h_flat[:, co : co + nt_seq], A_flat[:, s:e], scale, BT + q_ref_flat[:, s:e], + v_flat[:, s:e], + g_flat[:, s:e], + h_flat[:, co : co + nt_seq], + A_flat[:, s:e], + scale, + BT, ) o_ref_flat[:, s:e] = o_seq @@ -1995,12 +2016,13 @@ def main(): for i in range(3): torch.manual_seed(i * 100) q_cr = torch.randn(B, T, H, K, dtype=dtype, device=device) - v_cr = torch.randn(B, T, H, V, dtype=dtype, device=device) - g_cr = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1 - h_cr = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - A_cr = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 - o_cr = torch.zeros(B, T, H, V, dtype=dtype, device=device) - o_ref_cr = reference_chunk_gla_fwd_o(q_cr, v_cr, g_cr, h_cr, A_cr, scale, BT) + v_cr = torch.randn(B, T, HV, V, dtype=dtype, device=device) + g_cr = torch.randn(B, T, HV, K, dtype=torch.float32, device=device) * 0.1 + h_cr = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + A_cr = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1 + o_cr = torch.zeros(B, T, HV, V, dtype=dtype, device=device) + q_ref_cr = q_cr.repeat_interleave(G, dim=2) + o_ref_cr = reference_chunk_gla_fwd_o(q_ref_cr, v_cr, g_cr, h_cr, A_cr, scale, BT) chunk_gla_fwd_o( q=q_cr, @@ -2027,11 +2049,11 @@ def main(): for bench_T in [1024, 2048, 4096]: bench_NT = (bench_T + BT - 1) // BT q_b = torch.randn(B, bench_T, H, K, dtype=dtype, device=device) - v_b = torch.randn(B, bench_T, H, V, dtype=dtype, device=device) - g_b = torch.randn(B, bench_T, H, K, dtype=torch.float32, device=device) * 0.1 - h_b = torch.randn(B, bench_NT, H, K, V, dtype=dtype, device=device) * 0.01 - A_b = torch.randn(B, bench_T, H, BT, dtype=dtype, device=device) * 0.1 - o_b = torch.zeros(B, bench_T, H, V, dtype=dtype, device=device) + v_b = torch.randn(B, bench_T, HV, V, dtype=dtype, device=device) + g_b = torch.randn(B, bench_T, HV, K, dtype=torch.float32, device=device) * 0.1 + h_b = torch.randn(B, bench_NT, HV, K, V, dtype=dtype, device=device) * 0.01 + A_b = torch.randn(B, bench_T, HV, BT, dtype=dtype, device=device) * 0.1 + o_b = torch.zeros(B, bench_T, HV, V, dtype=dtype, device=device) # Warmup (also triggers lazy compilation if needed) for _ in range(3): diff --git a/tests/test_kda.py b/tests/test_kda.py index fadadbf..d2b17d3 100644 --- a/tests/test_kda.py +++ b/tests/test_kda.py @@ -35,6 +35,7 @@ "B", "T", "H", + "HV", "D", "gate_logit_normalizer", "mask_p", @@ -46,17 +47,21 @@ [ pytest.param( *test, - id="B{}-T{}-H{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-{}".format(*test), + id="B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-{}".format(*test), ) for test in [ - (1, 63, 1, 128, 1, 0, False, False, True, torch.bfloat16), - (2, 500, 3, 128, 1, 0, False, False, True, torch.bfloat16), - (2, 1000, 3, 128, 1, 0.5, False, False, True, torch.bfloat16), - (3, 1024, 4, 128, 0.1, 0, False, False, True, torch.bfloat16), - (4, 1024, 4, 128, 1, 0, False, False, True, torch.bfloat16), - (4, 1024, 4, 128, 1, 0, True, False, True, torch.bfloat16), - (2, 1500, 4, 128, 10, 0, False, True, True, torch.bfloat16), - (4, 2048, 8, 128, 1, 0, False, True, True, torch.bfloat16), + (1, 63, 1, 1, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 500, 3, 3, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 1000, 3, 3, 128, 1, 0.5, False, False, True, torch.bfloat16), + (3, 1024, 4, 4, 128, 0.1, 0, False, False, True, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, False, False, True, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, True, False, True, torch.bfloat16), + (2, 1500, 4, 4, 128, 10, 0, False, True, True, torch.bfloat16), + (4, 2048, 8, 8, 128, 1, 0, False, True, True, torch.bfloat16), + # GVA cases: HV > H + (2, 1024, 4, 8, 128, 1, 0, True, False, True, torch.bfloat16), + (2, 1500, 2, 4, 128, 10, 0, False, True, True, torch.bfloat16), + (2, 2048, 4, 8, 128, 1, 0, False, True, True, torch.bfloat16), ] ], ) @@ -64,6 +69,7 @@ def test_safe_gate_chunk( B: int, T: int, H: int, + HV: int, D: int, gate_logit_normalizer: float, mask_p: float, @@ -77,11 +83,11 @@ def test_safe_gate_chunk( torch.manual_seed(42) q = torch.rand(B, T, H, D, dtype=dtype) k = torch.rand(B, T, H, D, dtype=dtype) - v = torch.rand(B, T, H, D, dtype=dtype) - g = torch.randn(B, T, H, D, dtype=torch.float if not use_gate_in_kernel else dtype) + v = torch.rand(B, T, HV, D, dtype=dtype) + g = torch.randn(B, T, HV, D, dtype=torch.float if not use_gate_in_kernel else dtype) if use_gate_in_kernel: - A_log = torch.randn(H, dtype=torch.float) - dt_bias = torch.randn(H * D, dtype=torch.float) + A_log = torch.randn(HV, dtype=torch.float) + dt_bias = torch.randn(HV * D, dtype=torch.float) else: g = F.logsigmoid(g) / gate_logit_normalizer g = g * (torch.rand_like(g) > mask_p) @@ -94,8 +100,8 @@ def test_safe_gate_chunk( lower_bound = None naive_kda_gate_fn = naive_kda_gate - beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn(B, H, D, D, dtype=torch.float32) + beta = torch.randn(B, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn(B, HV, D, D, dtype=torch.float32) if use_gate_in_kernel: A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(True), (A_log, dt_bias)) q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, g, beta, h0)) @@ -158,17 +164,18 @@ def test_safe_gate_chunk( @pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize("disable_recompute", [True, False], ids=["no_recomp", "recomp"]) @pytest.mark.parametrize( - ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), + ("H", "HV", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), [ - pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}".format(*test)) + pytest.param(*test, id="H{}-HV{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}".format(*test)) for test in [ - (4, 128, 0.1, [0, 15], torch.bfloat16, True), - (4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True), - (4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), - (4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True), - (4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), + (4, 4, 128, 0.1, [0, 15], torch.bfloat16, True), + (4, 4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True), + (4, 4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), + (4, 4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True), + (4, 4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), # ======Varlen test with simulated trace======= ( + 32, 32, 128, 0, @@ -177,6 +184,7 @@ def test_safe_gate_chunk( True, ), ( + 32, 32, 128, 0, @@ -185,6 +193,7 @@ def test_safe_gate_chunk( True, ), ( + 32, 32, 128, 0, @@ -193,6 +202,20 @@ def test_safe_gate_chunk( True, ), ( + 32, + 32, + 128, + 0, + [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], + torch.bfloat16, + True, + ), + # ======GVA varlen cases: HV > H======= + (2, 4, 128, 0.1, [0, 15], torch.bfloat16, True), + (4, 8, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), + (4, 8, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), + ( + 8, 32, 128, 0, @@ -205,6 +228,7 @@ def test_safe_gate_chunk( ) def test_safe_gate_chunk_varlen( H: int, + HV: int, D: int, mask_p: float, cu_seqlens: list[int], @@ -221,15 +245,15 @@ def test_safe_gate_chunk_varlen( q = torch.randn((1, T, H, D), dtype=dtype) k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn((1, T, H, D), dtype=dtype) - g = F.logsigmoid(torch.randn(1, T, H, D, dtype=torch.float)) + v = torch.randn((1, T, HV, D), dtype=dtype) + g = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float)) mask = torch.rand_like(g) > mask_p g = g * mask + (~mask) * (-1000) if safe_gate: g = g.clamp(-5, 0) - beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn((N, H, D, D), dtype=torch.float32) + beta = torch.randn(1, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn((N, HV, D, D), dtype=torch.float32) q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(), (q, k, v, g, beta, h0)) do = torch.randn_like(v) diff --git a/tests/test_kda_compare_fla.py b/tests/test_kda_compare_fla.py index 9d5e08d..c88a40a 100644 --- a/tests/test_kda_compare_fla.py +++ b/tests/test_kda_compare_fla.py @@ -34,6 +34,7 @@ "B", "T", "H", + "HV", "D", "gate_logit_normalizer", "mask_p", @@ -45,17 +46,21 @@ [ pytest.param( *test, - id="B{}-T{}-H{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-{}".format(*test), + id="B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-{}".format(*test), ) for test in [ - (1, 63, 1, 128, 1, 0, False, False, True, torch.bfloat16), - (2, 500, 3, 128, 1, 0, False, False, True, torch.bfloat16), - (2, 1000, 3, 128, 1, 0.5, False, False, True, torch.bfloat16), - (3, 1024, 4, 128, 0.1, 0, False, False, True, torch.bfloat16), - (4, 1024, 4, 128, 1, 0, False, False, True, torch.bfloat16), - (4, 1024, 4, 128, 1, 0, True, False, True, torch.bfloat16), - (2, 1500, 4, 128, 10, 0, False, True, True, torch.bfloat16), - (4, 2048, 8, 128, 1, 0, False, True, True, torch.bfloat16), + (1, 63, 1, 1, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 500, 3, 3, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 1000, 3, 3, 128, 1, 0.5, False, False, True, torch.bfloat16), + (3, 1024, 4, 4, 128, 0.1, 0, False, False, True, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, False, False, True, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, True, False, True, torch.bfloat16), + (2, 1500, 4, 4, 128, 10, 0, False, True, True, torch.bfloat16), + (4, 2048, 8, 8, 128, 1, 0, False, True, True, torch.bfloat16), + # GVA cases: HV > H + (2, 1024, 4, 8, 128, 1, 0, True, False, True, torch.bfloat16), + (2, 1500, 2, 4, 128, 10, 0, False, True, True, torch.bfloat16), + (2, 2048, 4, 8, 128, 1, 0, False, True, True, torch.bfloat16), ] ], ) @@ -63,6 +68,7 @@ def test_safe_gate_chunk( B: int, T: int, H: int, + HV: int, D: int, gate_logit_normalizer: float, mask_p: float, @@ -76,11 +82,11 @@ def test_safe_gate_chunk( torch.manual_seed(42) q = torch.rand(B, T, H, D, dtype=dtype) k = torch.rand(B, T, H, D, dtype=dtype) - v = torch.rand(B, T, H, D, dtype=dtype) - g = torch.randn(B, T, H, D, dtype=torch.float if not use_gate_in_kernel else dtype) + v = torch.rand(B, T, HV, D, dtype=dtype) + g = torch.randn(B, T, HV, D, dtype=torch.float if not use_gate_in_kernel else dtype) if use_gate_in_kernel: - A_log = torch.randn(H, dtype=torch.float) - dt_bias = torch.randn(H * D, dtype=torch.float) + A_log = torch.randn(HV, dtype=torch.float) + dt_bias = torch.randn(HV * D, dtype=torch.float) else: g = F.logsigmoid(g) / gate_logit_normalizer g = g * (torch.rand_like(g) > mask_p) @@ -91,8 +97,8 @@ def test_safe_gate_chunk( else: lower_bound = None - beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn(B, H, D, D, dtype=torch.float32) + beta = torch.randn(B, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn(B, HV, D, D, dtype=torch.float32) if use_gate_in_kernel: A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(True), (A_log, dt_bias)) q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, g, beta, h0)) @@ -162,17 +168,18 @@ def test_safe_gate_chunk( @pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize("disable_recompute", [True, False], ids=["no_recomp", "recomp"]) @pytest.mark.parametrize( - ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), + ("H", "HV", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), [ - pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}".format(*test)) + pytest.param(*test, id="H{}-HV{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}".format(*test)) for test in [ - (4, 128, 0.1, [0, 15], torch.bfloat16, True), - (4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True), - (4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), - (4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True), - (4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), + (4, 4, 128, 0.1, [0, 15], torch.bfloat16, True), + (4, 4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True), + (4, 4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), + (4, 4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True), + (4, 4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), # ======Varlen test with simulated trace======= ( + 32, 32, 128, 0, @@ -181,6 +188,7 @@ def test_safe_gate_chunk( True, ), ( + 32, 32, 128, 0, @@ -189,6 +197,7 @@ def test_safe_gate_chunk( True, ), ( + 32, 32, 128, 0, @@ -197,6 +206,20 @@ def test_safe_gate_chunk( True, ), ( + 32, + 32, + 128, + 0, + [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], + torch.bfloat16, + True, + ), + # ======GVA varlen cases: HV > H======= + (2, 4, 128, 0.1, [0, 15], torch.bfloat16, True), + (4, 8, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), + (4, 8, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), + ( + 8, 32, 128, 0, @@ -209,6 +232,7 @@ def test_safe_gate_chunk( ) def test_safe_gate_chunk_varlen( H: int, + HV: int, D: int, mask_p: float, cu_seqlens: list[int], @@ -225,15 +249,15 @@ def test_safe_gate_chunk_varlen( q = torch.randn((1, T, H, D), dtype=dtype) k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn((1, T, H, D), dtype=dtype) - g = F.logsigmoid(torch.randn(1, T, H, D, dtype=torch.float)) + v = torch.randn((1, T, HV, D), dtype=dtype) + g = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float)) mask = torch.rand_like(g) > mask_p g = g * mask + (~mask) * (-1000) if safe_gate: g = g.clamp(-5, 0) - beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn((N, H, D, D), dtype=torch.float32) + beta = torch.randn(1, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn((N, HV, D, D), dtype=torch.float32) q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(), (q, k, v, g, beta, h0)) do = torch.randn_like(v)