diff --git a/benchmarks/bench_chunk_delta_h_bwd_sm90.py b/benchmarks/bench_chunk_delta_h_bwd_sm90.py new file mode 100644 index 0000000..ccd4ae9 --- /dev/null +++ b/benchmarks/bench_chunk_delta_h_bwd_sm90.py @@ -0,0 +1,586 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +""" +bench_chunk_delta_h_bwd_sm90.py - Benchmark: SM90 CuTe DSL bwd_dhu kernel + vs FLA Triton baseline. + +This mirrors benchmarks/bench_chunk_delta_h.py as closely as the backward API +allows: + - non-varlen and varlen modes + - K=128, V=128, BT=64, dtype=bf16 + - same default B/T/H and varlen sequence-count ranges as fwd + - dht/dh0 map to fwd initial_state/output_final_state + +Usage: + python benchmarks/bench_chunk_delta_h_bwd_sm90.py --mode both + python benchmarks/bench_chunk_delta_h_bwd_sm90.py --preset focused --mode non-varlen +""" + +import argparse +import math +import os +import pathlib +import sys + +os.environ.setdefault("CUDA_HOME", "/usr/local/cuda") +os.environ.setdefault("FLA_USE_FAST_OPS", os.getenv("CULA_USE_FAST_MATH", "1")) + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) + +import numpy as np +import torch +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu as fla_bwd_dhu +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + +import cula.ops.chunk_delta_h_bwd as bwd_mod + +chunk_gated_delta_rule_bwd_dhu_sm90 = bwd_mod.chunk_gated_delta_rule_bwd_dhu_sm90 + +if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(line_buffering=True) + + +K, V, BT = 128, 128, 64 +dtype = torch.bfloat16 +device = "cuda" + +WARMUP = 10 +N_ITERS = 100 +NCU_MODE = False + + +def time_kernel(fn, warmup=None, n_iters=None): + if warmup is None: + warmup = 1 if NCU_MODE else WARMUP + if n_iters is None: + n_iters = 1 if NCU_MODE else N_ITERS + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + for _ in range(n_iters): + fn() + end_evt.record() + torch.cuda.synchronize() + return start_evt.elapsed_time(end_evt) / n_iters + + +def accuracy_stats(ref, out): + diff = (ref.float() - out.float()).abs() + max_abs = diff.max().item() + rel_linf = max_abs / max(ref.float().abs().max().item(), 1e-6) + return max_abs, diff.mean().item(), rel_linf + + +def bwd_accuracy_stats(ref_result, cute_result): + ref_dh, ref_dh0, ref_dv2 = ref_result + got_dh, got_dh0, got_dv2 = cute_result + dh_max, dh_mean, dh_rel = accuracy_stats(ref_dh, got_dh) + dv2_max, dv2_mean, dv2_rel = accuracy_stats(ref_dv2, got_dv2) + dh0_max, dh0_mean, dh0_rel = 0.0, 0.0, 0.0 + if ref_dh0 is not None: + dh0_max, dh0_mean, dh0_rel = accuracy_stats(ref_dh0, got_dh0) + return { + "dh_max": dh_max, + "dh_mean": dh_mean, + "dh_rel": dh_rel, + "dh0_max": dh0_max, + "dh0_mean": dh0_mean, + "dh0_rel": dh0_rel, + "dv2_max": dv2_max, + "dv2_mean": dv2_mean, + "dv2_rel": dv2_rel, + "max_diff": max(dh_max, dh0_max, dv2_max), + "mean_diff": max(dh_mean, dh0_mean, dv2_mean), + "max_rel": max(dh_rel, dh0_rel, dv2_rel), + } + + +def make_non_varlen_inputs(B, T, H, use_g, use_gk, use_dht, use_dh0, transpose_state=False, seed=42): + torch.manual_seed(seed) + torch.cuda.empty_cache() + + q = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 + k = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 + w = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 + do = torch.randn(B, T, H, V, device=device, dtype=dtype) * 0.1 + dv = torch.randn(B, T, H, V, device=device, dtype=dtype) * 0.1 + + g = None + if use_g: + g = -torch.abs(torch.randn(B, T, H, device=device, dtype=torch.float32) * 0.01).cumsum(dim=1) + + gk = None + if use_gk: + gk = -torch.abs(torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.01).cumsum(dim=1) + + state_shape = (B, H, V, K) if transpose_state else (B, H, K, V) + dht = torch.randn(state_shape, device=device, dtype=torch.float32) * 0.01 if use_dht else None + dh0 = torch.empty(state_shape, device=device, dtype=torch.float32) if use_dh0 else None + return q, k, w, do, dv, g, gk, dht, dh0 + + +def generate_seq_lens(num_seqs, total_T, ratio, seed=42): + rng = np.random.RandomState(seed) + log_weights = rng.uniform(0, np.log(ratio), num_seqs) + weights = np.exp(log_weights) + raw_lens = weights / weights.sum() * total_T + seq_lens = np.maximum(np.round(raw_lens).astype(int), 1) + diff = total_T - seq_lens.sum() + if diff > 0: + indices = np.argsort(seq_lens) + for i in range(abs(diff)): + seq_lens[indices[i % num_seqs]] += 1 + elif diff < 0: + indices = np.argsort(-seq_lens) + for i in range(abs(diff)): + seq_lens[indices[i % num_seqs]] -= 1 + assert seq_lens.sum() == total_T + return list(seq_lens) + + +def make_varlen_inputs(num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0, seed=42): + seq_lens = generate_seq_lens(num_seqs, total_T, ratio, seed=seed) + cu_seqlens_list = [0] + for seq_len in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + seq_len) + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + cu_seqlens_long = cu_seqlens.long() + + chunk_indices = prepare_chunk_indices(cu_seqlens_long, BT) + chunk_offsets = prepare_chunk_offsets(cu_seqlens_long, BT).int() + + torch.manual_seed(seed) + torch.cuda.empty_cache() + + q = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 + k = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 + w = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 + do = torch.randn(1, total_T, H, V, device=device, dtype=dtype) * 0.1 + dv = torch.randn(1, total_T, H, V, device=device, dtype=dtype) * 0.1 + + g = None + if use_g: + g_raw = torch.randn(1, total_T, H, device=device, dtype=torch.float32) * 0.01 + g = torch.zeros_like(g_raw) + for i in range(num_seqs): + bos = cu_seqlens[i].item() + eos = cu_seqlens[i + 1].item() + g[:, bos:eos] = -torch.abs(g_raw[:, bos:eos]).cumsum(dim=1) + + gk = None + if use_gk: + gk_raw = torch.randn(1, total_T, H, K, device=device, dtype=torch.float32) * 0.01 + gk = torch.zeros_like(gk_raw) + for i in range(num_seqs): + bos = cu_seqlens[i].item() + eos = cu_seqlens[i + 1].item() + gk[:, bos:eos] = -torch.abs(gk_raw[:, bos:eos]).cumsum(dim=1) + + state_shape = (num_seqs, H, K, V) + dht = torch.randn(state_shape, device=device, dtype=torch.float32) * 0.01 if use_dht else None + dh0 = torch.empty(state_shape, device=device, dtype=torch.float32) if use_dh0 else None + return seq_lens, cu_seqlens, cu_seqlens_long, chunk_indices, chunk_offsets, q, k, w, do, dv, g, gk, dht, dh0 + + +def run_fla(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens_long=None): + return fla_bwd_dhu( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + h0=dh0, + dht=dht, + scale=K**-0.5, + cu_seqlens=cu_seqlens_long, + chunk_size=BT, + use_exp2=True, + ) + + +def run_cute(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens=None, chunk_indices=None, chunk_offsets=None): + return chunk_gated_delta_rule_bwd_dhu_sm90( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + h0=dh0, + dht=dht, + scale=K**-0.5, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + chunk_size=BT, + use_exp2=True, + ) + + +def flags_str(use_g, use_gk, use_dht, use_dh0): + flags = [] + if use_g: + flags.append("g") + if use_gk: + flags.append("gk") + if use_dht: + flags.append("dht") + if use_dh0: + flags.append("dh0") + return f" [{','.join(flags)}]" if flags else "" + + +FOCUSED_FEATURE_MODES = { + "A": (False, False, False, False), + "B": (False, False, True, True), + "C": (True, False, False, False), + "D": (False, True, False, False), + "E": (True, True, False, False), +} + + +def focused_feature_label(use_g, use_gk, use_dht, use_dh0): + for name, flags in FOCUSED_FEATURE_MODES.items(): + if flags == (use_g, use_gk, use_dht, use_dh0): + return name + return "-" + + +def build_focused_non_varlen_configs(feature_mode="all"): + modes = FOCUSED_FEATURE_MODES.items() + if feature_mode != "all": + modes = [(feature_mode, FOCUSED_FEATURE_MODES[feature_mode])] + + configs = [] + for B in (1, 2, 4): + for H in (16, 32): + for T in (2048, 4096, 8192, 16384): + for _, flags in modes: + use_g, use_gk, use_dht, use_dh0 = flags + configs.append((B, T, H, use_g, use_gk, use_dht, use_dh0)) + return configs + + +def filter_non_varlen_configs(configs, only_b=None, only_h=None, only_t=None): + if only_b is not None: + configs = [cfg for cfg in configs if cfg[0] == only_b] + if only_t is not None: + configs = [cfg for cfg in configs if cfg[1] == only_t] + if only_h is not None: + configs = [cfg for cfg in configs if cfg[2] == only_h] + return configs + + +def _compile_cache_misses(): + return bwd_mod._compile_bwd_dhu_sm90.cache_info().misses + + +def bench_non_varlen(configs): + print("\n" + "=" * 80) + print(" Non-Varlen Benchmark: CuTe DSL (SM90) bwd_dhu vs FLA Triton") + print("=" * 80) + results = [] + + for B, T, H, use_g, use_gk, use_dht, use_dh0 in configs: + q, k, w, do, dv, g, gk, dht, dh0 = make_non_varlen_inputs(B, T, H, use_g, use_gk, use_dht, use_dh0) + + ref = run_fla(q, k, w, do, dv, g, gk, dht, dh0) + misses_before = _compile_cache_misses() + got = run_cute(q, k, w, do, dv, g, gk, dht, dh0) + compiled_new = _compile_cache_misses() > misses_before + torch.cuda.synchronize() + acc = bwd_accuracy_stats(ref, got) + + def run_fla_case(q=q, k=k, w=w, do=do, dv=dv, g=g, gk=gk, dht=dht, dh0=dh0): + run_fla(q, k, w, do, dv, g, gk, dht, dh0) + + def run_cute_case(q=q, k=k, w=w, do=do, dv=dv, g=g, gk=gk, dht=dht, dh0=dh0): + run_cute(q, k, w, do, dv, g, gk, dht, dh0) + + ms_fla = time_kernel(run_fla_case) + ms_cute = time_kernel(run_cute_case) + speedup = ms_fla / ms_cute if ms_cute > 0 else float("inf") + flag_str = flags_str(use_g, use_gk, use_dht, use_dh0) + feature_mode = focused_feature_label(use_g, use_gk, use_dht, use_dh0) + + r = { + "B": B, + "T": T, + "H": H, + "feature_mode": feature_mode, + "flags": flag_str, + "compiled_new": compiled_new, + "ms_fla": ms_fla, + "ms_cute": ms_cute, + "speedup": speedup, + **acc, + } + results.append(r) + print( + f" B={B:2d} T={T:5d} H={H:3d} mode={feature_mode}{flag_str:<18s} | " + f"abs(dh={acc['dh_max']:.6f} dh0={acc['dh0_max']:.6f} dv2={acc['dv2_max']:.6f}) " + f"rel(dh={acc['dh_rel']:.3e} dh0={acc['dh0_rel']:.3e} dv2={acc['dv2_rel']:.3e}) | " + f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms speedup={speedup:.2f}x | " + f"compiled={'yes' if compiled_new else 'no'}" + ) + + return results + + +def bench_varlen(configs): + print("\n" + "=" * 80) + print(" Varlen Benchmark: CuTe DSL (SM90) bwd_dhu vs FLA Triton") + print("=" * 80) + results = [] + + for num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0 in configs: + ( + seq_lens, + cu_seqlens, + cu_seqlens_long, + chunk_indices, + chunk_offsets, + q, + k, + w, + do, + dv, + g, + gk, + dht, + dh0, + ) = make_varlen_inputs(num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0) + + ref = run_fla(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens_long=cu_seqlens_long) + got = run_cute( + q, + k, + w, + do, + dv, + g, + gk, + dht, + dh0, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + ) + torch.cuda.synchronize() + acc = bwd_accuracy_stats(ref, got) + + def run_fla_case(q=q, k=k, w=w, do=do, dv=dv, g=g, gk=gk, dht=dht, dh0=dh0, cu=cu_seqlens_long): + run_fla(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens_long=cu) + + def run_cute_case( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + cu=cu_seqlens, + ci=chunk_indices, + co=chunk_offsets, + ): + run_cute(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens=cu, chunk_indices=ci, chunk_offsets=co) + + ms_fla = time_kernel(run_fla_case) + ms_cute = time_kernel(run_cute_case) + speedup = ms_fla / ms_cute if ms_cute > 0 else float("inf") + + min_l, max_l = min(seq_lens), max(seq_lens) + avg_l = total_T // num_seqs + tag = f"{num_seqs}seqs T={total_T} [{min_l}..{max_l}] avg={avg_l}" + flag_str = flags_str(use_g, use_gk, use_dht, use_dh0) + + r = { + "tag": tag, + "T_total": total_T, + "H": H, + "n_seqs": num_seqs, + "feature_mode": focused_feature_label(use_g, use_gk, use_dht, use_dh0), + "flags": flag_str, + "compiled_new": False, + "ms_fla": ms_fla, + "ms_cute": ms_cute, + "speedup": speedup, + **acc, + } + results.append(r) + print( + f" {tag:40s} H={H:3d}{flag_str:<18s} | " + f"max={acc['max_diff']:.6f} mean={acc['mean_diff']:.8f} " + f"(dh={acc['dh_max']:.6f} dh0={acc['dh0_max']:.6f} dv2={acc['dv2_max']:.6f}) | " + f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | speedup={speedup:.2f}x" + ) + + return results + + +def print_report(nv_results, vl_results): + sep = "=" * 120 + print(f"\n\n{sep}") + print(" BENCHMARK REPORT: chunk_delta_rule_bwd_dhu") + print(" CuTe DSL (Hopper SM90) vs FLA Triton") + print(f" K={K} V={V} BT={BT} dtype=bf16") + wu = 1 if NCU_MODE else WARMUP + ni = 1 if NCU_MODE else N_ITERS + ncu_tag = " [NCU mode]" if NCU_MODE else "" + print(f" Warmup={wu} Iters={ni}{ncu_tag}") + print(sep) + + if nv_results: + print("\n [Non-Varlen]") + print(f" {'-' * 132}") + print( + f" {'Config':<45s} | {'max_abs':>10s} {'max_rel':>10s} | " + f"{'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s} {'Compiled':>8s}" + ) + print(f" {'-' * 132}") + for r in nv_results: + label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d} mode={r.get('feature_mode', '-')}{r['flags']}" + print( + f" {label:<45s} | {r['max_diff']:10.6f} {r['max_rel']:10.3e} | " + f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x " + f"{'yes' if r.get('compiled_new') else 'no':>8s}" + ) + print(f" {'-' * 132}") + speedups = [r["speedup"] for r in nv_results] + geo = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + print(f" {'Geometric mean':<45s} | {'':>10s} {'':>10s} | {'':>9s} {'':>9s} {geo:7.2f}x {'':>8s}") + + if vl_results: + print("\n [Varlen]") + print(f" {'-' * 120}") + print(f" {'Config':>60s} | {'max_diff':>10s} {'mean_diff':>12s} | {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}") + print(f" {'-' * 120}") + for r in vl_results: + label = f"{r['tag']} H={r['H']:3d}{r['flags']}" + print( + f" {label:>60s} | {r['max_diff']:10.6f} {r['mean_diff']:12.8f} | " + f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x" + ) + print(f" {'-' * 120}") + speedups = [r["speedup"] for r in vl_results] + geo = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + print(f" {'Geometric mean':>60s} | {'':>10s} {'':>12s} | {'':>9s} {'':>9s} {geo:7.2f}x") + + print(f"\n{sep}\n") + + +def main(): + parser = argparse.ArgumentParser(description="bench_chunk_delta_h_bwd_sm90: CuTe DSL (SM90) vs FLA Triton") + parser.add_argument( + "--mode", + type=str, + default="both", + choices=["non-varlen", "varlen", "both"], + help="Which benchmark mode to run (default: both)", + ) + parser.add_argument( + "--preset", + type=str, + default="fwd", + choices=["representative", "fwd", "focused"], + help="fwd mirrors bench_chunk_delta_h.py; representative runs a short subset; focused runs the long non-varlen matrix", + ) + parser.add_argument( + "--feature-mode", + type=str, + default="all", + choices=["all", "A", "B", "C", "D", "E"], + help="For --preset focused: A=no gates/state, B=dht+dh0, C=g, D=gk, E=g+gk", + ) + parser.add_argument("--max-configs", type=int, default=None, help="Run only the first N configs from the selected preset") + parser.add_argument("--filter-b", type=int, default=None, help="Only run non-varlen configs with this B") + parser.add_argument("--filter-h", type=int, default=None, help="Only run non-varlen configs with this H") + parser.add_argument("--filter-t", type=int, default=None, help="Only run non-varlen configs with this T") + parser.add_argument("--warmup", type=int, default=None, help="Override warmup iterations") + parser.add_argument("--iters", type=int, default=None, help="Override timed iterations") + parser.add_argument("--ncu", action="store_true", help="NCU profiling mode: warmup=1, iters=1") + args = parser.parse_args() + + if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 9: + raise RuntimeError("This benchmark requires an SM90/Hopper GPU.") + + global NCU_MODE, WARMUP, N_ITERS + if args.ncu: + NCU_MODE = True + print("[NCU mode] warmup=1, iters=1") + if args.warmup is not None: + WARMUP = args.warmup + if args.iters is not None: + N_ITERS = args.iters + + if args.preset == "focused": + # Focused non-varlen long-token matrix requested for SM90 bwd_dhu tuning. + # Tuple: (B, T, H, use_g, use_gk, use_dht, use_dh0) + non_varlen_configs = build_focused_non_varlen_configs(args.feature_mode) + varlen_configs = [] + elif args.preset == "fwd": + # Matches bench_chunk_delta_h.py's default dimensions. + # Tuple: (B, T, H, use_g, use_gk, use_dht, use_dh0) + non_varlen_configs = [ + (1, 8192, 64, False, True, True, True), + (2, 8192, 64, False, True, True, True), + (4, 8192, 64, False, True, True, True), + (8, 8192, 64, False, True, True, True), + ] + + # Tuple: (num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0) + varlen_configs = [ + (20, 8192, 64, 2.0, False, True, True, True), + (25, 8192, 64, 3.0, False, True, True, True), + (20, 8192, 64, 4.0, False, True, True, True), + (20, 32768, 64, 2.0, False, True, True, True), + (25, 32768, 64, 3.0, False, True, True, True), + ] + else: + # Short representative subset for day-to-day iteration. + # Tuple: (B, T, H, use_g, use_gk, use_dht, use_dh0) + non_varlen_configs = [ + (1, 512, 4, False, True, True, False), + (1, 512, 4, True, False, True, False), + (2, 1024, 64, False, True, True, True), + (1, 2048, 64, False, True, True, False), + ] + + # Tuple: (num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0) + varlen_configs = [ + (3, 512, 2, 3.0, False, True, True, False), + (4, 768, 2, 4.0, True, False, True, True), + ] + + non_varlen_configs = filter_non_varlen_configs( + non_varlen_configs, + only_b=args.filter_b, + only_h=args.filter_h, + only_t=args.filter_t, + ) + + if args.max_configs is not None: + non_varlen_configs = non_varlen_configs[: args.max_configs] + varlen_configs = varlen_configs[: args.max_configs] + + nv_res, vl_res = [], [] + if args.mode in ("non-varlen", "both"): + nv_res = bench_non_varlen(non_varlen_configs) + if args.mode in ("varlen", "both"): + vl_res = bench_varlen(varlen_configs) + print_report(nv_res, vl_res) + + +if __name__ == "__main__": + main() diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py new file mode 100644 index 0000000..a6cef26 --- /dev/null +++ b/cula/ops/chunk_delta_h_bwd.py @@ -0,0 +1,1465 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Chunk Gated Delta Rule Backward DHU Kernel (SM90 WGMMA) + +Hopper tensor-core path aligned with cula/ops/chunk_delta_h.py: +- fixed chunk size BT=64 +- K=V=128, BV=64 +- non-varlen tensors [B, T, H, D] and packed varlen tensors +- state layout [B, NT, H, K, V] or [B, NT, H, V, K] + +The recurrence follows FLA's bwd_dhu: + dv2 = dv + K @ dh + dh = decay(dh) + scale * Q^T @ do - W^T @ dv2 + +Each CTA owns one BV tile and one (batch, head). WGMMA computes the three +64x64 GEMMs per chunk while CUDA threads carry dh in registers. +""" + +from __future__ import annotations + +import functools + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.warpgroup as warpgroup +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.hopper_helpers as sm90_utils +import torch +from cutlass.cute.nvgpu import cpasync +from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream +from cutlass.cute.typing import Float32, Int32, Int64 +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + +from cula.utils import USE_FAST_MATH, assert_hopper + +BT = 64 +BV = 64 +BK = 128 +NUM_THREADS = 224 + + +def make_thread_cooperative_group(size: int): + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) + + +class ChunkDeltaRuleBwdDHUSm90: + def __init__( + self, + num_heads: int, + head_dim_k: int, + head_dim_v: int, + is_varlen: bool, + use_g: bool, + use_gk: bool, + use_dht: bool, + use_dh0: bool, + use_exp2: bool, + transpose_state_layout: bool, + scale: float, + use_fast_math: bool = True, + ): + assert head_dim_k == 128 and head_dim_v == 128, ( + f"SM90 bwd_dhu currently aligns with ChunkDeltaRuleFwdH and requires K=V=128, got K={head_dim_k}, V={head_dim_v}" + ) + self.H = num_heads + self.K = head_dim_k + self.V = head_dim_v + self.is_varlen = is_varlen + self.use_g = use_g + self.use_gk = use_gk + self.use_dht = use_dht + self.use_dh0 = use_dh0 + self.use_exp2 = use_exp2 + self.transpose_state_layout = transpose_state_layout + self.scale = scale + self.use_fast_math = use_fast_math + + self.BT = BT + self.BV = BV + self.BK = BK + self.threads_per_warp = 32 + self.num_compute_warps = 4 + self.num_compute_threads = self.threads_per_warp * self.num_compute_warps + self.load_warp_id = 4 + self.load_current_warp_id = 5 + self.store_warp_id = 6 + self.num_threads = NUM_THREADS + self.num_regs_compute = 232 + self.num_regs_other = 40 + self.k_stage = 3 + self.dv_stage = 2 + self.do_stage = 2 + self.q_stage = 3 + self.w_stage = 3 + self.gk_stage = 3 + self.dh_store_stage = 2 + self.dv2_store_stage = 2 + self.io_dtype = cutlass.BFloat16 + self.acc_dtype = cutlass.Float32 + self.buffer_align_bytes = 128 + + # K=BK=128, so the carried dh state is a single BV x BK register tile. + self.kdh_mma_tiler = (self.BV, self.BT, self.BK) + self.update_mma_tiler = (self.BV, self.BK, self.BT) + self.atom_layout_mnk = (1, 1, 1) + self.cluster_shape_mnk = (1, 1, 1) + self.gk_precompute_bar = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.num_compute_threads, + ) + + @cute.jit + def __call__( + self, + q_in: cute.Tensor, + k_in: cute.Tensor, + w_in: cute.Tensor, + g_in: cute.Tensor, + gk_in: cute.Tensor, + dht_in: cute.Tensor, + dh0_in: cute.Tensor, + do_in: cute.Tensor, + dh_in: cute.Tensor, + dv_in: cute.Tensor, + dv2_in: cute.Tensor, + cu_seqlens_in: cute.Tensor, + chunk_offsets_in: cute.Tensor, + problem_size: tuple[Int32, Int32, Int32, Int32], + stream: cuda.CUstream, + ): + q_ptr = q_in.iterator + k_ptr = k_in.iterator + w_ptr = w_in.iterator + g_ptr = g_in.iterator + gk_ptr = gk_in.iterator + dht_ptr = dht_in.iterator + dh0_ptr = dh0_in.iterator + do_ptr = do_in.iterator + dh_ptr = dh_in.iterator + dv_ptr = dv_in.iterator + dv2_ptr = dv2_in.iterator + cu_seqlens_ptr = cu_seqlens_in.iterator + chunk_offsets_ptr = chunk_offsets_in.iterator + + B, T, N, NT_total = problem_size + + # ===================== GMEM layouts ===================== + g_layout = cute.make_layout( + (B, T, self.H), + stride=(T * self.H, self.H, 1), + ) + g = cute.make_tensor(g_ptr, g_layout) + + cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((N + 1,))) + chunk_offsets = cute.make_tensor(chunk_offsets_ptr, cute.make_layout((N + 1,))) + + # dh TMA store view: (V, K) tile with layout selected by the requested state layout. + if cutlass.const_expr(self.transpose_state_layout): + dh_tma_layout = cute.make_layout( + (self.V, self.K, (NT_total, self.H, B)), + stride=(self.K, 1, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), + ) + else: + dh_tma_layout = cute.make_layout( + (self.V, self.K, (NT_total, self.H, B)), + stride=(1, self.V, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), + ) + dh_tma_tile = (self.BV, self.BK) + dh_smem_layout_enum = ( + utils.LayoutEnum.ROW_MAJOR if cutlass.const_expr(self.transpose_state_layout) else utils.LayoutEnum.COL_MAJOR + ) + dh_tma = cute.make_tensor(dh_ptr, dh_tma_layout) + + if cutlass.const_expr(self.transpose_state_layout): + final_layout = cute.make_layout( + (N, self.H, self.V, self.K), + stride=(self.H * self.K * self.V, self.K * self.V, self.K, 1), + ) + else: + final_layout = cute.make_layout( + (N, self.H, self.K, self.V), + stride=(self.H * self.K * self.V, self.K * self.V, self.V, 1), + ) + dht = cute.make_tensor(dht_ptr, final_layout) + dh0 = cute.make_tensor(dh0_ptr, final_layout) + if cutlass.const_expr(self.transpose_state_layout): + dht_tma_layout = cute.make_layout( + (self.V, self.K, (self.H, N)), + stride=(self.K, 1, (self.K * self.V, self.H * self.K * self.V)), + ) + else: + dht_tma_layout = cute.make_layout( + (self.V, self.K, (self.H, N)), + stride=(1, self.V, (self.K * self.V, self.H * self.K * self.V)), + ) + dht_tma = cute.make_tensor(dht_ptr, dht_tma_layout) + dh0_tma = cute.make_tensor(dh0_ptr, dht_tma_layout) + + # TMA operand views. Varlen shifts the T dimension with domain_offset below. + tk_layout = cute.make_layout((T, self.K, (self.H, B)), stride=(self.H * self.K, 1, (self.K, T * self.H * self.K))) + k_tk = cute.make_tensor(k_ptr, tk_layout) + + kt_layout = cute.make_layout((self.K, T, (self.H, B)), stride=(1, self.H * self.K, (self.K, T * self.H * self.K))) + q_kt = cute.make_tensor(q_ptr, kt_layout) + w_kt = cute.make_tensor(w_ptr, kt_layout) + gk_kt = cute.make_tensor(gk_ptr, kt_layout) + + vt_layout = cute.make_layout((self.V, T, (self.H, B)), stride=(1, self.H * self.V, (self.V, T * self.H * self.V))) + do_vt = cute.make_tensor(do_ptr, vt_layout) + dv_vt = cute.make_tensor(dv_ptr, vt_layout) + dv2_vt = cute.make_tensor(dv2_ptr, vt_layout) + dv2_layout = cute.make_layout( + (B, T, self.H, self.V), + stride=(T * self.H * self.V, self.H * self.V, self.V, 1), + ) + dv2 = cute.make_tensor(dv2_ptr, dv2_layout) + + # ===================== MMA setup ===================== + tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.io_dtype, + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR.sm90_mma_major_mode(), + utils.LayoutEnum.ROW_MAJOR.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + self.kdh_mma_tiler[:2], + warpgroup.OperandSource.RMEM, + ) + + update_tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.io_dtype, + self.io_dtype, + utils.LayoutEnum.COL_MAJOR.sm90_mma_major_mode(), + utils.LayoutEnum.COL_MAJOR.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + self.update_mma_tiler[:2], + ) + qdo_tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.io_dtype, + self.io_dtype, + utils.LayoutEnum.COL_MAJOR.sm90_mma_major_mode(), + utils.LayoutEnum.COL_MAJOR.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + self.update_mma_tiler[:2], + warpgroup.OperandSource.RMEM, + ) + + # ===================== SMEM layouts ===================== + k_smem_layout_staged = sm90_utils.make_smem_layout_b( + utils.LayoutEnum.ROW_MAJOR, + self.kdh_mma_tiler, + self.io_dtype, + self.k_stage, + ) + dv_smem_layout_staged = sm90_utils.make_smem_layout_a( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + self.dv_stage, + ) + do_smem_layout_staged = sm90_utils.make_smem_layout_a( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + self.do_stage, + ) + q_smem_layout_staged = sm90_utils.make_smem_layout_b( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + self.q_stage, + ) + w_smem_layout_staged = sm90_utils.make_smem_layout_b( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + self.w_stage, + ) + dv2_smem_layout_staged = sm90_utils.make_smem_layout_a( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + self.dv2_store_stage, + ) + gk_smem_layout_staged = cute.make_layout( + (self.BK, 1, self.gk_stage), + stride=(1, self.BK, self.BK), + ) + dht_smem_layout = ( + cute.make_layout((self.BV, self.BK), stride=(self.BK, 1)) + if cutlass.const_expr(self.transpose_state_layout) + else cute.make_layout((self.BV, self.BK), stride=(1, self.BV)) + ) + dh_smem_layout_staged = sm90_utils.make_smem_layout_epi( + self.io_dtype, + dh_smem_layout_enum, + dh_tma_tile, + self.dh_store_stage, + ) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp() + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + # ===================== TMA descriptors ===================== + tma_atom_k, tma_tensor_k = cpasync.make_tiled_tma_atom( + tma_load_op, + k_tk, + cute.slice_(k_smem_layout_staged, (None, None, 0)), + (self.BT, self.BK), + ) + tma_atom_dv, tma_tensor_dv = cpasync.make_tiled_tma_atom( + tma_load_op, + dv_vt, + cute.slice_(dv_smem_layout_staged, (None, None, 0)), + (self.BV, self.BT), + ) + tma_atom_do, tma_tensor_do = cpasync.make_tiled_tma_atom( + tma_load_op, + do_vt, + cute.slice_(do_smem_layout_staged, (None, None, 0)), + (self.BV, self.BT), + ) + tma_atom_q, tma_tensor_q = cpasync.make_tiled_tma_atom( + tma_load_op, + q_kt, + cute.slice_(q_smem_layout_staged, (None, None, 0)), + (self.BK, self.BT), + ) + tma_atom_w, tma_tensor_w = cpasync.make_tiled_tma_atom( + tma_load_op, + w_kt, + cute.slice_(w_smem_layout_staged, (None, None, 0)), + (self.BK, self.BT), + ) + tma_atom_gk, tma_tensor_gk = cpasync.make_tiled_tma_atom( + tma_load_op, + gk_kt, + cute.slice_(gk_smem_layout_staged, (None, None, 0)), + (self.BK, 1), + ) + tma_atom_dht, tma_tensor_dht = cpasync.make_tiled_tma_atom( + tma_load_op, + dht_tma, + dht_smem_layout, + (self.BV, self.BK), + ) + tma_atom_dh0, tma_tensor_dh0 = cpasync.make_tiled_tma_atom( + tma_store_op, + dh0_tma, + dht_smem_layout, + (self.BV, self.BK), + ) + tma_atom_dh, tma_tensor_dh = cpasync.make_tiled_tma_atom( + tma_store_op, + dh_tma, + cute.slice_(dh_smem_layout_staged, (None, None, 0)), + dh_tma_tile, + ) + tma_atom_dv2, tma_tensor_dv2 = cpasync.make_tiled_tma_atom( + tma_store_op, + dv2_vt, + cute.slice_(dv2_smem_layout_staged, (None, None, 0)), + (self.BV, self.BT), + ) + self.tma_k_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(k_smem_layout_staged, (None, None, 0))) + self.tma_dv_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(dv_smem_layout_staged, (None, None, 0))) + self.tma_do_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(do_smem_layout_staged, (None, None, 0))) + self.tma_q_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(q_smem_layout_staged, (None, None, 0))) + self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(w_smem_layout_staged, (None, None, 0))) + self.tma_gk_bytes = cute.size_in_bytes(cutlass.Float32, cute.slice_(gk_smem_layout_staged, (None, None, 0))) + self.tma_dht_bytes = cute.size_in_bytes(cutlass.Float32, dht_smem_layout) + + # ===================== SharedStorage ===================== + @cute.struct + class SharedStorage: + load_k_mbar: cute.struct.MemRange[Int64, self.k_stage * 2] + load_dv_mbar: cute.struct.MemRange[Int64, self.dv_stage * 2] + load_do_mbar: cute.struct.MemRange[Int64, self.do_stage * 2] + load_q_mbar: cute.struct.MemRange[Int64, self.q_stage * 2] + load_w_mbar: cute.struct.MemRange[Int64, self.w_stage * 2] + load_gk_mbar: cute.struct.MemRange[Int64, self.gk_stage * 2] + load_dht_mbar: cute.struct.MemRange[Int64, 2] + store_dh_mbar: cute.struct.MemRange[Int64, self.dh_store_stage * 2] + store_dh0_mbar: cute.struct.MemRange[Int64, 2] + store_dv2_mbar: cute.struct.MemRange[Int64, self.dv2_store_stage * 2] + sK: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(k_smem_layout_staged)], + self.buffer_align_bytes, + ] + sDv: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(dv_smem_layout_staged)], + self.buffer_align_bytes, + ] + sDo: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(do_smem_layout_staged)], + self.buffer_align_bytes, + ] + sGK: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.BK * self.gk_stage], + 128, + ] + sG: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.BT * 2], + 128, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(q_smem_layout_staged)], + self.buffer_align_bytes, + ] + sW: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(w_smem_layout_staged)], + self.buffer_align_bytes, + ] + sDv2: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(dv2_smem_layout_staged)], + self.buffer_align_bytes, + ] + sDh: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(dh_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.kernel( + g, + dht, + dh0, + dv2, + cu_seqlens, + chunk_offsets, + problem_size, + tiled_mma, + update_tiled_mma, + qdo_tiled_mma, + k_smem_layout_staged, + dv_smem_layout_staged, + do_smem_layout_staged, + dv2_smem_layout_staged, + q_smem_layout_staged, + w_smem_layout_staged, + tma_atom_k, + tma_tensor_k, + tma_atom_dv, + tma_tensor_dv, + tma_atom_do, + tma_tensor_do, + tma_atom_q, + tma_tensor_q, + tma_atom_w, + tma_tensor_w, + tma_atom_gk, + tma_tensor_gk, + tma_atom_dht, + tma_tensor_dht, + tma_atom_dh0, + tma_tensor_dh0, + dht_smem_layout, + tma_atom_dh, + tma_tensor_dh, + dh_smem_layout_staged, + tma_atom_dv2, + tma_tensor_dv2, + ).launch( + grid=[cute.ceil_div(self.V, self.BV), N * self.H, 1], + block=[self.num_threads, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + g: cute.Tensor, + dht: cute.Tensor, + dh0: cute.Tensor, + dv2: cute.Tensor, + cu_seqlens: cute.Tensor, + chunk_offsets: cute.Tensor, + problem_size: tuple[Int32, Int32, Int32, Int32], + tiled_mma: cute.TiledMma, + update_tiled_mma: cute.TiledMma, + qdo_tiled_mma: cute.TiledMma, + k_smem_layout_staged: cute.ComposedLayout, + dv_smem_layout_staged: cute.ComposedLayout, + do_smem_layout_staged: cute.ComposedLayout, + dv2_smem_layout_staged: cute.ComposedLayout, + q_smem_layout_staged: cute.ComposedLayout, + w_smem_layout_staged: cute.ComposedLayout, + tma_atom_k: cute.CopyAtom, + tma_tensor_k: cute.Tensor, + tma_atom_dv: cute.CopyAtom, + tma_tensor_dv: cute.Tensor, + tma_atom_do: cute.CopyAtom, + tma_tensor_do: cute.Tensor, + tma_atom_q: cute.CopyAtom, + tma_tensor_q: cute.Tensor, + tma_atom_w: cute.CopyAtom, + tma_tensor_w: cute.Tensor, + tma_atom_gk: cute.CopyAtom, + tma_tensor_gk: cute.Tensor, + tma_atom_dht: cute.CopyAtom, + tma_tensor_dht: cute.Tensor, + tma_atom_dh0: cute.CopyAtom, + tma_tensor_dh0: cute.Tensor, + dht_smem_layout: cute.Layout, + tma_atom_dh: cute.CopyAtom, + tma_tensor_dh: cute.Tensor, + dh_smem_layout_staged: cute.ComposedLayout, + tma_atom_dv2: cute.CopyAtom, + tma_tensor_dv2: cute.Tensor, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + B, T, N, NT_total = problem_size + + # ===================== Block indices ===================== + v_tile_idx, bh_idx, _ = cute.arch.block_idx() + bidx = bh_idx // self.H + hidx = bh_idx - bidx * self.H + data_bidx = bidx + tok_offset = Int32(0) + seq_len = T + NT = (T + self.BT - 1) // self.BT + chunk_off = Int32(0) + if cutlass.const_expr(self.is_varlen): + data_bidx = Int32(0) + tok_offset = cu_seqlens[bidx] + seq_len = cu_seqlens[bidx + 1] - tok_offset + NT = (seq_len + self.BT - 1) // self.BT + chunk_off = chunk_offsets[bidx] + v_tile_base = v_tile_idx * self.BV + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # ===================== SMEM views ===================== + sK = storage.sK.get_tensor(k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner) + sDv = storage.sDv.get_tensor(dv_smem_layout_staged.outer, swizzle=dv_smem_layout_staged.inner) + sDo = storage.sDo.get_tensor(do_smem_layout_staged.outer, swizzle=do_smem_layout_staged.inner) + sGK = storage.sGK.get_tensor(cute.make_layout((self.BK, 1, self.gk_stage), stride=(1, self.BK, self.BK))) + sG = storage.sG.get_tensor(cute.make_layout((self.BT, 2), stride=(1, self.BT))) + sQ = storage.sQ.get_tensor(q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner) + sW = storage.sW.get_tensor(w_smem_layout_staged.outer, swizzle=w_smem_layout_staged.inner) + sDv2 = storage.sDv2.get_tensor(dv2_smem_layout_staged.outer, swizzle=dv2_smem_layout_staged.inner) + sDh = storage.sDh.get_tensor(dh_smem_layout_staged.outer, swizzle=dh_smem_layout_staged.inner) + sDht = cute.make_tensor(cute.recast_ptr(storage.sDh.data_ptr(), dtype=cutlass.Float32), dht_smem_layout) + sDh0 = cute.make_tensor(cute.recast_ptr(storage.sK.data_ptr(), dtype=cutlass.Float32), dht_smem_layout) + + # ===================== Pipelines ===================== + load_k_P, load_k_C = pipeline.PipelineTmaAsync.create( + num_stages=self.k_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_k_bytes, + barrier_storage=storage.load_k_mbar.data_ptr(), + ).make_participants() + load_dv_P, load_dv_C = pipeline.PipelineTmaAsync.create( + num_stages=self.dv_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_dv_bytes, + barrier_storage=storage.load_dv_mbar.data_ptr(), + ).make_participants() + load_do_P, load_do_C = pipeline.PipelineTmaAsync.create( + num_stages=self.do_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_do_bytes, + barrier_storage=storage.load_do_mbar.data_ptr(), + ).make_participants() + load_q_P, load_q_C = pipeline.PipelineTmaAsync.create( + num_stages=self.q_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_q_bytes, + barrier_storage=storage.load_q_mbar.data_ptr(), + ).make_participants() + load_w_P, load_w_C = pipeline.PipelineTmaAsync.create( + num_stages=self.w_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_w_bytes, + barrier_storage=storage.load_w_mbar.data_ptr(), + ).make_participants() + if cutlass.const_expr(self.use_gk): + load_gk_P, load_gk_C = pipeline.PipelineTmaAsync.create( + num_stages=self.gk_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_gk_bytes, + barrier_storage=storage.load_gk_mbar.data_ptr(), + ).make_participants() + if cutlass.const_expr(self.use_dht): + load_dht_P, load_dht_C = pipeline.PipelineTmaAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_dht_bytes, + barrier_storage=storage.load_dht_mbar.data_ptr(), + ).make_participants() + if cutlass.const_expr(self.use_dh0): + store_dh0_P, store_dh0_C = pipeline.PipelineAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(self.num_compute_threads), + consumer_group=make_thread_cooperative_group(self.threads_per_warp), + barrier_storage=storage.store_dh0_mbar.data_ptr(), + ).make_participants() + store_dh_P, store_dh_C = pipeline.PipelineAsync.create( + num_stages=self.dh_store_stage, + producer_group=make_thread_cooperative_group(self.num_compute_threads), + consumer_group=make_thread_cooperative_group(self.threads_per_warp), + barrier_storage=storage.store_dh_mbar.data_ptr(), + ).make_participants() + store_dv2_P, store_dv2_C = pipeline.PipelineAsync.create( + num_stages=self.dv2_store_stage, + producer_group=make_thread_cooperative_group(self.num_compute_threads), + consumer_group=make_thread_cooperative_group(self.threads_per_warp), + barrier_storage=storage.store_dv2_mbar.data_ptr(), + ).make_participants() + + # ===================== TMA partitions ===================== + # Varlen shifts token-indexed tensors by tok_offset; dh uses chunk_off + # because state storage is compact across sequences. + if cutlass.const_expr(self.is_varlen): + tma_tensor_k_use = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_k) + tma_tensor_dv_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_dv) + tma_tensor_do_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_do) + tma_tensor_q_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_q) + tma_tensor_w_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_w) + tma_tensor_dh_use = cute.domain_offset((0, 0, (chunk_off, 0, 0)), tma_tensor_dh) + tma_tensor_dv2_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_dv2) + if cutlass.const_expr(self.use_gk): + tma_tensor_gk_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_gk) + tma_tensor_dht_use = tma_tensor_dht + else: + tma_tensor_k_use = tma_tensor_k + tma_tensor_dv_use = tma_tensor_dv + tma_tensor_do_use = tma_tensor_do + tma_tensor_q_use = tma_tensor_q + tma_tensor_w_use = tma_tensor_w + tma_tensor_dh_use = tma_tensor_dh + tma_tensor_dv2_use = tma_tensor_dv2 + if cutlass.const_expr(self.use_gk): + tma_tensor_gk_use = tma_tensor_gk + tma_tensor_dht_use = tma_tensor_dht + + _, bSG_sK, bSG_gK = self._epilog_partition( + tma_atom_k, tma_tensor_k_use[None, None, (hidx, data_bidx)], (self.BT, self.BK), sK + ) + _, bSG_sDv, bSG_gDv = self._epilog_partition( + tma_atom_dv, tma_tensor_dv_use[None, None, (hidx, data_bidx)], (self.BV, self.BT), sDv + ) + _, bSG_sDo, bSG_gDo = self._epilog_partition( + tma_atom_do, tma_tensor_do_use[None, None, (hidx, data_bidx)], (self.BV, self.BT), sDo + ) + _, bSG_sQ, bSG_gQ = self._epilog_partition( + tma_atom_q, tma_tensor_q_use[None, None, (hidx, data_bidx)], (self.BK, self.BT), sQ + ) + _, bSG_sW, bSG_gW = self._epilog_partition( + tma_atom_w, tma_tensor_w_use[None, None, (hidx, data_bidx)], (self.BK, self.BT), sW + ) + if cutlass.const_expr(self.use_gk): + _, bSG_sGK, bSG_gGK = self._epilog_partition( + tma_atom_gk, tma_tensor_gk_use[None, None, (hidx, data_bidx)], (self.BK, 1), sGK + ) + if cutlass.const_expr(self.use_dht): + _, bSG_sDht, bSG_gDht = self._epilog_partition( + tma_atom_dht, tma_tensor_dht_use[None, None, (hidx, bidx)], (self.BV, self.BK), sDht + ) + if cutlass.const_expr(self.use_dh0): + _, bSG_sDh0, bSG_gDh0 = self._epilog_partition( + tma_atom_dh0, tma_tensor_dh0[None, None, (hidx, bidx)], (self.BV, self.BK), sDh0 + ) + _, bSG_sDh, bSG_gDh = self._epilog_partition( + tma_atom_dh, tma_tensor_dh_use[None, None, (None, hidx, data_bidx)], (self.BV, self.BK), sDh + ) + _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( + tma_atom_dv2, tma_tensor_dv2_use[None, None, (hidx, data_bidx)], (self.BV, self.BT), sDv2 + ) + + is_compute_warp = warp_idx < self.num_compute_warps + local_tidx = tidx % self.num_compute_threads + if is_compute_warp: + cute.arch.setmaxregister_increase(self.num_regs_compute) + else: + cute.arch.setmaxregister_decrease(self.num_regs_other) + + # ===================== MMA fragments ===================== + thr_mma = tiled_mma.get_slice(local_tidx) + update_thr_mma = update_tiled_mma.get_slice(local_tidx) + + tKsB = thr_mma.partition_B(sK) + tKrB = thr_mma.make_fragment_B(tKsB) + tUsA = update_thr_mma.partition_A(sDv) + tUsB = update_thr_mma.partition_B(sQ) + tWsB = update_thr_mma.partition_B(sW) + tUrA = update_thr_mma.make_fragment_A(tUsA) + tDv2sA = update_thr_mma.partition_A(sDv2) + tDv2rA = update_thr_mma.make_fragment_A(tDv2sA) + tUrB = update_thr_mma.make_fragment_B(tUsB) + tWrB = update_thr_mma.make_fragment_B(tWsB) + if cutlass.const_expr(self.use_g): + qdo_thr_mma = qdo_tiled_mma.get_slice(local_tidx) + qdo_tUsB = qdo_thr_mma.partition_B(sQ) + qdo_tUrB = qdo_thr_mma.make_fragment_B(qdo_tUsB) + else: + tUsDo = update_thr_mma.partition_A(sDo) + tUrDo = update_thr_mma.make_fragment_A(tUsDo) + + cDV = cute.make_identity_tensor((self.BV, self.BT)) + tCcDV = thr_mma.partition_C(cDV) + acc_dv = thr_mma.make_fragment_C(thr_mma.partition_shape_C((self.BV, self.BT))) + + cState = cute.make_identity_tensor((self.BV, self.BK)) + tUcState = update_thr_mma.partition_C(cState) + state_shape = update_thr_mma.partition_shape_C((self.BV, self.BK)) + rState = update_thr_mma.make_fragment_C(state_shape) + acc_qdo = update_thr_mma.make_fragment_C(state_shape) + acc_wdv = update_thr_mma.make_fragment_C(state_shape) + dh_smem_layout_enum = ( + utils.LayoutEnum.ROW_MAJOR if cutlass.const_expr(self.transpose_state_layout) else utils.LayoutEnum.COL_MAJOR + ) + dh_copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + dh_smem_layout_enum, + elem_ty_d=self.io_dtype, + elem_ty_acc=self.acc_dtype, + ) + dh_copy_atom = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + dh_smem_layout_enum.is_m_major_c(), + 4, + ), + self.io_dtype, + ) + tiled_copy_dh_atom = cute.make_tiled_copy_C_atom(dh_copy_atom, update_tiled_mma) + tiled_copy_dh_r2s = cute.make_tiled_copy_S(dh_copy_atom_r2s, tiled_copy_dh_atom) + thr_copy_dh_r2s = tiled_copy_dh_r2s.get_slice(local_tidx) + tRS_sDh = thr_copy_dh_r2s.partition_D(sDh) + rDh_shape = cute.shape(thr_copy_dh_r2s.partition_S(sDh)) + tRS_rDh_layout = cute.make_layout(rDh_shape[:3]) + + # ========================================================================= + # WARP SPECIALIZATION + # load_warp_id : preloads K, dv, and optional gk for the next reverse chunk + # load_current_warp_id : loads do, q, and w for the current reverse chunk + # compute warps : carry dh in registers and run WGMMA + # store_warp_id : stores dh and dv2 after compute warps publish SMEM tiles + # ========================================================================= + # ===== Reverse chunk loop ===== + # Pipeline: preload(prev chunk) -> Phase 1 (publish dh + K@dh) + # -> Phase 2 (dv2) -> Phase 3 (QDO + decay) -> Phase 4 (WDV + dh update). + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_k) + cpasync.prefetch_descriptor(tma_atom_dv) + if cutlass.const_expr(self.use_gk): + cpasync.prefetch_descriptor(tma_atom_gk) + if cutlass.const_expr(self.use_dht): + cpasync.prefetch_descriptor(tma_atom_dht) + dht_h = load_dht_P.acquire_and_advance() + cute.copy( + tma_atom_dht, + bSG_gDht[(None, v_tile_idx, 0)], + bSG_sDht[None], + tma_bar_ptr=dht_h.barrier, + ) + + if NT > 0: + first_chunk = NT - 1 + k_h = load_k_P.acquire_and_advance() + cute.copy(tma_atom_k, bSG_gK[(None, first_chunk, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) + dv_h = load_dv_P.acquire_and_advance() + cute.copy( + tma_atom_dv, + bSG_gDv[(None, v_tile_idx, first_chunk)], + bSG_sDv[None, dv_h.index], + tma_bar_ptr=dv_h.barrier, + ) + if cutlass.const_expr(self.use_gk): + gk_h = load_gk_P.acquire_and_advance() + cute.copy( + tma_atom_gk, + bSG_gGK[(None, 0, seq_len - 1)], + bSG_sGK[None, gk_h.index], + tma_bar_ptr=gk_h.barrier, + ) + + for chunk_rev in cutlass.range(0, NT, unroll=0): + chunk_idx = NT - 1 - chunk_rev + next_chunk_idx = chunk_idx - 1 + if next_chunk_idx >= 0: + k_h = load_k_P.acquire_and_advance() + cute.copy(tma_atom_k, bSG_gK[(None, next_chunk_idx, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) + dv_h = load_dv_P.acquire_and_advance() + cute.copy( + tma_atom_dv, + bSG_gDv[(None, v_tile_idx, next_chunk_idx)], + bSG_sDv[None, dv_h.index], + tma_bar_ptr=dv_h.barrier, + ) + if cutlass.const_expr(self.use_gk): + next_gk_idx = cutlass.min(next_chunk_idx * self.BT + self.BT, seq_len) - 1 + gk_h = load_gk_P.acquire_and_advance() + cute.copy( + tma_atom_gk, + bSG_gGK[(None, 0, next_gk_idx)], + bSG_sGK[None, gk_h.index], + tma_bar_ptr=gk_h.barrier, + ) + + elif warp_idx == self.load_current_warp_id: + cpasync.prefetch_descriptor(tma_atom_do) + cpasync.prefetch_descriptor(tma_atom_q) + cpasync.prefetch_descriptor(tma_atom_w) + + for chunk_rev in cutlass.range(0, NT, unroll=0): + chunk_idx = NT - 1 - chunk_rev + do_h = load_do_P.acquire_and_advance() + cute.copy( + tma_atom_do, bSG_gDo[(None, v_tile_idx, chunk_idx)], bSG_sDo[None, do_h.index], tma_bar_ptr=do_h.barrier + ) + q_h = load_q_P.acquire_and_advance() + cute.copy(tma_atom_q, bSG_gQ[(None, 0, chunk_idx)], bSG_sQ[None, q_h.index], tma_bar_ptr=q_h.barrier) + w_h = load_w_P.acquire_and_advance() + cute.copy(tma_atom_w, bSG_gW[(None, 0, chunk_idx)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) + + elif is_compute_warp: + # Initialize carried dh state in registers. dht is loaded by the + # load warp into the sDh backing buffer before sDh is used for + # per-chunk output stores. + if cutlass.const_expr(self.use_dht): + dht_h = load_dht_C.wait_and_advance() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + init = Float32(0.0) + if cutlass.const_expr(self.use_dht): + init = sDht[v_rel, k_rel].to(self.acc_dtype) + rState[ei] = init + if cutlass.const_expr(self.use_dht): + dht_h.release() + + for chunk_rev in cutlass.range(0, NT, unroll=0): + chunk_idx = NT - 1 - chunk_rev + chunk_start = chunk_idx * self.BT + chunk_end = cutlass.min(chunk_start + self.BT, seq_len) + remaining = chunk_end - chunk_start + last_idx = chunk_end - 1 + g_last = Float32(0.0) + g_last_exp = Float32(1.0) + if cutlass.const_expr(self.use_g): + g_last = g[data_bidx, tok_offset + last_idx, hidx].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_last_exp = cute.exp2(g_last, fastmath=self.use_fast_math) + else: + g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) + + # ======================================== + # Phase 1: Publish dh + start K @ dh + # ======================================== + # Publish carried dh to the store pipeline before the GEMM chain, + # matching chunk_delta_h.py's h_out overlap pattern. + rState_bf16 = cute.make_rmem_tensor(rState.shape, self.io_dtype) + rState_bf16.store(rState.load().to(self.io_dtype)) + dh_h = store_dh_P.acquire_and_advance() + tRS_rState = tiled_copy_dh_r2s.retile(rState_bf16) + tRS_rDh_out = cute.make_rmem_tensor_like(tRS_rDh_layout, self.io_dtype) + tRS_rDh_out.store(tRS_rState.load()) + cute.copy( + tiled_copy_dh_r2s, + tRS_rDh_out, + tRS_sDh[(None, None, None, dh_h.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + dh_h.commit() + + # dv2 = dv + K @ dh. Compute the equivalent dh @ K^T tile so + # the register-carried state can feed WGMMA as an RMEM A operand. + acc_dv.fill(0.0) + k_wait = load_k_C.wait_and_advance() + rState_op = self.make_acc_into_op(rState_bf16, tiled_mma.tv_layout_A, self.io_dtype) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tKrB, mode=[2]), unroll_full=True): + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + tiled_mma, + acc_dv, + rState_op[None, None, kp], + tKrB[None, None, kp, k_wait.index], + acc_dv, + ) + cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(self.use_g): + if local_tidx < self.BT: + t_idx = chunk_start + local_tidx + g_decay = Float32(0.0) + g_exp = Float32(0.0) + if t_idx < seq_len: + g_cur = g[data_bidx, tok_offset + t_idx, hidx].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_decay = cute.exp2(g_last - g_cur, fastmath=self.use_fast_math) + g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) + else: + g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) + g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) + sG[local_tidx, 0] = g_decay + sG[local_tidx, 1] = g_exp + if cutlass.const_expr(not self.use_g): + # Phase 3 is independent of K@dh, so overlap QDO and optional gk decay + # with the first GEMM in the no-scalar-g fast path. For + # varlen tails, zero padded do positions before QDO so TMA + # overfetch into the next sequence cannot contribute. + do_wait_early = load_do_C.wait_and_advance() + q_wait_early = load_q_C.wait_and_advance() + if cutlass.const_expr(self.use_gk): + gk_wait_early = load_gk_C.wait_and_advance() + if cutlass.const_expr(self.is_varlen): + if remaining < self.BT: + linear_do = local_tidx + while linear_do < self.BV * self.BT: + v_rel = linear_do // self.BT + t_rel = linear_do - v_rel * self.BT + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait_early.index].to(self.acc_dtype) + sDo[v_rel, t_rel, do_wait_early.index] = do_scaled.to(self.io_dtype) + linear_do += self.num_compute_threads + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + acc_qdo.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_qdo, + tUrDo[None, None, kp, do_wait_early.index], + tUrB[None, None, kp, q_wait_early.index], + acc_qdo, + ) + cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(self.use_gk): + gk_last = sGK[local_tidx, 0, gk_wait_early.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + sGK[local_tidx, 0, gk_wait_early.index] = k_decay + self.gk_precompute_bar.arrive_and_wait() + cute.nvgpu.warpgroup.wait_group(1) + else: + cute.nvgpu.warpgroup.wait_group(0) + k_wait.release() + + # ======================================== + # Phase 2: dv2 = dv + K @ dh + # ======================================== + dv_wait = load_dv_C.wait_and_advance() + dv_stage = dv_wait.index + dv2_store_h = store_dv2_P.acquire_and_advance() + dv2_stage = dv2_store_h.index + if cutlass.const_expr(self.use_g): + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): + v_rel, t_rel = tCcDV[ei] + t_idx = chunk_start + t_rel + out = Float32(0.0) + if t_idx < seq_len: + out = acc_dv[ei] + if cutlass.const_expr(self.use_g): + out = out * sG[t_rel, 0] + out = out + sDv[v_rel, t_rel, dv_stage].to(self.acc_dtype) + out_bf16 = out.to(self.io_dtype) + sDv2[v_rel, t_rel, dv2_stage] = out_bf16 + if remaining < self.BT and t_idx < seq_len: + dv2[data_bidx, tok_offset + chunk_start + t_rel, hidx, v_tile_base + v_rel] = out_bf16 + cute.arch.fence_proxy("async.shared", space="cta") + dv2_store_h.commit() + dv_wait.release() + + # ======================================== + # Phase 3/4: dh += scale * do^T @ q - dv2^T @ w + # ======================================== + if cutlass.const_expr(not self.use_g): + w_wait = load_w_C.wait_and_advance() + acc_wdv.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_wdv, + tDv2rA[None, None, kp, dv2_stage], + tWrB[None, None, kp, w_wait.index], + acc_wdv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + q_wait_early.release() + + for ei in cutlass.range(cute.size(rState), unroll_full=True): + update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] + if cutlass.const_expr(self.use_gk): + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait_early.index] + update + else: + rState[ei] = rState[ei] + update + w_wait.release() + do_wait_early.release() + if cutlass.const_expr(self.use_gk): + gk_wait_early.release() + else: + do_wait = load_do_C.wait_and_advance() + if cutlass.const_expr(self.use_g): + # Phase 3a: materialize gated do in registers for QDO. + for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): + v_rel, t_rel = tCcDV[ei] + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) * sG[t_rel, 1] + acc_dv[ei] = do_scaled + rDo_op = self.make_acc_into_op(acc_dv, qdo_tiled_mma.tv_layout_A, self.io_dtype) + do_wait.release() + if cutlass.const_expr((not self.use_g) and self.is_varlen): + # Phase 3a: zero padded do positions in SMEM for varlen tails. + if remaining < self.BT: + linear_do = local_tidx + while linear_do < self.BV * self.BT: + v_rel = linear_do // self.BT + t_rel = linear_do - v_rel * self.BT + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) + sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) + linear_do += self.num_compute_threads + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + + # Phase 3b: QDO plus scalar/key decay while QDO is in flight. + q_wait = load_q_C.wait_and_advance() + if cutlass.const_expr(self.use_gk): + gk_wait = load_gk_C.wait_and_advance() + acc_qdo.fill(0.0) + cute.nvgpu.warpgroup.fence() + if cutlass.const_expr(self.use_g): + for kp in cutlass.range(cute.size(qdo_tUrB, mode=[2]), unroll_full=True): + qdo_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + qdo_tiled_mma, + acc_qdo, + rDo_op[None, None, kp], + qdo_tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) + else: + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_qdo, + tUrDo[None, None, kp, do_wait.index], + tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) + cute.nvgpu.warpgroup.commit_group() + + # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. + if cutlass.const_expr(self.use_g): + for ei in cutlass.range(cute.size(rState), unroll_full=True): + rState[ei] = rState[ei] * g_last_exp + if cutlass.const_expr(self.use_gk): + gk_last = sGK[local_tidx, 0, gk_wait.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + sGK[local_tidx, 0, gk_wait.index] = k_decay + self.gk_precompute_bar.arrive_and_wait() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait.index] + + # Phase 4: WDV and final dh update. + w_wait = load_w_C.wait_and_advance() + acc_wdv.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_wdv, + tDv2rA[None, None, kp, dv2_stage], + tWrB[None, None, kp, w_wait.index], + acc_wdv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + q_wait.release() + if cutlass.const_expr(self.use_gk): + gk_wait.release() + + for ei in cutlass.range(cute.size(rState), unroll_full=True): + update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] + rState[ei] = rState[ei] + update + w_wait.release() + if cutlass.const_expr(not self.use_g): + do_wait.release() + + if cutlass.const_expr(self.use_dh0): + dh0_h = store_dh0_P.acquire_and_advance() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + sDh0[v_rel, k_rel] = rState[ei] + cute.arch.fence_proxy("async.shared", space="cta") + dh0_h.commit() + + elif warp_idx == self.store_warp_id: + cpasync.prefetch_descriptor(tma_atom_dh) + cpasync.prefetch_descriptor(tma_atom_dv2) + if cutlass.const_expr(self.use_dh0): + cpasync.prefetch_descriptor(tma_atom_dh0) + + for chunk_rev in cutlass.range(0, NT, unroll=0): + chunk_idx = NT - 1 - chunk_rev + chunk_start = chunk_idx * self.BT + chunk_end = cutlass.min(chunk_start + self.BT, seq_len) + remaining = chunk_end - chunk_start + + dh_h = store_dh_C.wait_and_advance() + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, v_tile_idx, 0, chunk_idx)]) + cute.arch.cp_async_bulk_commit_group() + + dv2_store_h = store_dv2_C.wait_and_advance() + # Tail chunks skip TMA because the tile would cross sequence + # bounds. The store pipeline itself keeps sDv2 stages from + # being overwritten before this warp releases them. + if remaining >= self.BT: + cute.copy( + tma_atom_dv2, + bSG_sDv2[None, dv2_store_h.index], + bSG_gDv2[(None, v_tile_idx, chunk_idx)], + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh_h.release() + dv2_store_h.release() + + if cutlass.const_expr(self.use_dh0): + dh0_h = store_dh0_C.wait_and_advance() + cute.copy(tma_atom_dh0, bSG_sDh0[None], bSG_gDh0[(None, v_tile_idx, 0)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh0_h.release() + + @cute.jit + def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): + gC_epi = cute.flat_divide(gC_mnl, epi_tile) + sC_g = cute.group_modes(sC, 0, 2) + gC_g = cute.group_modes(gC_epi, 0, 2) + bSG_sC, bSG_gC = cpasync.tma_partition( + atom, + 0, + cute.make_layout(1), + sC_g, + gC_g, + ) + return atom, bSG_sC, bSG_gC + + @staticmethod + def _convert_c_layout_to_a_layout(c, a): + return cute.make_layout( + (a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))), + stride=( + c.stride[0], + c.stride[1], + (c.stride[2], cute.size(a, mode=[2]) * c.stride[0][2]), + ), + ) + + @cute.jit + def make_acc_into_op(self, acc, operand_layout_tv, element_type): + operand = cute.make_rmem_tensor_like( + self._convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]), + element_type, + ) + operand_as_acc = cute.make_tensor(operand.iterator, acc.layout) + operand_as_acc.store(acc.load().to(element_type)) + return operand + + +@functools.lru_cache(maxsize=64) +def _compile_bwd_dhu_sm90( + H: int, + K: int, + V: int, + is_varlen: bool, + use_g: bool, + use_gk: bool, + use_dht: bool, + use_dh0: bool, + use_exp2: bool, + transpose_state_layout: bool, + scale: float, +): + """Compile one bwd_dhu kernel variant. + + B, T, N, and NT are symbolic during compilation and are passed as runtime + problem_size values, matching the forward kernel's dynamic-shape pattern. + """ + kernel = ChunkDeltaRuleBwdDHUSm90( + num_heads=H, + head_dim_k=K, + head_dim_v=V, + is_varlen=is_varlen, + use_g=use_g, + use_gk=use_gk, + use_dht=use_dht, + use_dh0=use_dh0, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + scale=scale, + use_fast_math=USE_FAST_MATH, + ) + + sym_b = cute.sym_int() + sym_t = cute.sym_int() + sym_n = cute.sym_int() + sym_nt = cute.sym_int() + sym_meta = cute.sym_int() + + q_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (sym_b, sym_t, H, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + k_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (sym_b, sym_t, H, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + w_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (sym_b, sym_t, H, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + do_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (sym_b, sym_t, H, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dv_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (sym_b, sym_t, H, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dv2_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (sym_b, sym_t, H, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + g_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_b, sym_t, H), + stride_order=(2, 1, 0), + assumed_align=128, + ) + gk_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_b, sym_t, H, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + if transpose_state_layout: + dht_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_n, H, V, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dh0_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_n, H, V, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dh_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (sym_b, sym_nt, H, V, K), + stride_order=(4, 3, 2, 1, 0), + assumed_align=128, + ) + else: + dht_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_n, H, K, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dh0_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_n, H, K, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dh_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (sym_b, sym_nt, H, K, V), + stride_order=(4, 3, 2, 1, 0), + assumed_align=128, + ) + cu_fake = make_fake_compact_tensor(cutlass.Int32, (sym_meta,), assumed_align=128) + offsets_fake = make_fake_compact_tensor(cutlass.Int32, (sym_meta,), assumed_align=128) + stream_fake = make_fake_stream(use_tvm_ffi_env_stream=True) + + return cute.compile( + kernel, + q_fake, + k_fake, + w_fake, + g_fake, + gk_fake, + dht_fake, + dh0_fake, + do_fake, + dh_fake, + dv_fake, + dv2_fake, + cu_fake, + offsets_fake, + (Int32(1), Int32(1), Int32(1), Int32(1)), + stream_fake, + options="--enable-tvm-ffi", + ) + + +def chunk_gated_delta_rule_bwd_dhu_sm90( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + do: torch.Tensor, + dv: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + h0: torch.Tensor | None = None, + dht: torch.Tensor | None = None, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + chunk_size: int = BT, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + """FLA-compatible wrapper for the SM90 WGMMA bwd_dhu path.""" + assert_hopper(q.device) + if chunk_size != BT: + raise NotImplementedError(f"SM90 bwd_dhu only supports chunk_size={BT}.") + + B, T, H, K = q.shape + V = do.shape[-1] + is_varlen = cu_seqlens is not None + if is_varlen and B != 1: + raise ValueError("varlen mode expects packed inputs with shape [1, total_T, H, D].") + if K != 128 or V != 128: + raise NotImplementedError(f"SM90 bwd_dhu currently aligns with fwd and only supports K=V=128, got K={K}, V={V}.") + if q.dtype != torch.bfloat16 or k.dtype != torch.bfloat16 or w.dtype != torch.bfloat16: + raise TypeError("q, k, and w must be bfloat16 for the SM90 bwd_dhu path.") + if do.dtype != torch.bfloat16 or dv.dtype != torch.bfloat16: + raise TypeError("do and dv must be bfloat16 for the SM90 bwd_dhu path.") + if not q.is_contiguous() or not k.is_contiguous() or not w.is_contiguous(): + raise ValueError("q, k, and w must be contiguous.") + if not do.is_contiguous() or not dv.is_contiguous(): + raise ValueError("do and dv must be contiguous.") + if h0 is not None and (h0.dtype != torch.float32 or not h0.is_contiguous()): + raise ValueError("h0 must be contiguous float32.") + if cu_seqlens is not None and (cu_seqlens.device != q.device or not cu_seqlens.is_contiguous()): + raise ValueError("cu_seqlens must be contiguous and on the same CUDA device as q.") + if chunk_indices is not None and (chunk_indices.device != q.device or not chunk_indices.is_contiguous()): + raise ValueError("chunk_indices must be contiguous and on the same CUDA device as q.") + if chunk_offsets is not None and (chunk_offsets.device != q.device or not chunk_offsets.is_contiguous()): + raise ValueError("chunk_offsets must be contiguous and on the same CUDA device as q.") + + if is_varlen: + if chunk_indices is None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + N = len(cu_seqlens) - 1 + NT = len(chunk_indices) + if chunk_offsets is None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT).int() + elif chunk_offsets.dtype != torch.int32: + chunk_offsets = chunk_offsets.int() + cu_seqlens_arg = cu_seqlens.int() if cu_seqlens.dtype != torch.int32 else cu_seqlens + else: + N = B + NT = (T + BT - 1) // BT + cu_seqlens_arg = torch.empty(B + 1, device=q.device, dtype=torch.int32) + chunk_offsets = torch.empty(B + 1, device=q.device, dtype=torch.int32) + scale_value = 1.0 if scale is None else float(scale) + + state_shape = (N, H, V, K) if transpose_state_layout else (N, H, K, V) + dh = q.new_empty(B, NT, H, V, K) if transpose_state_layout else q.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.empty_like(dv) + + g_arg = g if g is not None else torch.empty(B, T, H, device=q.device, dtype=torch.float32) + gk_arg = gk if gk is not None else torch.empty(B, T, H, K, device=q.device, dtype=torch.float32) + dht_arg = dht if dht is not None else torch.empty(state_shape, device=q.device, dtype=torch.float32) + dh0_arg = dh0 if dh0 is not None else torch.empty(state_shape, device=q.device, dtype=torch.float32) + if g is not None and (g.dtype != torch.float32 or not g.is_contiguous()): + raise ValueError("g must be contiguous float32.") + if g is not None and tuple(g.shape) != (B, T, H): + raise ValueError(f"g must have shape {(B, T, H)}, got {tuple(g.shape)}.") + if gk is not None and (gk.dtype != torch.float32 or not gk.is_contiguous()): + raise ValueError("gk must be contiguous float32.") + if gk is not None and tuple(gk.shape) != (B, T, H, K): + raise ValueError(f"gk must have shape {(B, T, H, K)}, got {tuple(gk.shape)}.") + if dht is not None and (dht.dtype != torch.float32 or not dht.is_contiguous()): + raise ValueError("dht must be contiguous float32.") + if dht is not None and tuple(dht.shape) != state_shape: + raise ValueError(f"dht must have shape {state_shape} for this state layout, got {tuple(dht.shape)}.") + if h0 is not None and tuple(h0.shape) != state_shape: + raise ValueError(f"h0 must have shape {state_shape} for this state layout, got {tuple(h0.shape)}.") + + compiled = _compile_bwd_dhu_sm90( + H, + K, + V, + is_varlen, + g is not None, + gk is not None, + dht is not None, + h0 is not None, + use_exp2, + transpose_state_layout, + scale_value, + ) + problem_size = (Int32(B), Int32(T), Int32(N), Int32(NT)) + compiled(q, k, w, g_arg, gk_arg, dht_arg, dh0_arg, do, dh, dv, dv2, cu_seqlens_arg, chunk_offsets, problem_size) + return dh, dh0, dv2 + + +chunk_gated_delta_rule_bwd_dhu = chunk_gated_delta_rule_bwd_dhu_sm90 diff --git a/tests/test_chunk_delta_h_bwd_sm90.py b/tests/test_chunk_delta_h_bwd_sm90.py new file mode 100644 index 0000000..7324ef2 --- /dev/null +++ b/tests/test_chunk_delta_h_bwd_sm90.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for the SM90 CuTe DSL WGMMA bwd_dhu path. + +These cases follow tests/test_chunk_delta_h.py where the backward API permits. +For bwd_dhu, fwd's initial_state/output_final_state pair maps to dht/dh0. +""" + +import os +import sys + +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu as fla_bwd_dhu + +from cula.ops.chunk_delta_h_bwd import chunk_gated_delta_rule_bwd_dhu_sm90 + +BT = 64 +ATOL = 1e-2 +RTOL = 1e-2 +device = "cuda" + + +def _is_sm90() -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 9 + + +pytestmark = [ + pytest.mark.sm90_only, + pytest.mark.skipif(not _is_sm90(), reason="SM90/Hopper GPU is required"), +] + + +def run_fla_ref( + q, + k, + w, + do, + dv, + g=None, + gk=None, + dht=None, + dh0=None, + cu_seqlens=None, + use_exp2=True, + transpose_state_layout=False, +): + return fla_bwd_dhu( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + h0=dh0, + dht=dht, + scale=q.shape[-1] ** -0.5, + cu_seqlens=cu_seqlens.long() if cu_seqlens is not None else None, + chunk_size=BT, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + + +def run_cute_dsl( + q, + k, + w, + do, + dv, + g=None, + gk=None, + dht=None, + dh0=None, + cu_seqlens=None, + use_exp2=True, + transpose_state_layout=False, +): + return chunk_gated_delta_rule_bwd_dhu_sm90( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + h0=dh0, + dht=dht, + scale=q.shape[-1] ** -0.5, + cu_seqlens=cu_seqlens, + chunk_size=BT, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + + +def _assert_bwd_close(got, ref, expect_dh0, msg): + got_dh, got_dh0, got_dv2 = got + ref_dh, ref_dh0, ref_dv2 = ref + torch.testing.assert_close(got_dh.float(), ref_dh.float(), atol=ATOL, rtol=RTOL, msg=f"{msg}: dh") + torch.testing.assert_close(got_dv2.float(), ref_dv2.float(), atol=ATOL, rtol=RTOL, msg=f"{msg}: dv2") + if expect_dh0: + assert got_dh0 is not None + torch.testing.assert_close(got_dh0.float(), ref_dh0.float(), atol=ATOL, rtol=RTOL, msg=f"{msg}: dh0") + else: + assert got_dh0 is None + + +def _make_inputs( + B, + T, + H, + K, + V, + use_g=False, + use_gk=False, + use_state=False, + seed=42, + transpose_state_layout=False, +): + torch.manual_seed(seed) + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) * 0.1 + w = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) * 0.1 + do = torch.randn(B, T, H, V, dtype=torch.bfloat16, device=device) * 0.1 + dv = torch.randn(B, T, H, V, dtype=torch.bfloat16, device=device) * 0.1 + + g = None + if use_g: + g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device=device) * 0.01).cumsum(dim=1) + + gk = None + if use_gk: + gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.01).cumsum(dim=1) + + state_shape = (B, H, V, K) if transpose_state_layout else (B, H, K, V) + dht = torch.randn(state_shape, dtype=torch.float32, device=device) * 0.01 if use_state else None + dh0 = torch.empty(state_shape, dtype=torch.float32, device=device) if use_state else None + return q, k, w, do, dv, g, gk, dht, dh0 + + +def _make_varlen_inputs( + seq_lens, + H, + K, + V, + use_g=False, + use_gk=False, + use_state=False, + seed=42, + transpose_state_layout=False, +): + T_total = sum(seq_lens) + num_seqs = len(seq_lens) + cu = [0] + for seq_len in seq_lens: + cu.append(cu[-1] + seq_len) + + torch.manual_seed(seed) + q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device=device) * 0.1 + w = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device=device) * 0.1 + do = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device=device) * 0.1 + dv = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device=device) * 0.1 + + g = None + if use_g: + g = torch.empty(1, T_total, H, dtype=torch.float32, device=device) + for i in range(num_seqs): + bos, eos = cu[i], cu[i + 1] + seg = torch.randn(1, eos - bos, H, dtype=torch.float32, device=device) * 0.01 + g[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) + + gk = None + if use_gk: + gk = torch.empty(1, T_total, H, K, dtype=torch.float32, device=device) + for i in range(num_seqs): + bos, eos = cu[i], cu[i + 1] + seg = torch.randn(1, eos - bos, H, K, dtype=torch.float32, device=device) * 0.01 + gk[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) + + state_shape = (num_seqs, H, V, K) if transpose_state_layout else (num_seqs, H, K, V) + dht = torch.randn(state_shape, dtype=torch.float32, device=device) * 0.01 if use_state else None + dh0 = torch.empty(state_shape, dtype=torch.float32, device=device) if use_state else None + cu_seqlens = torch.tensor(cu, dtype=torch.int32, device=device) + return q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens + + +@pytest.mark.parametrize("B", [1, 2]) +@pytest.mark.parametrize("H", [1, 4]) +@pytest.mark.parametrize("T", [64, 128, 256]) +@pytest.mark.parametrize("K", [128]) +@pytest.mark.parametrize("V", [128]) +@pytest.mark.parametrize("use_gk", [False, True]) +@pytest.mark.parametrize("use_state", [False, True]) +def test_dhu_against_fla(B, H, T, K, V, use_gk, use_state): + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs(B, T, H, K, V, use_gk=use_gk, use_state=use_state) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + _assert_bwd_close(got, ref, use_state, f"B={B} H={H} T={T} gk={use_gk} state={use_state}") + + +@pytest.mark.parametrize( + "B,T,H,K,V", + [ + (1, 64, 1, 128, 128), + (2, 128, 4, 128, 128), + (4, 512, 4, 128, 128), + ], +) +def test_dv2_no_gating(B, T, H, K, V): + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs(B, T, H, K, V) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + _assert_bwd_close(got, ref, False, f"dv2 no-gating B={B} T={T} H={H}") + + +@pytest.mark.parametrize( + "seq_lens", + [ + [128, 128], + [50, 192, 100], + [33, 128, 200, 95], + ], +) +@pytest.mark.parametrize("H", [1, 4]) +@pytest.mark.parametrize("use_gk", [False, True]) +@pytest.mark.parametrize("use_state", [False, True]) +def test_varlen_against_fla(seq_lens, H, use_gk, use_state): + K, V = 128, 128 + q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens = _make_varlen_inputs(seq_lens, H, K, V, use_gk=use_gk, use_state=use_state) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, cu_seqlens=cu_seqlens) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, cu_seqlens=cu_seqlens) + _assert_bwd_close(got, ref, use_state, f"varlen seqs={seq_lens} H={H} gk={use_gk} state={use_state}") + + +def test_varlen_vs_nonvarlen(): + H, K, V = 2, 128, 128 + T = 256 + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs(1, T, H, K, V, use_gk=True, use_state=True) + dh_nv, dh0_nv, dv2_nv = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + + cu_seqlens = torch.tensor([0, T], dtype=torch.int32, device=device) + dh_vl, dh0_vl, dv2_vl = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, cu_seqlens=cu_seqlens) + + torch.testing.assert_close(dh_nv.float(), dh_vl.float(), atol=1e-6, rtol=1e-6) + torch.testing.assert_close(dv2_nv.float(), dv2_vl.float(), atol=1e-6, rtol=1e-6) + torch.testing.assert_close(dh0_nv.float(), dh0_vl.float(), atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize( + "use_g,use_gk", + [ + (True, False), + (True, True), + ], +) +def test_scalar_g_features(use_g, use_gk): + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs( + B=1, + T=128, + H=2, + K=128, + V=128, + use_g=use_g, + use_gk=use_gk, + use_state=True, + seed=123, + ) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + _assert_bwd_close(got, ref, True, f"scalar-g g={use_g} gk={use_gk}") + + +@pytest.mark.parametrize( + "T,use_g,use_gk,transpose_state_layout", + [ + (65, False, False, False), + (127, True, False, False), + (129, False, True, True), + (191, True, True, True), + ], + ids=["t65-plain", "t127-g", "t129-gk-trans", "t191-g-gk-trans"], +) +def test_tail_chunk_sizes(T, use_g, use_gk, transpose_state_layout): + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs( + B=1, + T=T, + H=2, + K=128, + V=128, + use_g=use_g, + use_gk=use_gk, + use_state=True, + seed=1000 + T, + transpose_state_layout=transpose_state_layout, + ) + ref = run_fla_ref( + q, + k, + w, + do, + dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + transpose_state_layout=transpose_state_layout, + ) + got = run_cute_dsl( + q, + k, + w, + do, + dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + transpose_state_layout=transpose_state_layout, + ) + _assert_bwd_close(got, ref, True, f"T={T} g={use_g} gk={use_gk} trans={transpose_state_layout}") + + +@pytest.mark.parametrize( + "use_g,use_gk,transpose_state_layout", + [ + (False, False, False), + (True, False, False), + (False, True, False), + (True, True, False), + (False, False, True), + (True, False, True), + (False, True, True), + (True, True, True), + ], +) +def test_varlen_tail_chunk_sizes(use_g, use_gk, transpose_state_layout): + seq_lens = [1, 63, 64, 65, 127, 128, 129] + q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens = _make_varlen_inputs( + seq_lens, + H=1, + K=128, + V=128, + use_g=use_g, + use_gk=use_gk, + use_state=True, + seed=2000 + int(use_g) * 10 + int(use_gk) * 20 + int(transpose_state_layout) * 40, + transpose_state_layout=transpose_state_layout, + ) + ref = run_fla_ref( + q, + k, + w, + do, + dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + transpose_state_layout=transpose_state_layout, + ) + got = run_cute_dsl( + q, + k, + w, + do, + dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + transpose_state_layout=transpose_state_layout, + ) + _assert_bwd_close( + got, + ref, + True, + f"varlen tails g={use_g} gk={use_gk} trans={transpose_state_layout}", + ) + + +def test_transpose_state_layout(): + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs( + B=1, + T=128, + H=2, + K=128, + V=128, + use_gk=True, + use_state=True, + seed=456, + transpose_state_layout=True, + ) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, transpose_state_layout=True) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, transpose_state_layout=True) + _assert_bwd_close(got, ref, True, "transpose state layout")