From 47b2c8d6e4b9ef24a32f1c668ea8f7098d6d1de8 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Tue, 5 May 2026 09:17:03 +0800 Subject: [PATCH 01/26] init sm100 bwd wy dqkg --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 439 +++ benchmarks/utils.py | 93 + cula/ops/chunk_wy_dqkg_sm100.py | 3257 +++++++++++++++++++++ cula/ops/intrinsics_sm100.py | 359 +++ cula/ops/ptx_umma_ext.py | 958 ++++++ 5 files changed, 5106 insertions(+) create mode 100644 benchmarks/bench_kda_bwd_wy_dqkg_sm100.py create mode 100644 cula/ops/chunk_wy_dqkg_sm100.py create mode 100644 cula/ops/intrinsics_sm100.py create mode 100644 cula/ops/ptx_umma_ext.py diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py new file mode 100644 index 0000000..e00c4ed --- /dev/null +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +bench_bwd_wy_dqkg_fused.py — Benchmark: cuLA CuTe DSL vs FLA Triton baseline + for chunk_kda_bwd_wy_dqkg_fused kernel + +Compares: + - Accuracy: err_ratio, relative max diff, mean diff between cuLA and FLA outputs + - Performance: kernel execution time (ms) with CUDA events + +Modes: + - Fixed-length: B=1,2 with various T + - Varlen: variable-length sequences with different distributions + +Usage: + python bench_bwd_wy_dqkg_fused.py [--mode fixed|varlen|both] [--ncu] [--heads 32 64] + +With --ncu, warmup=1 and iters=1 for ncu profiling: + ncu --set full -o report python bench_bwd_wy_dqkg_fused.py --mode fixed --ncu +""" + +import argparse +import pathlib +import sys + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) + +import torch +from fla.ops.kda.chunk_bwd import chunk_kda_bwd_wy_dqkg_fused as fla_chunk_kda_bwd_wy_dqkg_fused + +from benchmarks.utils import ( + SEED, + build_varlen_configs, + exclusive_cumsum, + prepare_bwd_wy_dqkg_fused_inputs, + set_seed, +) +from cula.ops.chunk_wy_dqkg_sm100 import chunk_kda_bwd_wy_dqkg_fused as cutedsl_chunk_kda_bwd_wy_dqkg_fused + +torch.backends.cuda.matmul.allow_tf32 = True + +# ============================================================ +# Constants +# ============================================================ +H_DEFAULT = 32 +K = 128 +V = 128 +BT = 64 +DTYPE = torch.bfloat16 +DEVICE = torch.device("cuda") +WARMUP = 25 +N_ITERS = 100 +NCU_MODE = False + +def generate_balanced_seqlens(total_tokens, num_seqs): + base = total_tokens // num_seqs + remainder = total_tokens % num_seqs + return [base] * (num_seqs - 1) + [base + remainder] + +# ============================================================ +# Helpers +# ============================================================ +def time_kernel(fn, warmup=None, n_iters=None): + if warmup is None: + warmup = 1 if NCU_MODE else WARMUP + if n_iters is None: + n_iters = 1 if NCU_MODE else N_ITERS + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + for _ in range(n_iters): + fn() + end_evt.record() + torch.cuda.synchronize() + return start_evt.elapsed_time(end_evt) / n_iters + + +def accuracy_stats(ref, out): + """Compute err_ratio, relative max diff, and mean absolute difference.""" + ref_f = ref.float() + out_f = out.float() + diff = (ref_f - out_f).abs() + err = diff.flatten().pow(2).mean().sqrt().item() + base = ref_f.flatten().pow(2).mean().sqrt().item() + err_ratio = err / (base + 1e-8) + max_diff = diff.max().item() + denom = ref_f.abs().max().item() + rel_max = max_diff / denom if denom > 0 else 0.0 + mean_diff = diff.mean().item() + return err_ratio, rel_max, mean_diff + + +# ============================================================ +# Runners +# ============================================================ +def run_fla_triton(inputs: dict): + """Run the FLA Triton baseline.""" + return fla_chunk_kda_bwd_wy_dqkg_fused( + q=inputs["q"], + k=inputs["k"], + v=inputs["v"], + v_new=inputs["v_new"], + g=inputs["g"], + beta=inputs["beta"], + A=inputs["A"], + h=inputs["h"], + do=inputs["do"], + dh=inputs["dh"], + dv=inputs["dv"], + scale=inputs["scale"], + cu_seqlens=inputs["cu_seqlens"], + chunk_size=BT, + chunk_indices=inputs["chunk_indices"], + transpose_state_layout=False, + ) + + +def run_cutedsl(inputs: dict): + """Run the CuTe DSL Blackwell kernel.""" + return cutedsl_chunk_kda_bwd_wy_dqkg_fused( + q=inputs["q"], + k=inputs["k"], + v=inputs["v"], + v_new=inputs["v_new"], + g=inputs["g"], + beta=inputs["beta"], + A=inputs["A"], + h=inputs["h"], + do=inputs["do"], + dh=inputs["dh"], + dv=inputs["dv"], + scale=inputs["scale"], + cu_seqlens=inputs["cu_seqlens"], + chunk_size=BT, + chunk_indices=inputs["chunk_indices"], + ) + +def check_determinism(H=4, total_T=2001, num_seqs=4, iters=1000, beta_dtype=DTYPE): + """Verify deterministic outputs across repeated runs.""" + torch.manual_seed(42) + seq_lens = generate_balanced_seqlens(total_T, num_seqs) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=DEVICE) + inputs = prepare_bwd_wy_dqkg_fused_inputs( + B=1, + T=total_T, + H=H, + K=K, + V=V, + chunk_size=BT, + device=DEVICE, + seed=SEED, + cu_seqlens=cu_seqlens, + ) + + ref_dq, ref_dk, ref_dv, ref_db, ref_dg, ref_dA = run_cutedsl(inputs) + for i in range(iters): + dq_out, dk_out, dv_out, db_out, dg_out, dA_out = run_cutedsl(inputs) + assert torch.equal(dq_out, ref_dq), f"dq mismatch at iter {i}" + assert torch.equal(dk_out, ref_dk), f"dk mismatch at iter {i}" + assert torch.equal(dv_out, ref_dv), f"dv mismatch at iter {i}" + assert torch.equal(dg_out, ref_dg), f"dg mismatch at iter {i}" + assert torch.equal(dA_out, ref_dA), f"dA mismatch at iter {i}" + # NOTE: for db, kernel uses atomic add which can cause non-determinism, so we use a looser check here + torch.testing.assert_close(db_out, ref_db, rtol=1e-5, atol=1e-5), f"db mismatch at iter {i}" + return True + + +# ============================================================ +# Fixed-length benchmark +# ============================================================ +def bench_fixed(configs, H: int): + print("\n" + "=" * 120) + print(f" Fixed-Length Benchmark: cuLA CuTe DSL vs FLA Triton (H={H}, K={K}, V={V}, BT={BT})") + print("=" * 120) + results = [] + + for B, T in configs: + set_seed(SEED) + torch.cuda.empty_cache() + + seq_lens = [T] * B + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=DEVICE) + + inputs = prepare_bwd_wy_dqkg_fused_inputs( + B=B, + T=T, + H=H, + K=K, + V=V, + chunk_size=BT, + device=DEVICE, + seed=SEED, + cu_seqlens=cu_seqlens, + ) + + # Accuracy + ref = run_fla_triton(inputs) # (dq, dk, dv, db, dg, dA) + out = run_cutedsl(inputs) # (dq, dk, dv2, db, dg, dA) + torch.cuda.synchronize() + + acc = {} + names = ["dq", "dk", "dv", "db", "dg", "dA"] + for name, r, o in zip(names, ref, out): + err_ratio, rel_max, mean_diff = accuracy_stats(r, o) + acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} + + # Timing + ms_fla = time_kernel(lambda: run_fla_triton(inputs)) + ms_dsl = time_kernel(lambda: run_cutedsl(inputs)) + speedup = ms_fla / ms_dsl if ms_dsl > 0 else float("inf") + + r = { + "B": B, + "T": T, + "accuracy": acc, + "ms_fla": ms_fla, + "ms_dsl": ms_dsl, + "speedup": speedup, + } + results.append(r) + + del inputs + torch.cuda.empty_cache() + + return results + + +# ============================================================ +# Varlen benchmark +# ============================================================ +def bench_varlen(configs, H: int): + print("\n" + "=" * 120) + print(f" Varlen Benchmark: cuLA CuTe DSL vs FLA Triton (H={H}, K={K}, V={V}, BT={BT})") + print("=" * 120) + results = [] + + for seq_lens, total_len, dist in configs: + set_seed(SEED) + torch.cuda.empty_cache() + + T = total_len + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=DEVICE) + + inputs = prepare_bwd_wy_dqkg_fused_inputs( + B=1, + T=T, + H=H, + K=K, + V=V, + chunk_size=BT, + device=DEVICE, + seed=SEED, + cu_seqlens=cu_seqlens, + ) + + # Accuracy + ref = run_fla_triton(inputs) + out = run_cutedsl(inputs) + torch.cuda.synchronize() + + acc = {} + names = ["dq", "dk", "dv", "db", "dg", "dA"] + for name, r, o in zip(names, ref, out): + err_ratio, rel_max, mean_diff = accuracy_stats(r, o) + acc[name] = {"err_ratio": err_ratio, "rel_max": rel_max, "mean_diff": mean_diff} + + # Timing + ms_fla = time_kernel(lambda: run_fla_triton(inputs)) + ms_dsl = time_kernel(lambda: run_cutedsl(inputs)) + speedup = ms_fla / ms_dsl if ms_dsl > 0 else float("inf") + + n_seqs = len(seq_lens) + min_l, max_l = min(seq_lens), max(seq_lens) + avg_l = T // n_seqs + tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" + + r = { + "tag": tag, + "dist": dist, + "T_total": T, + "n_seqs": n_seqs, + "accuracy": acc, + "ms_fla": ms_fla, + "ms_dsl": ms_dsl, + "speedup": speedup, + } + results.append(r) + + del inputs + torch.cuda.empty_cache() + + return results + + +# ============================================================ +# Report +# ============================================================ +def print_report(fixed_results, varlen_results, H: int): + sep = "=" * 130 + print(f"\n\n{sep}") + print(" BENCHMARK REPORT: chunk_kda_bwd_wy_dqkg_fused") + print(" cuLA CuTe DSL vs FLA Triton") + wu = 1 if NCU_MODE else WARMUP + ni = 1 if NCU_MODE else N_ITERS + mode_tag = " [NCU mode]" if NCU_MODE else "" + print(f" H={H} K={K} V={V} BT={BT} dtype=bf16{mode_tag}") + print(f" Warmup={wu} Iters={ni}") + print(sep) + + acc_keys = ["dq", "dk", "dv", "db", "dg", "dA"] + acc_header = " ".join(f"{k:>10s}" for k in acc_keys) + + if fixed_results: + print("\n [Fixed-Length]") + print(f" {'─' * 125}") + print(f" {'B':>3s} {'T':>5s} │ {'FLA(ms)':>9s} {'DSL(ms)':>9s} {'Speedup':>8s} │ {'':>10s}{acc_header}") + print(f" {'─' * 125}") + + for r in fixed_results: + rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) + err_ratio_vals = " ".join(f"{r['accuracy'].get(k, {}).get('err_ratio', 0.0):10.6f}" for k in acc_keys) + print( + f" {r['B']:3d} {r['T']:5d} │ " + f"{r['ms_fla']:9.4f} {r['ms_dsl']:9.4f} {r['speedup']:7.2f}x │ " + f"{'rel_max:':>10s}{rel_max_vals}" + ) + print(f" {'':3s} {'':5s} │ {'':9s} {'':9s} {'':8s} │ {'err_ratio:':>10s}{err_ratio_vals}") + print(f" {'─' * 125}") + + if varlen_results: + print("\n [Varlen]") + print(f" {'─' * 140}") + print(f" {'Config':>45s} │ {'FLA(ms)':>9s} {'DSL(ms)':>9s} {'Speedup':>8s} │ {'':>10s}{acc_header}") + print(f" {'─' * 140}") + + for r in varlen_results: + rel_max_vals = " ".join(f"{r['accuracy'].get(k, {}).get('rel_max', 0.0):10.6f}" for k in acc_keys) + err_ratio_vals = " ".join(f"{r['accuracy'].get(k, {}).get('err_ratio', 0.0):10.6f}" for k in acc_keys) + print( + f" {r['tag']:>45s} │ " + f"{r['ms_fla']:9.4f} {r['ms_dsl']:9.4f} {r['speedup']:7.2f}x │ " + f"{'rel_max:':>10s}{rel_max_vals}" + ) + print(f" {'':>45s} │ {'':9s} {'':9s} {'':8s} │ {'err_ratio:':>10s}{err_ratio_vals}") + print(f" {'─' * 140}") + + print(f"\n{sep}\n") + + +# ============================================================ +# Main +# ============================================================ +def main(): + global NCU_MODE + + parser = argparse.ArgumentParser(description="Benchmark chunk_kda_bwd_wy_dqkg_fused: cuLA CuTe DSL vs FLA Triton") + parser.add_argument( + "--mode", + type=str, + default="both", + choices=["fixed", "varlen", "both"], + help="Which benchmark mode to run (default: both)", + ) + parser.add_argument( + "--heads", + nargs="+", + type=int, + default=[H_DEFAULT], + help=f"Head counts to benchmark (default: [{H_DEFAULT}])", + ) + parser.add_argument("--ncu", action="store_true", help="NCU profiling mode: warmup=1, iters=1") + args = parser.parse_args() + + if args.ncu: + NCU_MODE = True + print("[NCU mode] warmup=1, iters=1") + + gpu_name = torch.cuda.get_device_name(0) + print(f"GPU: {gpu_name}") + wu = 1 if NCU_MODE else WARMUP + ni = 1 if NCU_MODE else N_ITERS + print(f"K={K}, V={V}, BT={BT}, dtype={DTYPE}, warmup={wu}, rep={ni}") + + fixed_configs = [ + (1, 256), + (1, 512), + (1, 1024), + (1, 2048), + (1, 4096), + (1, 8192), + (2, 512), + (2, 1024), + (2, 2048), + (2, 4096), + (2, 8192), + ] + + varlen_configs = build_varlen_configs( + num_seqs_list=(10, 20), + total_lens=(4096, 8192, 16384), + dists=("uniform", "random", "skewed"), + ) + + for H in args.heads: + check_determinism(H=H, iters=10000) + + fixed_res, varlen_res = [], [] + + if args.mode in ("fixed", "both"): + fixed_res = bench_fixed(fixed_configs, H) + + if args.mode in ("varlen", "both"): + varlen_res = bench_varlen(varlen_configs, H) + + print_report(fixed_res, varlen_res, H) + + print(f"\n{'=' * 130}") + print(" All benchmarks done.") + print(f"{'=' * 130}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 29bab04..b726364 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -345,3 +345,96 @@ def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_siz ) return q, k, v, g, beta, scale, cu_seqlens, chunk_indices + +def prepare_bwd_wy_dqkg_fused_inputs( + B: int, + T: int, + H: int, + K: int, + V: int, + chunk_size: int = CHUNK_SIZE, + device: torch.device | str = "cuda", + seed: int = SEED, + cu_seqlens: torch.Tensor | None = None, + dtype: torch.dtype = torch.bfloat16, +) -> dict: + """Prepare all inputs needed by the bwd_wy_dqkg_fused benchmark runners. + + Generates the full set of tensors consumed by both the FLA Triton and CuTe DSL + chunk_kda_bwd_wy_dqkg_fused kernels. Follows the same flattening convention + used in other prepare_* helpers (B=1 with cu_seqlens for varlen mode). + + Returns a dict with keys used directly by ``run_fla_triton`` and ``run_cutedsl`` + in ``bench_bwd_wy_dqkg_fused.py``. + """ + BT = chunk_size + scale = K**-0.5 + + set_seed(seed) + + # ---- primary token-indexed tensors ---- + q = torch.randn(B, T, H, K, dtype=dtype, device=device) + k = torch.randn(B, T, H, K, dtype=dtype, device=device) + v = torch.randn(B, T, H, V, dtype=dtype, device=device) + g_raw = torch.randn(B, T, H, K, dtype=dtype, device=device) + beta = torch.randn(B, T, H, dtype=torch.float, device=device).sigmoid() + + # l2norm q, k + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + + # gate preprocessing + A_log = torch.randn(H, dtype=torch.float, device=device) + dt_bias = torch.randn(H * K, dtype=torch.float, device=device) + + v_new = torch.randn(B, T, H, V, dtype=dtype, device=device) + do = torch.randn(B, T, H, V, dtype=dtype, device=device) + dv = torch.randn(B, T, H, V, dtype=dtype, device=device) + A = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 + + # ---- chunk-indexed state tensors ---- + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.int() + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = chunk_indices.shape[0] + else: + NT = (B * T + BT - 1) // BT + chunk_indices = None + + # h/dh: both FLA Triton and CuTe DSL use bf16 [B, NT, H, K, V] + h = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 + dh = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 + + # flatten to batch_size=1 for cu_seqlens compatibility + if B != 1: + q, k, v, g_raw, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g_raw, beta)) + v_new, do, dv, A = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (v_new, do, dv, A)) + h, dh = map(lambda x: rearrange(x, "b nt ... -> 1 (b nt) ..."), (h, dh)) + + g = kda_gate_chunk_cumsum( + g=g_raw, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=-5.0, + ) + + return dict( + q=q, + k=k, + v=v, + v_new=v_new, + g=g, + beta=beta, + A=A, + h=h, + dh=dh, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py new file mode 100644 index 0000000..f24c185 --- /dev/null +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -0,0 +1,3257 @@ +import argparse + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import torch +from cutlass._mlir.dialects import llvm, arith as _arith, nvvm as _nvvm +from cutlass._mlir import ir +from cutlass.cutlass_dsl import dsl_user_op +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream +from cutlass.cute.arch import ( + elect_one, + mbarrier_init, + mbarrier_init_fence, + mbarrier_wait, + mbarrier_arrive, + sync_threads, +) +from cutlass.cute.nvgpu.tcgen05 import ( + make_umma_smem_desc, + smem_descriptor_to_int, +) +from cutlass.cute.tensor import TensorSSA +from cutlass.cute.typing import Float32, Int32, Int64, BFloat16 +from fla.ops.utils import prepare_chunk_indices + +from cula.utils import USE_FAST_MATH, assert_blackwell + +from cula.ops.intrinsics_sm100 import ( + tcgen05_fence_before, + tcgen05_fence_after, + tcgen05_ld_32x32b, + tcgen05_st_32x32b, + reinterpret_cast, + subvec, + store_256b, +) +from cula.ops.ptx_umma_ext import ( + Tcgen05SmemDescriptor, + tcgen05mma_ws_ts_f16, + tcgen05mma_ws_ss_f16, +) + +PRINT_DEBUG = False + +LN2 = 0.6931471805599453 +RCP_LN2 = 1.4426950408889634 + +COMPILE_OPTIONS = "--enable-tvm-ffi --generate-line-info --ptxas-options '--verbose'" + +# Mapping from torch dtype to cutlass dtype (for beta_dtype conversion) +_torch_to_cutlass_dtype = { + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + +def make_thread_cooperative_group(size: int): + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) + +def _exclusive_cumsum(a: list[int]): + r = [0] + for v in a: + r.append(r[-1] + v) + return r + + +# ── TMEM column offset constants (cta_group::1, M=64, .ws Layout E) ── +# See docs/chunk_wy_dqkg_design_doc.md §6.4 +TMEM_DA_ACC_OFF = 0 # [0,32) 32 cols dA fp32 acc; Phase 3: [0,16) overwritten by dA_bf16 +TMEM_DQ_ACC_OFF = 32 # [32,96) 64 cols dq fp32 acc; Phase 3: step2/step3 result [32,64) +TMEM_DK_ACC_OFF = 96 # [96,160) 64 cols dk fp32 acc +TMEM_DW_ACC_OFF = 160 # [160,224] 64 cols dw fp32 acc +TMEM_FLEX_OFF = 224 # [224,256) 32 cols dvb time-shared +TMEM_A_BF16_OFF = 256 # [256,272) 16 cols A_bf16 TS opA (persistent) +TMEM_DKGB_ACC_OFF = 272 # [272,336) 64 cols, dkgb fp32 acc +TMEM_DA2_ACC_OFF = 336 # [336,368) 32 cols dA fp32 acc, used for dA=dA@A and dA=A@dA +TMEM_TOTAL = 512 + +# Instruction descriptor for M=64, N=64, BF16, dense, TransposeB=1 +# Bits: M>>4=4 at [24:28], N>>3=8 at [17:22], TransposeB at [16], +# btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] +IDESC_F16_M64_N64_K_MN = (4 << 24) | (8 << 17) | (1 << 16) | (1 << 10) | (1 << 7) | (1 << 4) +assert IDESC_F16_M64_N64_K_MN == 0x4110490 + +# Instruction descriptor for M=64, N=128, BF16, dense, TransposeB=1 +# Bits: M>>4=4 at [24:28], N>>3=16 at [17:22], TransposeB at [16], +# btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] +IDESC_F16_M64_N128_K_MN = (4 << 24) | (16 << 17) | (1 << 16) | (1 << 10) | (1 << 7) | (1 << 4) +assert IDESC_F16_M64_N128_K_MN == 0x4210490 + +# Instruction descriptor for M=64, N=128, BF16, dense +# Bits: M>>4=4 at [24:28], N>>3=16 at [17:22], +# btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] +IDESC_F16_M64_N128_K_K = (4 << 24) | (16 << 17) | (1 << 10) | (1 << 7) | (1 << 4) + +# Instruction descriptor for M=64, N=128, BF16, dense, TransposeA=1, TransposeB=1 +# Bits: M>>4=4 at [24:28], N>>3=16 at [17:22], +# TransposeB at [16], TransposeA at [15], +# btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] +IDESC_F16_M64_N128_MN_MN = ( + (4 << 24) | (16 << 17) | (1 << 16) | (1 << 15) | (1 << 10) | (1 << 7) | (1 << 4) +) +assert IDESC_F16_M64_N128_MN_MN == 0x4218490 + +# Instruction descriptor for M=64, N=64, BF16, dense, TransposeA=1, TransposeB=1 +# Bits: M>>4=4 at [24:28], N>>3=8 at [17:22], +# TransposeB at [16], TransposeA at [15], +# btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] +IDESC_F16_M64_N64_MN_MN = ( + (4 << 24) | (8 << 17) | (1 << 16) | (1 << 15) | (1 << 10) | (1 << 7) | (1 << 4) +) + +# Instruction descriptor for M=64, N=64, BF16, dense +# Bits: M>>4=4 at [24:28], N>>3=8 at [17:22], +# TransposeB at [16], TransposeA at [15], +# btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] +IDESC_F16_M64_N64_K_K = ( + (4 << 24) | (8 << 17) | (1 << 10) | (1 << 7) | (1 << 4) +) + +ELEM_BYTES_BF16 = BFloat16.width // 8 + +def make_thread_cooperative_group(size: int): + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) + +# ============================================================ +# Helpers: _ir, Float32 conversion +# ============================================================ + +def _ir(val, loc=None, ip=None): + return val.ir_value(loc=loc, ip=ip) if hasattr(val, "ir_value") else val + +@dsl_user_op +def atomicAdd(dst_ptr: cute.Pointer, val: Int32 | Float32, *, loc=None, ip=None) -> Int32 | Float32: + return cute.arch.atomic_add( + ptr=dst_ptr.llvm_ptr, + val=val, + sem="relaxed", + scope="sys", + loc=loc, + ip=ip, + ) + +@dsl_user_op +def bf16_to_f32(val, *, loc=None, ip=None): + """Convert a BFloat16 value to Float32 using arith.extf (no inline asm).""" + bf16_ir = BFloat16(val).ir_value(loc=loc, ip=ip) + f32_ir = _arith.extf(ir.F32Type.get(), bf16_ir, loc=loc, ip=ip) + return Float32(f32_ir) + + +@dsl_user_op +def f32_to_bf16(val, *, loc=None, ip=None): + """Convert a Float32 value to BFloat16 using native arith.truncf.""" + f32_ir = Float32(val).ir_value(loc=loc, ip=ip) + bf16_ir = _arith.truncf(BFloat16.mlir_type, f32_ir, loc=loc, ip=ip) + return BFloat16(bf16_ir) + +@cute.jit +def smem_load_bf16x8_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32): + """ + Load 8 consecutive bfloat16 from SMEM with Swizzle<3,4,3> layout. + raw_ptr: BFloat16 SMEM base pointer (NOT recast_ptr — raw buffer start) + row: row index in [0, T_TILE=64) + col_base: 8-aligned column index in [0, K_TILE=128) + Logical layout: [BT=64, BV=128] K-major, with the BV=128 dim split into + two halves of 64 elements (high half offset by 4096 elements). + Swizzle<3,4,3> on bf16: phys_elem = elem ^ ((row & 7) << 3) within a half. + Returns an 8-element rmem fragment (bf16). + """ + half = col_base >> Int32(6) + k_inner = col_base & Int32(63) + swizzled = k_inner ^ ((row & Int32(7)) << Int32(3)) + elem_off = half * Int32(4096) + row * Int32(64) + swizzled + aligned_ptr = cute.make_ptr( + BFloat16, (raw_ptr + elem_off).toint(), + cute.AddressSpace.smem, assumed_align=16, + ) + smem_t = cute.make_tensor(aligned_ptr, cute.make_layout((8,), stride=(1,))) + rmem_t = cute.make_fragment_like(smem_t) + cute.autovec_copy(smem_t, rmem_t) + return rmem_t + +# TODO: is this bug for K_TILE != 128? +@cute.jit +def smem_store_bf16x8_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32, data: cute.Tensor): + """ + Store 8 consecutive bfloat16 to SMEM with Swizzle<3,4,3> layout. + raw_ptr: BFloat16 SMEM base pointer (NOT recast_ptr — raw buffer start) + row: row index in [0, T_TILE=64) + col_base: 8-aligned column index in [0, K_TILE=128) + data: 8-element rmem fragment (bf16) to store. + + NOTE: For the K-major→MN-major dv re-swizzle, source layout + `(BT,BV) K-major Swizzle<3,4,3>` and destination layout + `(BV,BT) MN-major Swizzle<3,4,3>` produce **identical** physical + addresses for the same (row=t, col=v). So this helper uses the same + address formula as the load helper, and the caller passes (row=t, col=v) + for both load (src K-maj) and store (dst MN-maj), implicitly transposing. + """ + half = col_base >> Int32(6) + k_inner = col_base & Int32(63) + swizzled = k_inner ^ ((row & Int32(7)) << Int32(3)) + elem_off = half * Int32(4096) + row * Int32(64) + swizzled + smem_ptr = cute.make_ptr( + BFloat16, (raw_ptr + elem_off).toint(), + cute.AddressSpace.smem, assumed_align=16, + ) + smem_t = cute.make_tensor(smem_ptr, cute.make_layout((8,), stride=(1,))) + cute.autovec_copy(data, smem_t) + +@cute.jit +def smem_load_f32x4_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32): + """ + Load 4 consecutive float32 from SMEM with K_SW128 layout. + Logical layout: [BT=64, BK=128] ROW_MAJOR, tiled over a Float32 K_SW128 atom. + The atom provides a 32-element row stride. The 128-element column is broken + into 4 blocks of 32 elements. + PyCutlass tiles this such that outer blocks stride by 2048 elements: + elem_idx = row * 32 + (col_base % 32) + (col_base / 32) * 2048 + + The TMA hardware performs a 128B Swizzle on physical byte addresses: + byte_idx = elem_idx * 4 + swizzled_byte = byte_idx ^ (((byte_idx >> 7) & 7) << 4) + Dividing by 4 yields the element XOR offset: + elem_xor = ((elem_idx >> 5) & 7) << 2 + Because (elem_idx >> 5) simplifies to 'row + (col_outer * 64)', + the XOR offset simplifies exactly to ((row & 7) << 2). + This only affects the inner 32-element column block. + """ + c_inner = col_base & Int32(31) + c_outer = col_base >> Int32(5) + swizzled_inner = c_inner ^ ((row & Int32(7)) << Int32(2)) + + elem_offset = row * Int32(32) + swizzled_inner + c_outer * Int32(2048) + + aligned_ptr = cute.make_ptr( + Float32, (raw_ptr + elem_offset).toint(), + cute.AddressSpace.smem, assumed_align=16, + ) + t = cute.make_tensor(aligned_ptr, cute.make_layout((4,), stride=(1,))) + vals = t.load() + return (vals[0], vals[1], vals[2], vals[3]) + +@cute.jit +def smem_store_f32x4_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32, data: cute.Tensor): + """ + Store 4 consecutive float32 to SMEM with K_SW128 layout. + Inverse of smem_load_f32x4_sw128 — same address formula, write path. + raw_ptr: Float32 SMEM base pointer (raw buffer start) + row: row index in [0, BT) + col_base: 4-aligned column index (multiples of 4) + data: 4-element rmem fragment (f32) to store. + """ + c_inner = col_base & Int32(31) + c_outer = col_base >> Int32(5) + swizzled_inner = c_inner ^ ((row & Int32(7)) << Int32(2)) + elem_offset = row * Int32(32) + swizzled_inner + c_outer * Int32(2048) + smem_ptr = cute.make_ptr( + Float32, (raw_ptr + elem_offset).toint(), + cute.AddressSpace.smem, assumed_align=16, + ) + smem_t = cute.make_tensor(smem_ptr, cute.make_layout((4,), stride=(1,))) + cute.autovec_copy(data, smem_t) + +# SMEM B: MN-major +@cute.jit +def mma_ws_ts_m64n128_call( + tmem_a_base: Int32, + b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, K: Int32 +): + with elect_one(): + b_outer = b_smem_layout.outer + for ks in cutlass.range_constexpr(K // 16): + scale = 0 if ks == 0 else 1 + b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_BF16 + desc_b = desc_b_base + b_off + tmem_a = tmem_a_base + Int32(ks * 4) + tcgen05mma_ws_ts_f16(tmem_a, desc_b, tmem_c, IDESC_F16_M64_N128_K_MN, scale) + +@cute.jit +def mma_ws_ss_m64n128_call( + a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, K: Int32, is_accum: bool = False, +): + with elect_one(): + a_outer = a_smem_layout.outer + b_outer = b_smem_layout.outer + scale = 0 if not is_accum else 1 + for ks in cutlass.range_constexpr(K // 16): + a_off = cute.crd2idx(((0, 0), 0, ks, 0), a_outer) * ELEM_BYTES_BF16 + b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_BF16 + desc_a = desc_a_base + a_off + desc_b = desc_b_base + b_off + tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N128_K_MN, scale) + scale = 1 + +@cute.jit +def mma_ws_ss_m64n128_k_k_call( + a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, K: Int32, is_accum: bool = False, +): + with elect_one(): + a_outer = a_smem_layout.outer + b_outer = b_smem_layout.outer + scale = 0 if not is_accum else 1 + for ks in cutlass.range_constexpr(K // 16): + a_off = cute.crd2idx(((0, 0), 0, ks, 0), a_outer) * ELEM_BYTES_BF16 + b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_BF16 + desc_a = desc_a_base + a_off + desc_b = desc_b_base + b_off + tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N128_K_K, scale) + scale = 1 + +@cute.jit +def mma_ws_ss_m64n128_mn_mn_call( + a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, K: Int32, is_accum: bool = False, +): + with elect_one(): + a_outer = a_smem_layout.outer + b_outer = b_smem_layout.outer + scale = 0 if not is_accum else 1 + for ks in cutlass.range_constexpr(K // 16): + a_off = cute.crd2idx(((0, 0), 0, ks, 0), a_outer) * ELEM_BYTES_BF16 + b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_BF16 + desc_a = desc_a_base + a_off + desc_b = desc_b_base + b_off + tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N128_MN_MN, scale) + scale = 1 + +@cute.jit +def mma_ws_ts_m64n64_call( + tmem_a_base: Int32, + b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, K: Int32 +): + pass + +@cute.jit +def mma_ws_ss_m64n64_call( + a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, K: Int32 +): + pass + +@cute.jit +def mma_ws_ss_m64n64_k_k_call( + a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, K: Int32, is_accum: bool = False, +): + with elect_one(): + a_outer = a_smem_layout.outer + b_outer = b_smem_layout.outer + scale = 0 if not is_accum else 1 + for ks in cutlass.range_constexpr(K // 16): + a_off = cute.crd2idx(((0, 0), 0, ks, 0), a_outer) * ELEM_BYTES_BF16 + b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_BF16 + desc_a = desc_a_base + a_off + desc_b = desc_b_base + b_off + tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N64_K_K, scale) + scale = 1 + +@cute.jit +def mma_ws_ss_m64n64_mn_mn_call( + a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, K: Int32, is_accum: bool = False, +): + with elect_one(): + a_outer = a_smem_layout.outer + b_outer = b_smem_layout.outer + scale = 0 if not is_accum else 1 + for ks in cutlass.range_constexpr(K // 16): + a_off = cute.crd2idx(((0, 0), 0, ks, 0), a_outer) * ELEM_BYTES_BF16 + b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_BF16 + desc_a = desc_a_base + a_off + desc_b = desc_b_base + b_off + tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N64_MN_MN, scale) + scale = 1 + +@cute.jit +def umma_arrive(mbar_ptr: cute.Pointer): + """tcgen05.commit.cta_group::1.mbarrier::arrive::one — signal MMA done.""" + with elect_one(): + tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + +class ChunkKdaBwdWyDqkgFused: + """ + CuTe DSL kernel for chunk_kda_bwd_kernel_wy_dqkg_fused. + + Computes backward gradients dq, dk, dv2, dg, db, dA for the KDA + chunkwise delta-rule backward pass. + + Architecture: 1 CudaCore WG + 1 MMA warp + TMA/Aux warps. + See docs/chunk_kda_bwd_kernel_wy_dqkg_fused_blackwell_warpgroup_recommendation.md + """ + + def __init__( + self, + chunk_size: int = 64, + head_dim_k: int = 128, + head_dim_v: int = 128, + acc_dtype: type[cutlass.Numeric] = cutlass.Float32, + io_dtype: type[cutlass.Numeric] = cutlass.BFloat16, + g_dtype: type[cutlass.Numeric] = cutlass.Float32, + beta_dtype: type[cutlass.Numeric] = cutlass.Float32, + scale: float = 1.0, + min_occupancy: int = 1, # FIXME: change to 2, bug exists for accuracy + use_fast_math: bool = True, + ): + assert chunk_size == 64, "chunk_size must be 64" + assert head_dim_k == 128 and head_dim_v == 128, ( + f"head_dim_k and head_dim_v must both be 128, got head_dim_k={head_dim_k}, head_dim_v={head_dim_v}" + ) + assert_blackwell() + + self.use_fast_math = use_fast_math + self.chunk_size = chunk_size + self.head_dim_k = head_dim_k + self.head_dim_v = head_dim_v + self.acc_dtype = acc_dtype + self.io_dtype = io_dtype + self.g_dtype = g_dtype + self.beta_dtype = beta_dtype + self.scale = scale + + # Tile sizes + self.BT = chunk_size # 64 + self.BK = 128 # K tiling for V-loop GEMM (single K tile) + self.BV = 64 # V tiling for V-loop GEMM (single V tile) + + # Warp layout: WG0 (4 warps CudaCore+Store) + WG1 (1 MMA + 1 Load + 2 Aux) + self.threads_per_warp = 32 + self.cuda_warp_ids = (0, 1, 2, 3) # WG0: CudaCore + Store + self.cuda2_warp_ids = (4, 5, 6, 7) # WG1: CudaCore + Store + self.mma_warp_id = 8 # WG2: MMA dispatch + self.load_warp_id = 9 # WG2: TMA Load + self.aux_warp_ids = (10, 11) # WG2: Aux/Load Aux + self.threads_per_cta = self.threads_per_warp * 12 # 384 threads (3 WGs) + + self.num_regs_cuda = 208 + self.num_regs_others = 88 + self.min_occupancy = min_occupancy + + self.cluster_shape_mnk = (1, 1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + + # Number of K/V tiles + self.num_k_tiles = (head_dim_k + self.BK - 1) // self.BK # 128/128 = 1 + self.num_v_tiles = (head_dim_v + self.BV - 1) // self.BV # 128/64 = 2 + + # ── Pipeline stages ── + # V-loop TMA: 2-stage double buffer + self.vloop_stage = 2 + self.kloop_stage = 1 + self.a_stage = 1 # TODO: increase to 2 + self.mma_stage = 1 + + # ── MMA tiler shapes ── + # V-loop GEMMs: [BT, BV] × [BV, BK] → [BT, BK] + # dq = do @ h : (BT, BK, BV) — M=BT, N=BK, K=BV + # dk = v_new @ dh : (BT, BK, BV) + # dw = dv @ h : (BT, BK, BV) + self.vloop_gemm_tiler = (self.BT, self.BK, self.BV) + + # V-loop i_k==0 GEMMs: [BT, BV] × [BV, BT] → [BT, BT] + # dA = dv @ v^T : (BT, BT, BV) + self.dA_vloop_tiler = (self.BT, self.BT, self.BV) + + # V-loop i_k==0: A @ dv : [BT, BT] × [BT, BV] → [BT, BV] + self.dvb_tiler = (self.BT, self.BV, self.BT) + + # K-loop GEMMs: + # dA += dw @ kg^T : [BT, BK] × [BK, BT] → [BT, BT] → (BT, BT, BK) + self.kloop_dA_tiler = (self.BT, self.BT, self.BK) + # dkgb = A @ dw : [BT, BT] × [BT, BK] → [BT, BK] → (BT, BK, BT) + self.kloop_dkgb_tiler = (self.BT, self.BK, self.BT) + + # dA-post GEMMs: + # dA @ A : [BT, BT] × [BT, BT] → [BT, BT] → (BT, BT, BT) + # A @ dA : same + self.dApost_tiler = (self.BT, self.BT, self.BT) + + # Named barriers + self.tmem_dealloc_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=self.threads_per_cta, + ) + self.mma_warp_sync_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32, + ) + self.cuda_wg_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=32 * 8, + ) + self.buffer_align_bytes = 1024 + + # Persistent scheduling + self.persistent = True + hardware_info = cutlass.utils.HardwareInfo() + self.num_sm = hardware_info.get_device_multiprocessor_count() + + def _compute_grid(self, B, T, H, total_nt=None): + """Compute grid dimensions for persistent kernel launch. + + Grid: (min(num_sm * min_occupancy, total_tiles), 1, 1) + Each CTA handles multiple tiles via stride-by-gridDim.x loop. + """ + assert total_nt is not None + total_tiles = total_nt * H + grid_x = cutlass.min(Int32(self.num_sm * self.min_occupancy), total_tiles) + return (grid_x, Int32(1), Int32(1)) + + @cute.jit + def __call__( + self, + # ── Inputs ── + q_in: cute.Tensor, # [B, T, H, K] bf16 + k_in: cute.Tensor, # [B, T, H, K] bf16 + v_in: cute.Tensor, # [B, T, H, V] bf16 + v_new_in: cute.Tensor, # [B, T, H, V] bf16 + g_in: cute.Tensor, # [B, T, H, K] fp32 + beta_in: cute.Tensor, # [B, T, H] fp32 + A_in: cute.Tensor, # [B, T, H, BT] bf16 + h_in: cute.Tensor, # [B, NT, H, K, V] bf16 + do_in: cute.Tensor, # [B, T, H, V] bf16 + dh_in: cute.Tensor, # [B, NT, H, K, V] bf16 + dv_in: cute.Tensor, # [B, T, H, V] bf16 + # ── Outputs ── + dq_in: cute.Tensor, # [B, T, H, K] fp32 + dk_in: cute.Tensor, # [B, T, H, K] fp32 + dv2_in: cute.Tensor, # [B, T, H, V] bf16 + dg_in: cute.Tensor, # [B, T, H, K] fp32 + db_in: cute.Tensor, # [B, T, H] fp32 + dA_in: cute.Tensor, # [B, T, H, BT] fp32 + # ── Metadata ── + cu_seqlens_in: cute.Tensor, # [N+1] int32 + chunk_indices_in: cute.Tensor, # [NT, 2] int32 + problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + total_nt: Int32, + stream, + ): + # ── Extract pointers ── + q_ptr = q_in.iterator + k_ptr = k_in.iterator + v_ptr = v_in.iterator + v_new_ptr = v_new_in.iterator + g_ptr = g_in.iterator + beta_ptr = beta_in.iterator + A_ptr = A_in.iterator + h_ptr = h_in.iterator + do_ptr = do_in.iterator + dh_ptr = dh_in.iterator + dv_ptr = dv_in.iterator + dq_ptr = dq_in.iterator + dk_ptr = dk_in.iterator + dv2_ptr = dv2_in.iterator + dg_ptr = dg_in.iterator + db_ptr = db_in.iterator + dA_ptr = dA_in.iterator + cu_seqlens_ptr = cu_seqlens_in.iterator + chunk_indices_ptr = chunk_indices_in.iterator + + B, T, H, K, V = problem_size + BT = self.BT + BK = self.BK + BV = self.BV + + data_B = Int32(1) + NT = total_nt + + # ===================== GMEM layouts ===================== + # Token-indexed tensors: (T, dim, (H, data_B)) + # q, k: (T, K, (H, data_B)) bf16 + qk_layout = cute.make_layout( + (T, K, (H, data_B)), + stride=(H * K, 1, (K, T * H * K)), + ) + q = cute.make_tensor(q_ptr, qk_layout) + k = cute.make_tensor(k_ptr, qk_layout) + + # v, v_new, do, dv, dv2: (T, V, (H, data_B)) bf16 + tv_layout = cute.make_layout( + (T, V, (H, data_B)), + stride=(H * V, 1, (V, T * H * V)), + ) + v = cute.make_tensor(v_ptr, tv_layout) + v_new = cute.make_tensor(v_new_ptr, tv_layout) + do = cute.make_tensor(do_ptr, tv_layout) + dv = cute.make_tensor(dv_ptr, tv_layout) + dv2 = cute.make_tensor(dv2_ptr, tv_layout) + + # g: (T, K, (H, data_B)) fp32 + g_layout = cute.make_layout( + (T, K, (H, data_B)), + stride=(H * K, 1, (K, T * H * K)), + ) + g = cute.make_tensor(g_ptr, g_layout) + + # beta: (T, (H, data_B)) fp32 + beta_layout = cute.make_layout( + (T, (H, data_B)), + stride=(H, (1, T * H)), + ) + beta = cute.make_tensor(beta_ptr, beta_layout) + + # A: (T, BT, (H, data_B)) bf16 + a_layout = cute.make_layout( + (T, BT, (H, data_B)), + stride=(H * BT, 1, (BT, T * H * BT)), + ) + A = cute.make_tensor(A_ptr, a_layout) + # NOTE: for A as operand A, A is loaded as transposed view to do MMA + a_t_layout = cute.make_layout( + (BT, T, (H, data_B)), + stride=(1, H * BT, (BT, T * H * BT)), + ) + A_T = cute.make_tensor(A_ptr, a_t_layout) + + # dq, dk: (T, K, (H, data_B)) fp32 + dqk_layout = cute.make_layout( + (T, K, (H, data_B)), + stride=(H * K, 1, (K, T * H * K)), + ) + dq = cute.make_tensor(dq_ptr, dqk_layout) + dk = cute.make_tensor(dk_ptr, dqk_layout) + + # dg: (T, K, (H, data_B)) fp32 + dg = cute.make_tensor(dg_ptr, dqk_layout) + + # db: (T, (H, data_B)) fp32 + db = cute.make_tensor(db_ptr, beta_layout) + + # dA: (T, BT, (H, data_B)) fp32 + dA_layout = cute.make_layout( + (T, BT, (H, data_B)), + stride=(H * BT, 1, (BT, T * H * BT)), + ) + dA_out = cute.make_tensor(dA_ptr, dA_layout) + + h_nt_total = NT + + # h row-major: (K, V, (h_nt_total, H)) as operand B + h_layout = cute.make_layout( + (K, V, (h_nt_total, H)), + stride=(V, 1, (H * K * V, K * V)), + ) + h = cute.make_tensor(h_ptr, h_layout) + dh = cute.make_tensor(dh_ptr, h_layout) + + # Transposed views for V-loop TMA (data loaded as MMA B-operands): + vt_layout = cute.make_layout( + (V, T, (data_B, H)), + stride=(1, H * V, (T * H * V, V)), + ) + v_T = cute.make_tensor(v_ptr, vt_layout) + + # ===================== MMA setup (4 objects) ===================== + # All use tcgen05.mma.ws (Layout E, M=64, cta_group::1). + # 1. vloop_tiled_mma: SS K,K (64,128) — dq, dk, dw + # dq += do @ h, dk += vnew @ dh, dw += dv @ h + vloop_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.K, # A: K-major + tcgen05.OperandMajorMode.K, # B: K-major + self.acc_dtype, + self.cta_group, + self.vloop_gemm_tiler[:2], # (64, 128) + # default a_source=OperandSource.SMEM → SS mode + ) + + # 2. dA_vloop_tiled_mma: SS K,K (64,64) — dA vloop + kpost_dA + # dA += dv @ v^T, dA += dw @ kg^T + dA_vloop_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + self.cta_group, + self.dA_vloop_tiler[:2], # (64, 64) + # default a_source=OperandSource.SMEM → SS mode + ) + + # 3. dvb_tiled_mma: SS MN,MN (64,64) — dvb + dkgb + # dvb = A @ dv, dkgb = A @ dw + dvb_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.MN, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + self.cta_group, + self.dvb_tiler[:2], # (64, 64) + ) + + # dkgb_tiled_mma: SS MN,MN (64,128) - dkgb + dkgb_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.MN, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + self.cta_group, + self.kloop_dkgb_tiler[:2], # (64, 128) + ) + + # dA_kloop_tiled_mma: SS K,K (64, 64) + # dA += dw @ kg^T + dA_kloop_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + self.cta_group, + self.kloop_dA_tiler[:2] # (64, 64) + ) + + # dA2post_tiled_mma: SS K,K (64,64) + # dA = dA @ A + dA2post_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + self.cta_group, + self.dApost_tiler[:2], # (64, 64) + # tcgen05.OperandSource.SMEM, # SS mode + ) + + # dA3post_tiled_mma: SS MN,MN (64,64) + # dA = A @ dA + dA3post_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.MN, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + self.cta_group, + self.dApost_tiler[:2], # (64, 64) + # tcgen05.OperandSource.SMEM, # SS mode + ) + + # ===================== SMEM layouts ===================== + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + + # SS opA layout: do/vnew/dv [BT,BV]=[64,64] K-major + vloop_opA_smem = sm100_utils.make_smem_layout_a( + vloop_tiled_mma, + self.vloop_gemm_tiler, + self.io_dtype, + self.vloop_stage, + ) + + # SS opB layout: h/dh [BK,BV]=[128,64] K-major + vloop_opB_smem = sm100_utils.make_smem_layout_b( + vloop_tiled_mma, + self.vloop_gemm_tiler, + self.io_dtype, + self.vloop_stage, + ) + + # SS opB layout: v [BV,BT]=[128,64] K-major (dA vloop) + v_opB_smem = sm100_utils.make_smem_layout_b( + dA_vloop_tiled_mma, + self.dA_vloop_tiler, + self.io_dtype, + self.vloop_stage, + ) + + # SS opA layout: A MN-major [BT,BT]=[64,64] + A_mn_opA_smem = sm100_utils.make_smem_layout_a( + dvb_tiled_mma, + self.dvb_tiler, + self.io_dtype, + self.a_stage, + ) + + # opB: dv MN-major [BV,BT]=[64,64] + dv_mn_opB_smem = sm100_utils.make_smem_layout_b( + dvb_tiled_mma, + self.dvb_tiler, + self.io_dtype, + self.vloop_stage, + ) + + # opA: dw K-major [BT,BK]=[64,128] + dw_k_opA_smem = sm100_utils.make_smem_layout_a( + dA_vloop_tiled_mma, + self.kloop_dA_tiler, + self.io_dtype, + self.kloop_stage, + ) + + # opB: dw MN-major [BK,BT] + dw_mn_opB_smem = sm100_utils.make_smem_layout_b( + dkgb_tiled_mma, + self.kloop_dkgb_tiler, + self.io_dtype, + self.kloop_stage, + ) + + # opB: kg^T K-major [BT, BK] + kg_k_opB_smem = sm100_utils.make_smem_layout_b( + dA_kloop_tiled_mma, + self.kloop_dA_tiler, + self.io_dtype, + self.kloop_stage, + ) + + # opA: dA K-major [BT,BT] + dA_k_opA_smem = sm100_utils.make_smem_layout_a( + dA2post_tiled_mma, + self.dApost_tiler, + self.io_dtype, + self.a_stage, + ) + + # opB: A K-major [BT,BT] + A_k_opB_smem = sm100_utils.make_smem_layout_b( + dA2post_tiled_mma, + self.dApost_tiler, + self.io_dtype, + self.a_stage, + ) + + # opB: dA MN-major [BT,BT] + dA_mn_opB_smem = sm100_utils.make_smem_layout_b( + dA3post_tiled_mma, + self.dApost_tiler, + self.io_dtype, + self.a_stage, + ) + + # --- Epilogue (non-MMA) layouts --- + g_epi_smem_layout = sm100_utils.make_smem_layout_epi( + self.g_dtype, + utils.LayoutEnum.ROW_MAJOR, + (self.BT, self.BK), + self.kloop_stage, + ) + + k_epi_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR, + (self.BT, self.BK), + self.kloop_stage, + ) + + q_epi_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR, + (self.BT, self.BK), + 1, + ) + + # ===================== Cluster layout ===================== + cluster_layout = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (vloop_tiled_mma.thr_id.shape,), + ) + + # ===================== TMA descriptors ===================== + # Strip stage dimension for TMA atom creation (expects 3 modes, not 4) + vloop_opA_smem_no_stage = cute.select(vloop_opA_smem, mode=[0, 1, 2]) + vloop_opB_smem_no_stage = cute.select(vloop_opB_smem, mode=[0, 1, 2]) + v_opB_smem_no_stage = cute.select(v_opB_smem, mode=[0, 1, 2]) + A_mn_opA_smem_no_stage = cute.select(A_mn_opA_smem, mode=[0, 1, 2]) + + tma_atom_dv, tma_tensor_dv = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + dv, + vloop_opA_smem_no_stage, + self.vloop_gemm_tiler, + vloop_tiled_mma, + cluster_layout.shape, + ) + + tma_atom_A, tma_tensor_A = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + A_T, + A_mn_opA_smem_no_stage, + self.dvb_tiler, + dvb_tiled_mma, + cluster_layout.shape, + ) + + tma_atom_h, tma_tensor_h = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + h, + vloop_opB_smem_no_stage, + self.vloop_gemm_tiler, + vloop_tiled_mma, + cluster_layout.shape, + ) + + tma_atom_dh, tma_tensor_dh = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + dh, + vloop_opB_smem_no_stage, + self.vloop_gemm_tiler, + vloop_tiled_mma, + cluster_layout.shape, + ) + + tma_atom_do, tma_tensor_do = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + do, + vloop_opA_smem_no_stage, + self.vloop_gemm_tiler, + vloop_tiled_mma, + cluster_layout.shape, + ) + + tma_atom_vnew, tma_tensor_vnew = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + v_new, + vloop_opA_smem_no_stage, + self.vloop_gemm_tiler, + vloop_tiled_mma, + cluster_layout.shape, + ) + + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + v, + v_opB_smem_no_stage, + self.dA_vloop_tiler, + dA_vloop_tiled_mma, + cluster_layout.shape, + ) + + g_epi_smem_no_stage = cute.select(g_epi_smem_layout, mode=[0, 1]) + tma_atom_g, tma_tensor_g = cpasync.make_tiled_tma_atom( + tma_load_op, + g, + g_epi_smem_no_stage, + (self.BT, self.BK), + ) + + k_epi_smem_no_stage = cute.select(k_epi_smem_layout, mode=[0, 1]) + tma_atom_k, tma_tensor_k = cpasync.make_tiled_tma_atom( + tma_load_op, + k, + k_epi_smem_no_stage, + (self.BT, self.BK), + ) + + q_epi_smem_no_stage = cute.select(q_epi_smem_layout, mode=[0, 1]) + tma_atom_q, tma_tensor_q = cpasync.make_tiled_tma_atom( + tma_load_op, + q, + q_epi_smem_no_stage, + (self.BT, self.BK), + ) + + # ===================== TMA byte counts ===================== + self.tma_bytes_A = cute.size_in_bytes(self.io_dtype, A_mn_opA_smem_no_stage) + self.tma_bytes_dv = cute.size_in_bytes(self.io_dtype, vloop_opA_smem_no_stage) + self.tma_bytes_h = cute.size_in_bytes(self.io_dtype, vloop_opB_smem_no_stage) + self.tma_bytes_dh = cute.size_in_bytes(self.io_dtype, vloop_opB_smem_no_stage) + self.tma_bytes_do = cute.size_in_bytes(self.io_dtype, vloop_opA_smem_no_stage) + self.tma_bytes_vnew = cute.size_in_bytes(self.io_dtype, vloop_opA_smem_no_stage) + self.tma_bytes_g = cute.size_in_bytes(self.g_dtype, g_epi_smem_no_stage) + self.tma_bytes_v = cute.size_in_bytes(self.io_dtype, v_opB_smem_no_stage) + self.tma_bytes_k = cute.size_in_bytes(self.io_dtype, k_epi_smem_no_stage) + self.tma_bytes_q = cute.size_in_bytes(self.io_dtype, q_epi_smem_no_stage) + + # ===================== SharedStorage ===================== + @cute.struct + class SharedStorage: + # ======= mbarrier ======= + bar_load_A: cute.struct.MemRange[Int64, self.a_stage * 2] + bar_load_dv: cute.struct.MemRange[Int64, self.vloop_stage * 2] + bar_mma_dvb: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_load_beta: cute.struct.MemRange[Int64, 1 * 2] + bar_load_h: cute.struct.MemRange[Int64, self.vloop_stage * 2] + bar_load_dh: cute.struct.MemRange[Int64, self.vloop_stage * 2] + bar_load_do: cute.struct.MemRange[Int64, self.vloop_stage * 2] + bar_load_g: cute.struct.MemRange[Int64, self.kloop_stage * 2] + bar_load_v: cute.struct.MemRange[Int64, self.vloop_stage * 2] + bar_load_vnew: cute.struct.MemRange[Int64, self.vloop_stage * 2] + bar_load_q: cute.struct.MemRange[Int64, self.kloop_stage * 2] + bar_load_k: cute.struct.MemRange[Int64, self.kloop_stage * 2] + bar_mma_dq: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_mma_dw: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_mma_dk: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_mma_dkgb: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_mma_dA: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_mma_dA2: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_mma_dA3: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_mma_done_dA: cute.struct.MemRange[Int64, self.mma_stage] + bar_mma_done_vloop: cute.struct.MemRange[Int64, self.mma_stage] + bar_prologue_dw: cute.struct.MemRange[Int64, self.kloop_stage * 2] + bar_prologue_kg: cute.struct.MemRange[Int64, self.kloop_stage * 2] + bar_prologue_dA2: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_prologue_dA3: cute.struct.MemRange[Int64, self.mma_stage * 2] + # TMEM holding buffer + tmem_holding_buf: Int32 + # A, stage=1, [BT,BT], 8KB + buf_A: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(A_mn_opA_smem)], + self.buffer_align_bytes, + ] + # k, stage=1, [BT,BK], 16KB + buf_k: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(k_epi_smem_layout)], + self.buffer_align_bytes, + ] + # g, stage=1, [BT,BK], 32KB + buf_g: cute.struct.Align[ + cute.struct.MemRange[self.g_dtype, cute.cosize(g_epi_smem_layout)], + self.buffer_align_bytes, + ] + # q, stage=1, [BT,BK], 16KB + buf_q: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(q_epi_smem_layout)], + self.buffer_align_bytes, + ] + # V-loop buffers, stage=2 + # h, dh, [BK,BV] 32KB*2 + buf_h: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(vloop_opB_smem)], + self.buffer_align_bytes, + ] + buf_dh: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(vloop_opB_smem)], + self.buffer_align_bytes, + ] + # do, dv, v_new, v, [BT,BV] 16KB*4 + buf_do: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(vloop_opA_smem)], + self.buffer_align_bytes, + ] + buf_dv: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(vloop_opA_smem)], + self.buffer_align_bytes, + ] + buf_vnew: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(vloop_opA_smem)], + self.buffer_align_bytes, + ] + buf_v: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(v_opB_smem)], + self.buffer_align_bytes, + ] + + # dw, stage=1, [BT,BK] 16KB + buf_dw: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(dw_k_opA_smem)], + self.buffer_align_bytes, + ] + # Scalars + s_beta: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.BT], + 128, + ] + s_db: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.BT], + 128, + ] + s_gn: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.BK], + 128, + ] + s_dgk: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.BK], + 128, + ] + + self.shared_storage = SharedStorage + + # ===================== cu_seqlens / chunk_indices tensors ===================== + cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((B + 1,))) + chunk_indices = cute.make_tensor(chunk_indices_ptr, cute.make_layout((total_nt, 2), stride=(2, 1))) + + # ===================== Grid ===================== + grid = self._compute_grid(B, T, H, total_nt=total_nt) + + # ===================== Launch kernel ===================== + self.kernel( + # MMA objects (4) + vloop_tiled_mma, + dA_vloop_tiled_mma, + dvb_tiled_mma, + dA_kloop_tiled_mma, + dA2post_tiled_mma, + dA3post_tiled_mma, + # TMA atoms + tma_atom_dv, + tma_tensor_dv, + tma_atom_A, + tma_tensor_A, + tma_atom_h, + tma_tensor_h, + tma_atom_dh, + tma_tensor_dh, + tma_atom_do, + tma_tensor_do, + tma_atom_g, + tma_tensor_g, + tma_atom_v, + tma_tensor_v, + tma_atom_k, + tma_tensor_k, + tma_atom_vnew, + tma_tensor_vnew, + tma_atom_q, + tma_tensor_q, + # SMEM layouts + vloop_opA_smem, + vloop_opB_smem, + v_opB_smem, + A_mn_opA_smem, + dv_mn_opB_smem, + dw_k_opA_smem, + dw_mn_opB_smem, + kg_k_opB_smem, + A_k_opB_smem, + dA_k_opA_smem, + dA_mn_opB_smem, + g_epi_smem_layout, + k_epi_smem_layout, + q_epi_smem_layout, + # GMEM tensors + q, k, g, beta, dq, dk, dv2, dg, db, dA_out, + # Metadata + cu_seqlens, + chunk_indices, + problem_size, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=self.min_occupancy, + ) + + @cute.kernel + def kernel( + self, + # MMA objects (4) + vloop_tiled_mma: cute.TiledMma, + dA_vloop_tiled_mma: cute.TiledMma, + dvb_tiled_mma: cute.TiledMma, + dA_kloop_tiled_mma: cute.TiledMma, + dA2post_tiled_mma: cute.TiledMma, + dA3post_tiled_mma: cute.TiledMma, + # TMA atoms + tensors + tma_atom_dv: cute.CopyAtom, + tma_tensor_dv: cute.Tensor, + tma_atom_A: cute.CopyAtom, + tma_tensor_A: cute.Tensor, + tma_atom_h: cute.CopyAtom, + tma_tensor_h: cute.Tensor, + tma_atom_dh: cute.CopyAtom, + tma_tensor_dh: cute.Tensor, + tma_atom_do: cute.CopyAtom, + tma_tensor_do: cute.Tensor, + tma_atom_g: cute.CopyAtom, + tma_tensor_g: cute.Tensor, + tma_atom_v: cute.CopyAtom, + tma_tensor_v: cute.Tensor, + tma_atom_k: cute.CopyAtom, + tma_tensor_k: cute.Tensor, + tma_atom_vnew: cute.CopyAtom, + tma_tensor_vnew: cute.Tensor, + tma_atom_q: cute.CopyAtom, + tma_tensor_q: cute.Tensor, + # SMEM layouts + vloop_opA_smem: cute.ComposedLayout, + vloop_opB_smem: cute.ComposedLayout, + v_opB_smem: cute.ComposedLayout, + A_mn_opA_smem: cute.ComposedLayout, + dv_mn_opB_smem: cute.ComposedLayout, + dw_k_opA_smem: cute.ComposedLayout, + dw_mn_opB_smem: cute.ComposedLayout, + kg_k_opB_smem: cute.ComposedLayout, + A_k_opB_smem: cute.ComposedLayout, + dA_k_opA_smem: cute.ComposedLayout, + dA_mn_opB_smem: cute.ComposedLayout, + g_epi_smem_layout: cute.ComposedLayout, + k_epi_smem_layout: cute.ComposedLayout, + q_epi_smem_layout: cute.ComposedLayout, + # GMEM tensors + q_gmem: cute.Tensor, + k_gmem: cute.Tensor, + g_gmem: cute.Tensor, + beta_gmem: cute.Tensor, + dq_gmem: cute.Tensor, + dk_gmem: cute.Tensor, + dv2_gmem: cute.Tensor, + dg_gmem: cute.Tensor, + db_gmem: cute.Tensor, + dA_gmem: cute.Tensor, + # Metadata + cu_seqlens: cute.Tensor, + chunk_indices: cute.Tensor, + problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + ): + B, T, H, K, V = problem_size + BT = self.BT + BK, BV = self.BK, self.BV + + # ===================== Persistent work decode ===================== + # Grid: (min(num_sm * occ, total_tiles), 1, 1) — persistent + block_idx_x = cute.arch.block_idx()[0] + grid_dim_x = cute.arch.grid_dim()[0] + thread_idx = cute.arch.thread_idx()[0] + lane_idx = thread_idx % 32 + + total_work_units = chunk_indices.layout.shape[0] * H + num_iters = (total_work_units - block_idx_x + grid_dim_x - 1) // grid_dim_x + + num_cuda_warps = len(self.cuda_warp_ids) + num_cuda_warps_total = len(self.cuda_warp_ids) + len(self.cuda2_warp_ids) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_A) + cpasync.prefetch_descriptor(tma_atom_dv) + cpasync.prefetch_descriptor(tma_atom_h) + cpasync.prefetch_descriptor(tma_atom_dh) + cpasync.prefetch_descriptor(tma_atom_do) + cpasync.prefetch_descriptor(tma_atom_g) + cpasync.prefetch_descriptor(tma_atom_v) + cpasync.prefetch_descriptor(tma_atom_vnew) + cpasync.prefetch_descriptor(tma_atom_k) + cpasync.prefetch_descriptor(tma_atom_q) + + # ===================== SMEM allocation ===================== + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Barrier Initialization + bar_mma_done_vloop_ptr = storage.bar_mma_done_vloop.data_ptr() + if warp_idx == 0: + with elect_one(): + for i in cutlass.range(self.mma_stage): + mbarrier_init(bar_mma_done_vloop_ptr + i, 1) + mbarrier_init_fence() + + # ====== Pipeline Definition ====== + pipeline_load_A = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.bar_load_A.data_ptr(), + num_stages=self.a_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_bytes_A, + ) + pipeline_load_dv = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.bar_load_dv.data_ptr(), + num_stages=self.vloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_bytes_dv, + ) + pipeline_load_h = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.bar_load_h.data_ptr(), + num_stages=self.vloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) + num_cuda_warps_total), + tx_count=self.tma_bytes_h, + ) + pipeline_load_dh = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.bar_load_dh.data_ptr(), + num_stages=self.vloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) + num_cuda_warps_total), + tx_count=self.tma_bytes_dh, + ) + # NOTE: UMMA as consumer to call tcgen05.commit + pipeline_load_do = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.bar_load_do.data_ptr(), + num_stages=self.vloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_bytes_do, + ) + pipeline_load_vnew = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.bar_load_vnew.data_ptr(), + num_stages=self.vloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_bytes_vnew, + ) + pipeline_load_g = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.bar_load_g.data_ptr(), + num_stages=self.kloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total), + tx_count=self.tma_bytes_g, + ) + pipeline_load_v = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.bar_load_v.data_ptr(), + num_stages=self.vloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) + num_cuda_warps_total), + tx_count=self.tma_bytes_v, + ) + pipeline_load_k = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.bar_load_k.data_ptr(), + num_stages=self.kloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total), + tx_count=self.tma_bytes_k, + ) + pipeline_load_q = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.bar_load_q.data_ptr(), + num_stages=self.kloop_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total), + tx_count=self.tma_bytes_q, + ) + pipeline_mma_dvb = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.bar_mma_dvb.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + pipeline_mma_dq = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_mma_dq.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + pipeline_mma_dk = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_mma_dk.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + pipeline_mma_dw = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_mma_dw.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + pipeline_mma_dA = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.bar_mma_dA.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + pipeline_mma_dA2 = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.bar_mma_dA2.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + pipeline_mma_dA3 = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.bar_mma_dA3.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + pipeline_prologue_dw = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_prologue_dw.data_ptr(), + num_stages=self.kloop_stage, + producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + ) + pipeline_prologue_kg = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_prologue_kg.data_ptr(), + num_stages=self.kloop_stage, + producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + ) + pipeline_prologue_dA2 = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_prologue_dA2.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + ) + pipeline_prologue_dA3 = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_prologue_dA3.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + ) + pipeline_mma_dkgb = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.bar_mma_dkgb.data_ptr(), + num_stages=self.mma_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + pipeline_load_beta = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_load_beta.data_ptr(), + num_stages=1, + producer_group=make_thread_cooperative_group(len(self.aux_warp_ids) * 32), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + ) + + # ===================== TMEM allocation ===================== + tmem_alloc_bar = pipeline.NamedBarrier(barrier_id=1, num_threads=self.threads_per_cta) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_bar, + allocator_warp_id=self.load_warp_id, + ) + # Cluster arrive after barrier init + pipeline.pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk, is_relaxed=True) + + vloop_opA_smem_no_stage = cute.select(vloop_opA_smem, mode=[0, 1, 2]) + vloop_opB_smem_no_stage = cute.select(vloop_opB_smem, mode=[0, 1, 2]) + A_mn_opA_smem_no_stage = cute.select(A_mn_opA_smem, mode=[0, 1, 2]) + dv_mn_opB_smem_no_stage = cute.select(dv_mn_opB_smem, mode=[0, 1, 2]) + v_opB_smem_no_stage = cute.select(v_opB_smem, mode=[0, 1, 2]) + + sA = storage.buf_A.get_tensor(A_mn_opA_smem.outer, swizzle=A_mn_opA_smem.inner) + sDv = storage.buf_dv.get_tensor(vloop_opA_smem.outer, swizzle=vloop_opA_smem.inner) + sH = storage.buf_h.get_tensor(vloop_opB_smem.outer, swizzle=vloop_opB_smem.inner) + sDh = storage.buf_dh.get_tensor(vloop_opB_smem.outer, swizzle=vloop_opB_smem.inner) + sDo = storage.buf_do.get_tensor(vloop_opA_smem.outer, swizzle=vloop_opA_smem.inner) + sVnew = storage.buf_vnew.get_tensor(vloop_opA_smem.outer, swizzle=vloop_opA_smem.inner) + sV = storage.buf_v.get_tensor(v_opB_smem.outer, swizzle=v_opB_smem.inner) + + sA_raw = cute.make_ptr( + self.io_dtype, storage.buf_A.data_ptr().toint(), cute.AddressSpace.smem, + ) + sDv_ptr_base = storage.buf_dv.data_ptr().toint() + vloop_opA_bytes_per_stage = cute.size_in_bytes(self.io_dtype, vloop_opA_smem_no_stage) + sDo_ptr_base = storage.buf_do.data_ptr().toint() + sVnew_ptr_base = storage.buf_vnew.data_ptr().toint() + sV_ptr_base = storage.buf_v.data_ptr().toint() + v_opB_bytes_per_stage = cute.size_in_bytes(self.io_dtype, v_opB_smem_no_stage) + sH_ptr_base = storage.buf_h.data_ptr().toint() + sDh_ptr_base = storage.buf_dh.data_ptr().toint() + vloop_opB_bytes_per_stage = cute.size_in_bytes(self.io_dtype, vloop_opB_smem_no_stage) + + # NOTE: make_umma_smem_desc requires the iterator to carry the swizzle + # (and ≥16B alignment). When constructing a tensor over a ComposedLayout + # via make_ptr+make_tensor, the swizzle ends up composed on the layout + # rather than the iterator, which breaks make_umma_smem_desc. Use + # recast_ptr to move the swizzle onto the iterator and pair it with the + # underlying (non-swizzle) outer layout. + sDv_mn = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_dv.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=dv_mn_opB_smem.inner, + dtype=self.io_dtype, + ), + dv_mn_opB_smem.outer, + ) + sDw_mn = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_dw.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=dw_mn_opB_smem.inner, + dtype=self.io_dtype, + ), + dw_mn_opB_smem.outer, + ) + sDw_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_dw.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=dw_k_opA_smem.inner, + dtype=self.io_dtype, + ), + dw_k_opA_smem.outer, + ) + sDv_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_dv.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=vloop_opA_smem.inner, + dtype=self.io_dtype, + ), + vloop_opA_smem.outer, + ) + sV_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_v.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=v_opB_smem.inner, + dtype=self.io_dtype, + ), + v_opB_smem.outer, + ) + sA_mn = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_A.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=A_mn_opA_smem.inner, + dtype=self.io_dtype, + ), + A_mn_opA_smem.outer, + ) + sDo_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr(self.io_dtype, storage.buf_do.data_ptr().toint(), cute.AddressSpace.smem, assumed_align=128), + swizzle_=vloop_opA_smem.inner, + dtype=self.io_dtype, + ), + vloop_opA_smem.outer, + ) + sVnew_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr(self.io_dtype, storage.buf_vnew.data_ptr().toint(), cute.AddressSpace.smem, assumed_align=128), + swizzle_=vloop_opA_smem.inner, + dtype=self.io_dtype, + ), + vloop_opA_smem.outer, + ) + sH_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr(self.io_dtype, storage.buf_h.data_ptr().toint(), cute.AddressSpace.smem, assumed_align=128), + swizzle_=vloop_opB_smem.inner, + dtype=self.io_dtype, + ), + vloop_opB_smem.outer, + ) + sDh_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr(self.io_dtype, storage.buf_dh.data_ptr().toint(), cute.AddressSpace.smem, assumed_align=128), + swizzle_=vloop_opB_smem.inner, + dtype=self.io_dtype, + ), + vloop_opB_smem.outer, + ) + sKG_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr(self.io_dtype, storage.buf_k.data_ptr().toint(), cute.AddressSpace.smem, assumed_align=128), + swizzle_=kg_k_opB_smem.inner, + dtype=self.io_dtype, + ), + kg_k_opB_smem.outer, + ) + sA_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_A.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=A_k_opB_smem.inner, + dtype=self.io_dtype, + ), + A_k_opB_smem.outer, + ) + sDA_mn = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr(self.io_dtype, storage.buf_q.data_ptr().toint(), cute.AddressSpace.smem, assumed_align=128), + swizzle_=dA_mn_opB_smem.inner, + dtype=self.io_dtype, + ), + dA_mn_opB_smem.outer, + ) + sDA_k = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_q.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=dA_k_opA_smem.inner, + dtype=self.io_dtype, + ), + dA_k_opA_smem.outer, + ) + sG_raw = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.g_dtype, + storage.buf_g.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=g_epi_smem_layout.inner, + dtype=self.g_dtype, + ), + g_epi_smem_layout.outer, + ) + sG_raw_ptr = cute.make_ptr( + self.g_dtype, storage.buf_g.data_ptr().toint(), cute.AddressSpace.smem + ) + sV_raw_ptr = cute.make_ptr( + self.io_dtype, storage.buf_v.data_ptr().toint(), cute.AddressSpace.smem + ) + sK_raw = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_k.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=k_epi_smem_layout.inner, + dtype=self.io_dtype, + ), + k_epi_smem_layout.outer, + ) + sK_raw_ptr = cute.make_ptr( + self.io_dtype, storage.buf_k.data_ptr().toint(), cute.AddressSpace.smem + ) + sDw_raw_ptr = cute.make_ptr( + self.io_dtype, storage.buf_dw.data_ptr().toint(), cute.AddressSpace.smem + ) + sQ_raw = cute.make_tensor( + cute.recast_ptr( + cute.make_ptr( + self.io_dtype, + storage.buf_q.data_ptr().toint(), + cute.AddressSpace.smem, + assumed_align=128, + ), + swizzle_=q_epi_smem_layout.inner, + dtype=self.io_dtype, + ), + q_epi_smem_layout.outer, + ) + sQ_raw_ptr = cute.make_ptr( + self.io_dtype, storage.buf_q.data_ptr().toint(), cute.AddressSpace.smem + ) + + # Scalar SMEM buffers (plain layouts, no swizzle) + sBeta = cute.make_tensor( + cute.make_ptr(Float32, storage.s_beta.data_ptr().toint(), cute.AddressSpace.smem), + cute.make_layout((self.BT, ), stride=(1, )), + ) + sDb = cute.make_tensor( + cute.make_ptr(Float32, storage.s_db.data_ptr().toint(), cute.AddressSpace.smem), + cute.make_layout((self.BT, ), stride=(1, )), + ) + sDgk = cute.make_tensor( + cute.make_ptr(Float32, storage.s_dgk.data_ptr().toint(), cute.AddressSpace.smem), + cute.make_layout((self.BK, ), stride=(1, )), + ) + sGn = cute.make_tensor( + cute.make_ptr(Float32, storage.s_gn.data_ptr().toint(), cute.AddressSpace.smem), + cute.make_layout((self.BK, ), stride=(1, )), + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline.pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) + + tmem.allocate(TMEM_TOTAL) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + # ===================== Warp dispatch ===================== + # CUDA Core loop body + if warp_idx in self.cuda_warp_ids or warp_idx in self.cuda2_warp_ids: + cute.arch.setmaxregister_increase(self.num_regs_cuda) + + load_beta_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, 1 + ) + load_h_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + load_dh_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + load_g_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kloop_stage + ) + load_v_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + mma_dvb_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + mma_dq_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + mma_dw_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + mma_dk_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + load_k_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kloop_stage + ) + prologue_dw_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kloop_stage + ) + prologue_kg_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kloop_stage + ) + mma_dgkb_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + load_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kloop_stage + ) + mma_dA_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + mma_dA2_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + mma_dA3_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + prologue_dA2_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + prologue_dA3_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + + wg_idx = tidx // 128 + local_tidx = tidx % 128 + sub_wg_idx = tidx // 64 + warp_id = local_tidx // 32 + warp_row_tile = warp_id % 2 + warp_col_tile = warp_id // 2 + row = warp_row_tile * 32 + lane_idx # BT1 + bk_num_cols = self.BK // 2 + bv_num_cols = self.BV // 2 + bk_num_cols_per_wg = bk_num_cols // 2 + bv_num_cols_per_wg = bv_num_cols // 2 + bt_num_cols_per_wg = self.BT // 4 + # ref: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e + bv_col_base = warp_col_tile * (self.BV // 2) + wg_idx * bv_num_cols_per_wg + bk_col_base = warp_col_tile * (self.BK // 2) + wg_idx * bk_num_cols_per_wg + bt_col_base = warp_col_tile * (self.BT // 2) + wg_idx * bt_num_cols_per_wg + # 8 fp32 store each time for store_256b + num_stores_f32 = bk_num_cols_per_wg // 8 + + vloop_stage_idx = 0 + for wu_iter in cutlass.range(0, num_iters, unroll=0): + work_idx = block_idx_x + wu_iter * grid_dim_x + i_t = work_idx // H # chunk index (global) + head_idx = work_idx % H # head index + # Decode chunk_indices + batch_idx = chunk_indices[(i_t, 0)] + tile_idx = chunk_indices[(i_t, 1)] + tok_offset = cu_seqlens[(batch_idx,)] + seq_len = cu_seqlens[(batch_idx + 1,)] - tok_offset + sub_seq_len = min(self.BT, seq_len - tile_idx * self.BT) + + # fill db, dgk to 0 + if local_tidx < self.BT: + sDb[local_tidx] = Float32(0.0) + if local_tidx < self.BK: + sDgk[local_tidx] = Float32(0.0) + self.cuda_wg_sync_barrier.arrive_and_wait() + + pipeline_load_beta.consumer_wait(load_beta_consumer_state) + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + + beta_val = sBeta[(row,)] + db_val = Float32(0.0) + for v_iter in cutlass.range(self.num_v_tiles): + # dgk += sum(h * dh, axis=0) + pipeline_load_h.consumer_wait(load_h_consumer_state) + pipeline_load_dh.consumer_wait(load_dh_consumer_state) + + sH_raw_ptr = cute.make_ptr( + self.io_dtype, sH_ptr_base + vloop_stage_idx * vloop_opB_bytes_per_stage, cute.AddressSpace.smem + ) + sDh_raw_ptr = cute.make_ptr( + self.io_dtype, sDh_ptr_base + vloop_stage_idx * vloop_opB_bytes_per_stage, cute.AddressSpace.smem + ) + # each thread in one WG processes one row + self.cuda_wg_sync_barrier.arrive_and_wait() + if wg_idx == 0: + for i in cutlass.range_constexpr(self.BV // 8): + col = i * 8 + h_vals = smem_load_bf16x8_sw128(sH_raw_ptr, local_tidx, col) + dh_vals = smem_load_bf16x8_sw128(sDh_raw_ptr, local_tidx, col) + h_dh_vals = cute.make_rmem_tensor((8,), Float32) + h_dh_vals.store(h_vals.load().to(Float32) * dh_vals.load().to(Float32)) + for j in cutlass.range_constexpr(8): + sDgk[(local_tidx,)] += h_dh_vals[j] + + pipeline_load_dh.consumer_release(load_dh_consumer_state) + load_dh_consumer_state.advance() + pipeline_load_h.consumer_release(load_h_consumer_state) + load_h_consumer_state.advance() + + pipeline_mma_dvb.consumer_wait(mma_dvb_consumer_state) + tcgen05_fence_after() + dvb_i32 = tcgen05_ld_32x32b(bv_num_cols_per_wg, TMEM_FLEX_OFF + wg_idx * bv_num_cols_per_wg) + tcgen05_fence_before() + cute.arch.fence_view_async_tmem_load() + + pipeline_mma_dvb.consumer_release(mma_dvb_consumer_state) + mma_dvb_consumer_state.advance() + + dvb_f32 = reinterpret_cast(dvb_i32, Int32, bv_num_cols_per_wg, Float32) + dvb_f32_val = TensorSSA(dvb_f32, (bv_num_cols_per_wg,), Float32) + + # db += sum(dvb * v, axis=1) + pipeline_load_v.consumer_wait(load_v_consumer_state) + rV_bf16 = cute.make_rmem_tensor((bv_num_cols_per_wg,), self.io_dtype) + sV_raw_ptr_cur = cute.make_ptr( + self.io_dtype, sV_ptr_base + vloop_stage_idx * v_opB_bytes_per_stage, cute.AddressSpace.smem + ) + if row < sub_seq_len: + for i in cutlass.range_constexpr(bv_num_cols_per_wg // 8): + col_base = bv_col_base + i * 8 + vals = smem_load_bf16x8_sw128(sV_raw_ptr_cur, row, col_base) + rV_bf16[i * 8 + 0] = vals[0] + rV_bf16[i * 8 + 1] = vals[1] + rV_bf16[i * 8 + 2] = vals[2] + rV_bf16[i * 8 + 3] = vals[3] + rV_bf16[i * 8 + 4] = vals[4] + rV_bf16[i * 8 + 5] = vals[5] + rV_bf16[i * 8 + 6] = vals[6] + rV_bf16[i * 8 + 7] = vals[7] + else: + rV_bf16.fill(BFloat16(0.0)) + rV_fp32 = cute.make_rmem_tensor((bv_num_cols_per_wg,), Float32) + rV_fp32.store(rV_bf16.load().to(Float32)) + rV_fp32.store(rV_fp32.load() * dvb_f32_val) + if row < sub_seq_len: + for i in cutlass.range_constexpr(bv_num_cols_per_wg): + db_val += rV_fp32[i] + + pipeline_load_v.consumer_release(load_v_consumer_state) + load_v_consumer_state.advance() + + # ── dv2 epilogue: dv2 = dvb * beta, cast to bf16, store to gmem ── + dvb_f32_rmem = cute.make_rmem_tensor((bv_num_cols_per_wg,), Float32) + dvb_f32_rmem.store(dvb_f32_val * beta_val) + + dvb_bf16_rmem = cute.make_rmem_tensor((bv_num_cols_per_wg,), self.io_dtype) + dvb_bf16_rmem.store(dvb_f32_rmem.load().to(self.io_dtype)) + + # bf16 vector → i32 vector for store_256b (8 i32 = 16 bf16 = 32 bytes per store). + dvb_bf16_val = dvb_bf16_rmem.load() + dvb_i32_vec = reinterpret_cast( + dvb_bf16_val, self.io_dtype, bv_num_cols_per_wg, Int32 + ) + # bv_num_cols bf16 = bv_num_cols // 16 stores of 256b each. + num_stores_per_row = bv_num_cols_per_wg // 16 # = 4 for BV=128 + + base_addr = ( + dv2_gmem.iterator + + (tok_offset + tile_idx * self.BT + row) * H * V + + head_idx * V + + v_iter * self.BV + + bv_col_base + ).toint() + if row < sub_seq_len: + for s in cutlass.range_constexpr(num_stores_per_row): + chunk = subvec(dvb_i32_vec, s * 8, 8) + store_256b(base_addr + s * 32, chunk) + + vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage + + # gk_exp = exp2(g) + pipeline_load_g.consumer_wait(load_g_consumer_state) + # write to gn + sGn[local_tidx] = sG_raw[(sub_seq_len - 1, local_tidx, 0)] + self.cuda_wg_sync_barrier.arrive_and_wait() + + # row-major load, match TMEM layout + rG = cute.make_rmem_tensor((self.BK // 4, ), self.g_dtype) + if row < sub_seq_len: + for i in cutlass.range_constexpr(self.BK // 4 // 4): + col_base = bk_col_base + i * 4 + vals = smem_load_f32x4_sw128(sG_raw_ptr, row, col_base) + rG[i * 4 + 0] = vals[0] + rG[i * 4 + 1] = vals[1] + rG[i * 4 + 2] = vals[2] + rG[i * 4 + 3] = vals[3] + else: + rG.fill(Float32(0.0)) + rG_val = rG.load() + rG_exp_val = cute.exp2(rG_val, fastmath=self.use_fast_math) + + # wait for dq, dq=dq*gk_exp*scale, GMEM store + pipeline_mma_dq.consumer_wait(mma_dq_consumer_state) + tcgen05_fence_after() + dq_i32 = tcgen05_ld_32x32b(bk_num_cols_per_wg, TMEM_DQ_ACC_OFF + wg_idx * bk_num_cols_per_wg) + tcgen05_fence_before() + cute.arch.fence_view_async_tmem_load() + + pipeline_mma_dq.consumer_release(mma_dq_consumer_state) + mma_dq_consumer_state.advance() + + dq_f32 = reinterpret_cast(dq_i32, Int32, bk_num_cols_per_wg, Float32) + dq_f32_val = TensorSSA(dq_f32, (bk_num_cols_per_wg,), Float32) + + rDq = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) + rDq.store(dq_f32_val * rG_exp_val * Float32(self.scale)) + + # TODO: store to smem first to reduce register usage + dq_f32_val_store = rDq.load() + dq_i32_vec = reinterpret_cast( + dq_f32_val_store, Float32, bk_num_cols_per_wg, Int32 + ) + dq_base_addr = ( + dq_gmem.iterator + + (tok_offset + tile_idx * self.BT + row) * H * K + + head_idx * K + + bk_col_base + ).toint() + if row < sub_seq_len: + for s in cutlass.range_constexpr(num_stores_f32): + chunk = subvec(dq_i32_vec, s * 8, 8) + store_256b(dq_base_addr + s * 32, chunk) + + # wait for dw + pipeline_mma_dw.consumer_wait(mma_dw_consumer_state) + tcgen05_fence_after() + dw_i32 = tcgen05_ld_32x32b(bk_num_cols_per_wg, TMEM_DW_ACC_OFF + wg_idx * bk_num_cols_per_wg) + tcgen05_fence_before() + cute.arch.fence_view_async_tmem_load() + + pipeline_mma_dw.consumer_release(mma_dw_consumer_state) + mma_dw_consumer_state.advance() + + # dw = -dw, convert to bf16, write to smem + dw_f32 = reinterpret_cast(dw_i32, Int32, bk_num_cols_per_wg, Float32) + dw_f32_val = TensorSSA(dw_f32, (bk_num_cols_per_wg,), Float32) + + dw_bf16_rmem = cute.make_rmem_tensor((bk_num_cols_per_wg,), BFloat16) + if row < sub_seq_len: + dw_bf16_rmem.store((-dw_f32_val).to(BFloat16)) + else: + dw_bf16_rmem.fill(BFloat16(0.0)) + + pipeline_prologue_dw.producer_acquire(prologue_dw_producer_state) + # store bf16x8 each time + dw_smem_num_stores = bk_num_cols_per_wg // 8 + for i in cutlass.range_constexpr(dw_smem_num_stores): + col_base = bk_col_base + i * 8 + chunk = cute.local_tile(dw_bf16_rmem, (8,), (i,)) + smem_store_bf16x8_sw128(sDw_raw_ptr, row, col_base, chunk) + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + pipeline_prologue_dw.producer_commit(prologue_dw_producer_state) + prologue_dw_producer_state.advance() + + pipeline_load_k.consumer_wait(load_k_consumer_state) + # compute kg = k * gk_exp + rK = cute.make_rmem_tensor((self.BK // 4, ), self.io_dtype) + if row < sub_seq_len: + for i in cutlass.range_constexpr(self.BK // 4 // 8): + col_base = bk_col_base + i * 8 + vals = smem_load_bf16x8_sw128(sK_raw_ptr, row, col_base) + rK[i * 8 + 0] = vals[0] + rK[i * 8 + 1] = vals[1] + rK[i * 8 + 2] = vals[2] + rK[i * 8 + 3] = vals[3] + rK[i * 8 + 4] = vals[4] + rK[i * 8 + 5] = vals[5] + rK[i * 8 + 6] = vals[6] + rK[i * 8 + 7] = vals[7] + else: + rK.fill(BFloat16(0.0)) + rK_fp32 = cute.make_rmem_tensor((self.BK // 4, ), Float32) + rK_fp32.store(rK.load().to(Float32)) + rK_fp32_val = rK_fp32.load() + rKG_val = rK_fp32_val * rG_exp_val + + # write kg to K smem, + # notify dA += dw @ kg^T + rKG_bf16 = cute.make_rmem_tensor((self.BK // 4, ), BFloat16) + rKG_bf16.store(rKG_val.to(BFloat16)) + + pipeline_prologue_kg.producer_acquire(prologue_kg_producer_state) + for i in cutlass.range_constexpr(self.BK // 4 // 8): + col_base = bk_col_base + i * 8 + chunk_kg = cute.local_tile(rKG_bf16, (8,), (i,)) + smem_store_bf16x8_sw128(sK_raw_ptr, row, col_base, chunk_kg) + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + pipeline_prologue_kg.producer_commit(prologue_kg_producer_state) + prologue_kg_producer_state.advance() + + # wait for dkgb + pipeline_mma_dkgb.consumer_wait(mma_dgkb_consumer_state) + tcgen05_fence_after() + dkgb_i32 = tcgen05_ld_32x32b(bk_num_cols_per_wg, TMEM_DKGB_ACC_OFF + wg_idx * bk_num_cols_per_wg) + tcgen05_fence_before() + cute.arch.fence_view_async_tmem_load() + + pipeline_mma_dkgb.consumer_release(mma_dgkb_consumer_state) + mma_dgkb_consumer_state.advance() + + # db += sum(dkgb * kg, axis=1) + dkgb_f32 = reinterpret_cast(dkgb_i32, Int32, bk_num_cols_per_wg, Float32) + dkgb_f32_val = TensorSSA(dkgb_f32, (bk_num_cols_per_wg,), Float32) + rKgb_kg = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) + rKgb_kg.store(dkgb_f32_val * rKG_val) + + if row < sub_seq_len: + for i in cutlass.range_constexpr(bk_num_cols_per_wg): + db_val += rKgb_kg[i] + # atomic add for each row of db + if row < sub_seq_len: + sDb_row_ptr = cute.make_ptr(Float32, (sDb.iterator + row).toint(), cute.AddressSpace.smem, assumed_align=4) + atomicAdd(sDb_row_ptr, db_val) + self.cuda_wg_sync_barrier.arrive_and_wait() + # store db to GMEM + if local_tidx < sub_seq_len: + db_gmem[(tok_offset + tile_idx * self.BT + local_tidx, (head_idx, Int32(0)))] = sDb[(local_tidx,)] + + # dk = dk * exp2(gn[None, :] - g) + pipeline_mma_dk.consumer_wait(mma_dk_consumer_state) + tcgen05_fence_after() + dk_i32 = tcgen05_ld_32x32b(bk_num_cols_per_wg, TMEM_DK_ACC_OFF + wg_idx * bk_num_cols_per_wg) + tcgen05_fence_before() + cute.arch.fence_view_async_tmem_load() + + pipeline_mma_dk.consumer_release(mma_dk_consumer_state) + mma_dk_consumer_state.advance() + + dk_f32 = reinterpret_cast(dk_i32, Int32, bk_num_cols_per_wg, Float32) + dk_f32_val = TensorSSA(dk_f32, (bk_num_cols_per_wg,), Float32) + + rDk = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) + if row < sub_seq_len: + for i in cutlass.range_constexpr(bk_num_cols_per_wg): + exp_g_gn = cute.exp2(sGn[(bk_col_base + i,)] - rG_val[i], fastmath=self.use_fast_math) + rDk[i] = dk_f32_val[i] * exp_g_gn + else: + rDk.fill(Float32(0.0)) + + # dgk *= exp2(gn) + self.cuda_wg_sync_barrier.arrive_and_wait() + if wg_idx == 0: + sDgk[(local_tidx, )] *= cute.exp2(sGn[(local_tidx,)], fastmath=self.use_fast_math) + + # kdk = k * dk + rKdk = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) + rKdk.store(rK_fp32.load() * rDk.load()) + + # dgk += sum(kdk, axis=0) + # write kdk to G SMEM then do BT-dim reduce + for i in cutlass.range_constexpr(self.BK // 4 // 4): + col_base = bk_col_base + i * 4 + chunk_kdk = cute.local_tile(rKdk, (4,), (i,)) + smem_store_f32x4_sw128(sG_raw_ptr, row, col_base, chunk_kdk) + self.cuda_wg_sync_barrier.arrive_and_wait() + + if wg_idx == 0: + sum = Float32(0.0) + for r in cutlass.range_constexpr(self.BT): + sum += sG_raw[(r, local_tidx, 0)] + sDgk[(local_tidx, )] += sum + + pipeline_load_g.consumer_release(load_g_consumer_state) + load_g_consumer_state.advance() + + # gb = gk_exp * beta[:, None] + rGb = cute.make_rmem_tensor((bk_num_cols_per_wg, ), Float32) + rGb.store(rG_exp_val * beta_val) + + # dk = dk + dkgb * gb + rDk.store(rDk.load() + dkgb_f32_val * rGb.load()) + rDk_val = rDk.load() + dk_i32_vec = reinterpret_cast( + rDk_val, Float32, bk_num_cols_per_wg, Int32 + ) + # GMEM store dk + # 8 fp32 store each time for store_256b + dk_base_addr = ( + dk_gmem.iterator + + (tok_offset + tile_idx * self.BT + row) * H * K + + head_idx * K + + bk_col_base + ).toint() + if row < sub_seq_len: + for s in cutlass.range_constexpr(num_stores_f32): + chunk_dk = subvec(dk_i32_vec, s * 8, 8) + store_256b(dk_base_addr + s * 32, chunk_dk) + + # dg1 = kg * dkgb * beta[:, None], can reuse kg RMEM + rDg = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) + rDg.store(rKG_val * dkgb_f32_val * beta_val) + + pipeline_load_q.consumer_wait(load_q_consumer_state) + # dg2 = q * dq - kdk + dg1 + rQ = cute.make_rmem_tensor((bk_num_cols_per_wg,), self.io_dtype) + if row < sub_seq_len: + for i in cutlass.range_constexpr(self.BK // 4 // 8): + col_base = bk_col_base + i * 8 + vals = smem_load_bf16x8_sw128(sQ_raw_ptr, row, col_base) + rQ[i * 8 + 0] = vals[0] + rQ[i * 8 + 1] = vals[1] + rQ[i * 8 + 2] = vals[2] + rQ[i * 8 + 3] = vals[3] + rQ[i * 8 + 4] = vals[4] + rQ[i * 8 + 5] = vals[5] + rQ[i * 8 + 6] = vals[6] + rQ[i * 8 + 7] = vals[7] + else: + rQ.fill(BFloat16(0.0)) + rDg.store(rQ.load().to(Float32) * dq_f32_val_store + rDg.load() - rKdk.load()) + + self.cuda_wg_sync_barrier.arrive_and_wait() + # dg = dg2 + m_last * dgk, GMEM store dg + if row == sub_seq_len - 1: + for i in cutlass.range_constexpr(bk_num_cols_per_wg): + col = bk_col_base + i + rDg[i] += sDgk[(col,)] + # NOTE: must sync before next wu_iter's `sDgk[local_tidx] = 0` + # init, otherwise WG0 of next iter may overwrite sDgk while + # WG1 of this iter (row == sub_seq_len - 1 lane) is still + # reading sDgk[col] above. This was the source of the + # non-deterministic dg accuracy bug. + self.cuda_wg_sync_barrier.arrive_and_wait() + + rDg_val = rDg.load() + dg_i32_vec = reinterpret_cast( + rDg_val, Float32, bk_num_cols_per_wg, Int32 + ) + # FIXME: after dg store, severe register spill with worse perf! + dg_base_addr = ( + dg_gmem.iterator + + (tok_offset + tile_idx * self.BT + row) * H * K + + head_idx * K + + bk_col_base + ).toint() + if row < sub_seq_len: + for s in cutlass.range_constexpr(num_stores_f32): + chunk_dg = subvec(dg_i32_vec, s * 8, 8) + store_256b(dg_base_addr + s * 32, chunk_dg) + + pipeline_mma_dA.consumer_wait(mma_dA_consumer_state) + tcgen05_fence_after() + dA_i32 = tcgen05_ld_32x32b(bt_num_cols_per_wg, TMEM_DA_ACC_OFF + wg_idx * bt_num_cols_per_wg) + tcgen05_fence_before() + cute.arch.fence_view_async_tmem_load() + + pipeline_mma_dA.consumer_release(mma_dA_consumer_state) + mma_dA_consumer_state.advance() + # NOTE: only release k smem after dA finished, because kg reuses k smem in dA += dw @ kg^T + pipeline_load_k.consumer_release(load_k_consumer_state) + load_k_consumer_state.advance() + + # dA = dA * beta[None, :], apply strict lower-triangular mask. + # Triton reference multiplies by the column beta (`b_beta[None, :]`) + # and keeps only `row > col`. + dA_f32 = reinterpret_cast(dA_i32, Int32, bt_num_cols_per_wg, Float32) + dA_f32_val = TensorSSA(dA_f32, (bt_num_cols_per_wg,), Float32) + rDA = cute.make_rmem_tensor((bt_num_cols_per_wg, ), BFloat16) + for i in cutlass.range_constexpr(bt_num_cols_per_wg): + col = bt_col_base + i + beta_col = sBeta[(col,)] + dA_scaled = (dA_f32_val[i] * beta_col).to(BFloat16) + if col < row: + rDA[i] = dA_scaled + else: + rDA[i] = BFloat16(0.0) + if row >= sub_seq_len: + rDA.fill(BFloat16(0.0)) + + pipeline_prologue_dA2.producer_acquire(prologue_dA2_producer_state) + + for i in cutlass.range_constexpr(bt_num_cols_per_wg // 8): + col_base = bt_col_base + i * 8 + chunk_dA = cute.local_tile(rDA, (8,), (i,)) + smem_store_bf16x8_sw128(sQ_raw_ptr, row, col_base, chunk_dA) + # notify dA2 = dA @ A + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + pipeline_prologue_dA2.producer_commit(prologue_dA2_producer_state) + prologue_dA2_producer_state.advance() + + pipeline_load_beta.consumer_release(load_beta_consumer_state) + load_beta_consumer_state.advance() + + # wait for dA2 + pipeline_mma_dA2.consumer_wait(mma_dA2_consumer_state) + tcgen05_fence_after() + dA2_i32 = tcgen05_ld_32x32b(bt_num_cols_per_wg, TMEM_DA2_ACC_OFF + wg_idx * bt_num_cols_per_wg) + tcgen05_fence_before() + cute.arch.fence_view_async_tmem_load() + + pipeline_mma_dA2.consumer_release(mma_dA2_consumer_state) + mma_dA2_consumer_state.advance() + + pipeline_prologue_dA3.producer_acquire(prologue_dA3_producer_state) + # write dA2 to smem notify dA2 = A @ dA2 + dA2_f32 = reinterpret_cast(dA2_i32, Int32, bt_num_cols_per_wg, Float32) + dA2_f32_val = TensorSSA(dA2_f32, (bt_num_cols_per_wg,), Float32) + rDA2 = cute.make_rmem_tensor((bt_num_cols_per_wg, ), BFloat16) + if row < sub_seq_len: + rDA2.store(dA2_f32_val.to(BFloat16)) + else: + rDA2.fill(BFloat16(0.0)) + for i in cutlass.range_constexpr(bt_num_cols_per_wg // 8): + col_base = bt_col_base + i * 8 + chunk_dA2 = cute.local_tile(rDA2, (8,), (i,)) + smem_store_bf16x8_sw128(sQ_raw_ptr, row, col_base, chunk_dA2) + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + pipeline_prologue_dA3.producer_commit(prologue_dA3_producer_state) + prologue_dA3_producer_state.advance() + + # wait for dA2 + pipeline_mma_dA3.consumer_wait(mma_dA3_consumer_state) + tcgen05_fence_after() + dA3_i32 = tcgen05_ld_32x32b(bt_num_cols_per_wg, TMEM_DA2_ACC_OFF + wg_idx * bt_num_cols_per_wg) + tcgen05_fence_before() + cute.arch.fence_view_async_tmem_load() + + pipeline_mma_dA3.consumer_release(mma_dA3_consumer_state) + mma_dA3_consumer_state.advance() + # NOTE: release smem Q because we reuse to store bf16 dA + pipeline_load_q.consumer_release(load_q_consumer_state) + load_q_consumer_state.advance() + + # dA = -dA, apply strict lower-triangular mask + dA3_f32 = reinterpret_cast(dA3_i32, Int32, bt_num_cols_per_wg, Float32) + dA3_f32_val = TensorSSA(dA3_f32, (bt_num_cols_per_wg,), Float32) + rDA3 = cute.make_rmem_tensor((bt_num_cols_per_wg, ), Float32) + rDA3.store(-dA3_f32_val) + for i in cutlass.range_constexpr(bt_num_cols_per_wg): + col = bt_col_base + i + if col >= row: + rDA3[i] = Float32(0.0) + rDA3_val = rDA3.load() + dA3_i32_vec = reinterpret_cast( + rDA3_val, Float32, bt_num_cols_per_wg, Int32 + ) + # GMEM store dA + num_stores_dA = bt_num_cols_per_wg // 8 + dA_base_addr = ( + dA_gmem.iterator + + (tok_offset + tile_idx * self.BT + row) * H * BT + + head_idx * BT + + bt_col_base + ).toint() + if row < sub_seq_len: + for s in cutlass.range_constexpr(num_stores_dA): + chunk_dA_store = subvec(dA3_i32_vec, s * 8, 8) + store_256b(dA_base_addr + s * 32, chunk_dA_store) + + # Load loop body + elif warp_idx == self.load_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_others) + + load_A_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.a_stage + ) + load_dv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.vloop_stage + ) + load_h_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.vloop_stage + ) + load_dh_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.vloop_stage + ) + load_do_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.vloop_stage + ) + load_vnew_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.vloop_stage + ) + load_g_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kloop_stage + ) + load_v_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.vloop_stage + ) + load_k_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kloop_stage + ) + load_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kloop_stage + ) + + vloop_stage_idx = 0 + for wu_iter in cutlass.range(0, num_iters, unroll=0): + work_idx = block_idx_x + wu_iter * grid_dim_x + i_t = work_idx // H # chunk index (global) + head_idx = work_idx % H # head index + + # Decode chunk_indices + batch_idx = chunk_indices[(i_t, 0)] + tile_idx = chunk_indices[(i_t, 1)] + tok_offset = cu_seqlens[(batch_idx,)] + seq_len = cu_seqlens[(batch_idx + 1,)] - tok_offset + sub_seq_len = min(self.BT, seq_len - tile_idx * self.BT) + + # Load A, TODO: double-buffer? + tma_A_v = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_A) + tAsA, tAgA = self._tma_partition_A( + tma_atom_A, + tma_A_v, + sA, + self.dvb_tiler, # [BT, BV, BT] + dvb_tiled_mma, + Int32(0), + head_idx, + ) + pipeline_load_A.producer_acquire(load_A_producer_state) + cute.copy( + tma_atom_A, + tAgA[(None, 0, tile_idx)], + tAsA[(None, load_A_producer_state.index)], + tma_bar_ptr=pipeline_load_A.producer_get_barrier(load_A_producer_state), + ) + load_A_producer_state.advance() + + # V-loop + for v_iter in cutlass.range(self.num_v_tiles): + tma_h_v = cute.domain_offset((0, v_iter * self.BV, (0, 0)), tma_tensor_h) + tHsH, tHgH = self._tma_partition_B( + tma_atom_h, + tma_h_v, + sH, + self.vloop_gemm_tiler, # [BT, BK, BV] + vloop_tiled_mma, + head_idx, i_t + ) + pipeline_load_h.producer_acquire(load_h_producer_state) + cute.copy( + tma_atom_h, + tHgH[(None, 0, 0)], + tHsH[(None, vloop_stage_idx)], + tma_bar_ptr=pipeline_load_h.producer_get_barrier(load_h_producer_state), + ) + load_h_producer_state.advance() + + tma_dh_v = cute.domain_offset((0, v_iter * self.BV, (0, 0)), tma_tensor_dh) + tDHsDH, tDHgDH = self._tma_partition_B( + tma_atom_dh, + tma_dh_v, + sDh, + self.vloop_gemm_tiler, # [BT, BK, BV] + vloop_tiled_mma, + head_idx, i_t + ) + pipeline_load_dh.producer_acquire(load_dh_producer_state) + cute.copy( + tma_atom_dh, + tDHgDH[(None, 0, 0)], + tDHsDH[(None, vloop_stage_idx)], + tma_bar_ptr=pipeline_load_dh.producer_get_barrier(load_dh_producer_state), + ) + load_dh_producer_state.advance() + + tma_do_v = cute.domain_offset((tok_offset, v_iter * self.BV, (0, 0)), tma_tensor_do) + tDOsDo, tDOgDo = self._tma_partition_A( + tma_atom_do, + tma_do_v, + sDo, + self.vloop_gemm_tiler, # [BT, BK, BV] + vloop_tiled_mma, + Int32(0), head_idx, + ) + pipeline_load_do.producer_acquire(load_do_producer_state) + cute.copy( + tma_atom_do, + tDOgDo[(None, tile_idx, 0)], + tDOsDo[(None, vloop_stage_idx)], + tma_bar_ptr=pipeline_load_do.producer_get_barrier(load_do_producer_state), + ) + load_do_producer_state.advance() + + tma_dv_v = cute.domain_offset((tok_offset, v_iter * self.BV, (0, 0)), tma_tensor_dv) + tDVsDv, tDVgDV = self._tma_partition_A( + tma_atom_dv, + tma_dv_v, + sDv, + self.vloop_gemm_tiler, # [BT, BK, BV] + vloop_tiled_mma, + Int32(0), head_idx, + ) + pipeline_load_dv.producer_acquire(load_dv_producer_state) + cute.copy( + tma_atom_dv, + tDVgDV[(None, tile_idx, 0)], + tDVsDv[(None, vloop_stage_idx)], + tma_bar_ptr=pipeline_load_dv.producer_get_barrier(load_dv_producer_state), + ) + load_dv_producer_state.advance() + + tma_v_v = cute.domain_offset((tok_offset, v_iter * self.BV, (0, 0)), tma_tensor_v) + tVsV, tVgV = self._tma_partition_B( + tma_atom_v, + tma_v_v, + sV, + self.dA_vloop_tiler, # [BT, BT, BV] + dA_vloop_tiled_mma, + Int32(0), head_idx, + ) + pipeline_load_v.producer_acquire(load_v_producer_state) + cute.copy( + tma_atom_v, + tVgV[(None, tile_idx, 0)], + tVsV[(None, vloop_stage_idx)], + tma_bar_ptr=pipeline_load_v.producer_get_barrier(load_v_producer_state), + ) + load_v_producer_state.advance() + + # load v_new + tma_vnew_v = cute.domain_offset((tok_offset, v_iter * self.BV, (0, 0)), tma_tensor_vnew) + tVnewsVnew, tVnewgVnew = self._tma_partition_A( + tma_atom_vnew, + tma_vnew_v, + sVnew, + self.vloop_gemm_tiler, # [BT, BK, BV] + vloop_tiled_mma, + Int32(0), head_idx, + ) + pipeline_load_vnew.producer_acquire(load_vnew_producer_state) + cute.copy( + tma_atom_vnew, + tVnewgVnew[(None, tile_idx, 0)], + tVnewsVnew[(None, vloop_stage_idx)], + tma_bar_ptr=pipeline_load_vnew.producer_get_barrier(load_vnew_producer_state), + ) + load_vnew_producer_state.advance() + + vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage + + # Load g + tma_g_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_g) + tGsG, tGgG = self._epilog_partition_varlen( + tma_atom_g, + tma_g_v[None, None, (head_idx, Int32(0))], + (self.BT, self.BK), + sG_raw, + ) + pipeline_load_g.producer_acquire(load_g_producer_state) + cute.copy( + tma_atom_g, + tGgG[(None, tile_idx, 0)], + tGsG[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 + tma_bar_ptr=pipeline_load_g.producer_get_barrier(load_g_producer_state), + ) + load_g_producer_state.advance() + + # Load k + tma_k_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_k) + tKsK, tKgK = self._epilog_partition_varlen( + tma_atom_k, + tma_k_v[None, None, (head_idx, Int32(0))], + (self.BT, self.BK), + sK_raw, + ) + pipeline_load_k.producer_acquire(load_k_producer_state) + cute.copy( + tma_atom_k, + tKgK[(None, tile_idx, 0)], + tKsK[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 + tma_bar_ptr=pipeline_load_k.producer_get_barrier(load_k_producer_state), + ) + load_k_producer_state.advance() + + tma_q_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_q) + tQsQ, tQgQ = self._epilog_partition_varlen( + tma_atom_q, + tma_q_v[None, None, (head_idx, Int32(0))], + (self.BT, self.BK), + sQ_raw, + ) + pipeline_load_q.producer_acquire(load_q_producer_state) + cute.copy( + tma_atom_q, + tQgQ[(None, tile_idx, 0)], + tQsQ[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 + tma_bar_ptr=pipeline_load_q.producer_get_barrier(load_q_producer_state), + ) + load_q_producer_state.advance() + + + # MMA loop body + elif warp_idx == self.mma_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_others) + + load_A_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.a_stage + ) + load_dv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + mma_dvb_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + load_h_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + load_dh_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + load_do_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + load_vnew_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + mma_dq_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + mma_dk_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + load_v_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.vloop_stage + ) + mma_dw_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + prologue_dw_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kloop_stage + ) + prologue_kg_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kloop_stage + ) + mma_dgkb_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + mma_dA_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + mma_dA2_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + mma_dA3_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_stage + ) + prologue_dA2_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + prologue_dA3_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_stage + ) + + vloop_stage_idx = 0 + a_stage_idx = 0 + mma_vloop_phase = 0 + for wu_iter in cutlass.range(0, num_iters, unroll=0): + work_idx = block_idx_x + wu_iter * grid_dim_x + i_t = work_idx // H # chunk index (global) + head_idx = work_idx % H # head index + + # Decode chunk_indices + batch_idx = chunk_indices[(i_t, 0)] + tile_idx = chunk_indices[(i_t, 1)] + tok_offset = cu_seqlens[(batch_idx,)] + seq_len = cu_seqlens[(batch_idx + 1,)] - tok_offset + sub_seq_len = min(self.BT, seq_len - tile_idx * self.BT) + + zeros8 = cute.make_rmem_tensor((8, ), dtype=self.io_dtype) + zeros8.fill(BFloat16(0.0)) + + pipeline_load_A.consumer_wait(load_A_consumer_state) + if sub_seq_len < self.BT: + for i in cutlass.range_constexpr(self.BT // 32): + row = i * 32 + lane_idx + if row >= sub_seq_len: + for col in cutlass.range_constexpr(self.BT // 8): + # A tile is MN_SW128 in shared memory; use raw swizzled + # address stores to avoid layout-coordinate ambiguity. + smem_store_bf16x8_sw128(sA_raw, row, col * 8, zeros8) + # Make generic-proxy SMEM stores visible to UMMA async-proxy readers. + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.mma_warp_sync_barrier.arrive_and_wait() + + for v_iter in cutlass.range(self.num_v_tiles): + is_accum = False if v_iter == 0 else True + pipeline_load_h.consumer_wait(load_h_consumer_state) + pipeline_load_do.consumer_wait(load_do_consumer_state) + sDo_raw_ptr = cute.make_ptr( + self.io_dtype, + sDo_ptr_base + vloop_stage_idx * vloop_opA_bytes_per_stage, + cute.AddressSpace.smem, + ) + if sub_seq_len < self.BT: + for i in cutlass.range_constexpr(self.BT // 32): + row = i * 32 + lane_idx + if row >= sub_seq_len: + for col in cutlass.range_constexpr(self.BV // 8): + # dv tile uses the same Swizzle<3,4,3> physical mapping. + smem_store_bf16x8_sw128(sDo_raw_ptr, row, col * 8, zeros8) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.mma_warp_sync_barrier.arrive_and_wait() + + if v_iter == 0: + pipeline_mma_dq.producer_acquire(mma_dq_producer_state) + + # dq+=do@h + sDo_k_cur = sDo_k[(None, None, None, vloop_stage_idx)] + sH_k_cur = sH_k[(None, None, None, vloop_stage_idx)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDo_k_cur.iterator, sDo_k_cur.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sH_k_cur.iterator, sH_k_cur.layout, "k")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n128_k_k_call(vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DQ_ACC_OFF, self.BV, is_accum) + + # TODO: should we add tcgen05.commit and mbar.wait to ensure current dq MMA has been finished? + + pipeline_load_do.consumer_release(load_do_consumer_state) + load_do_consumer_state.advance() + + if v_iter == self.num_v_tiles - 1: + pipeline_mma_dq.producer_commit(mma_dq_producer_state) + mma_dq_producer_state.advance() + + pipeline_load_dv.consumer_wait(load_dv_consumer_state) + sDv_raw = cute.make_ptr( + self.io_dtype, + sDv_ptr_base + vloop_stage_idx * vloop_opA_bytes_per_stage, + cute.AddressSpace.smem, + ) + if sub_seq_len < self.BT: + for i in cutlass.range_constexpr(self.BT // 32): + row = i * 32 + lane_idx + if row >= sub_seq_len: + for col in cutlass.range_constexpr(self.BV // 8): + # dv tile uses the same Swizzle<3,4,3> physical mapping. + smem_store_bf16x8_sw128(sDv_raw, row, col * 8, zeros8) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.mma_warp_sync_barrier.arrive_and_wait() + + # if lane_idx == 0: + # cute.printf("V_iter", v_iter) + # cute.print_tensor(sDv[None, None, None, vloop_stage_idx]) + pipeline_mma_dvb.producer_acquire(mma_dvb_producer_state) + sDv_mn_cur = sDv_mn[(None, None, None, vloop_stage_idx)] + sA_mn_cur = sA_mn[(None, None, None, a_stage_idx)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sA_mn_cur.iterator, sA_mn_cur.layout, "mn")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDv_mn_cur.iterator, sDv_mn_cur.layout, "mn")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n64_mn_mn_call(A_mn_opA_smem, desc_a_base, dv_mn_opB_smem, desc_b_base, TMEM_FLEX_OFF, self.BT) + + pipeline_mma_dvb.producer_commit(mma_dvb_producer_state) + mma_dvb_producer_state.advance() + + # dw += dv @ h + if v_iter == 0: + pipeline_mma_dw.producer_acquire(mma_dw_producer_state) + + sDv_k_cur = sDv_k[(None, None, None, vloop_stage_idx)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDv_k_cur.iterator, sDv_k_cur.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sH_k_cur.iterator, sH_k_cur.layout, "k")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n128_k_k_call(vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DW_ACC_OFF, self.BV, is_accum) + + # dA += dv @ v^T + pipeline_load_v.consumer_wait(load_v_consumer_state) + sV_raw = cute.make_ptr( + self.io_dtype, sV_ptr_base + vloop_stage_idx * v_opB_bytes_per_stage, cute.AddressSpace.smem + ) + if sub_seq_len < self.BT: + for i in cutlass.range_constexpr(self.BT // 32): + row = i * 32 + lane_idx + if row >= sub_seq_len: + for col in cutlass.range_constexpr(self.BV // 8): + # dv tile uses the same Swizzle<3,4,3> physical mapping. + smem_store_bf16x8_sw128(sV_raw, row, col * 8, zeros8) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.mma_warp_sync_barrier.arrive_and_wait() + if v_iter == 0: + pipeline_mma_dA.producer_acquire(mma_dA_producer_state) + + sV_k_cur = sV_k[(None, None, None, vloop_stage_idx)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDv_k_cur.iterator, sDv_k_cur.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sV_k_cur.iterator, sV_k_cur.layout, "k")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n64_k_k_call(vloop_opA_smem, desc_a_base, v_opB_smem, desc_b_base, TMEM_DA_ACC_OFF, self.BV, is_accum) + + # dv pipeline calls tcgen05.commit for dv@h and dv@v^T + pipeline_load_dv.consumer_release(load_dv_consumer_state) + load_dv_consumer_state.advance() + + if v_iter == self.num_v_tiles - 1: + pipeline_mma_dw.producer_commit(mma_dw_producer_state) + mma_dw_producer_state.advance() + + pipeline_load_h.consumer_release(load_h_consumer_state) + load_h_consumer_state.advance() + + pipeline_load_v.consumer_release(load_v_consumer_state) + load_v_consumer_state.advance() + + # dk += v_new @ dh + pipeline_load_vnew.consumer_wait(load_vnew_consumer_state) + sDvnew_raw_ptr = cute.make_ptr( + self.io_dtype, + sVnew_ptr_base + vloop_stage_idx * vloop_opA_bytes_per_stage, + cute.AddressSpace.smem, + ) + if sub_seq_len < self.BT: + for i in cutlass.range_constexpr(self.BT // 32): + row = i * 32 + lane_idx + if row >= sub_seq_len: + for col in cutlass.range_constexpr(self.BV // 8): + # dv tile uses the same Swizzle<3,4,3> physical mapping. + smem_store_bf16x8_sw128(sDvnew_raw_ptr, row, col * 8, zeros8) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.mma_warp_sync_barrier.arrive_and_wait() + + pipeline_load_dh.consumer_wait(load_dh_consumer_state) + if v_iter == 0: + pipeline_mma_dk.producer_acquire(mma_dk_producer_state) + + sVnew_k_cur = sVnew_k[(None, None, None, vloop_stage_idx)] + sDh_k_cur = sDh_k[(None, None, None, vloop_stage_idx)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sVnew_k_cur.iterator, sVnew_k_cur.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDh_k_cur.iterator, sDh_k_cur.layout, "k")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n128_k_k_call(vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DK_ACC_OFF, self.BV, is_accum) + + # vnew pipeline calls tcgen05.commit + pipeline_load_vnew.consumer_release(load_vnew_consumer_state) + load_vnew_consumer_state.advance() + + if v_iter == self.num_v_tiles - 1: + pipeline_mma_dk.producer_commit(mma_dk_producer_state) + mma_dk_producer_state.advance() + + pipeline_load_dh.consumer_release(load_dh_consumer_state) + load_dh_consumer_state.advance() + + # add tcgen05.commit and mbar.wait to make sure dq/dk/dw MMA finished + umma_arrive(bar_mma_done_vloop_ptr + 0) + mbarrier_wait(bar_mma_done_vloop_ptr + 0, mma_vloop_phase) + mma_vloop_phase ^= 1 + + vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage + + pipeline_prologue_dw.consumer_wait(prologue_dw_consumer_state) + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + # dkgb = A @ dw + pipeline_mma_dkgb.producer_acquire(mma_dgkb_producer_state) + sA_mn_cur = sA_mn[(None, None, None, a_stage_idx)] + sDw_mn_cur = sDw_mn[(None, None, None, 0)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sA_mn_cur.iterator, sA_mn_cur.layout, "mn")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDw_mn_cur.iterator, sDw_mn_cur.layout, "mn")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n128_mn_mn_call(A_mn_opA_smem, desc_a_base, dw_mn_opB_smem, desc_b_base, TMEM_DKGB_ACC_OFF, self.BT) + + pipeline_mma_dkgb.producer_commit(mma_dgkb_producer_state) + mma_dgkb_producer_state.advance() + + pipeline_prologue_kg.consumer_wait(prologue_kg_consumer_state) + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + # dA += dw @ kg^T + sDw_k_cur = sDw_k[(None, None, None, 0)] + sKG_k_cur = sKG_k[(None, None, None, 0)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDw_k_cur.iterator, sDw_k_cur.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sKG_k_cur.iterator, sKG_k_cur.layout, "k")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n64_k_k_call(dw_k_opA_smem, desc_a_base, kg_k_opB_smem, desc_b_base, TMEM_DA_ACC_OFF, self.BK, True) + + pipeline_mma_dA.producer_commit(mma_dA_producer_state) + mma_dA_producer_state.advance() + pipeline_prologue_kg.consumer_release(prologue_kg_consumer_state) + prologue_kg_consumer_state.advance() + + pipeline_prologue_dw.consumer_release(prologue_dw_consumer_state) + prologue_dw_consumer_state.advance() + + # dA2 = dA @ A + pipeline_mma_dA2.producer_acquire(mma_dA2_producer_state) + pipeline_prologue_dA2.consumer_wait(prologue_dA2_consumer_state) + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + + sDA_k_cur = sDA_k[(None, None, None, 0)] + sA_k_cur = sA_k[(None, None, None, 0)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDA_k_cur.iterator, sDA_k_cur.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sA_k_cur.iterator, sA_k_cur.layout, "k")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n64_k_k_call(dA_k_opA_smem, desc_a_base, A_k_opB_smem, desc_b_base, TMEM_DA2_ACC_OFF, self.BT) + + pipeline_mma_dA2.producer_commit(mma_dA2_producer_state) + mma_dA2_producer_state.advance() + pipeline_prologue_dA2.consumer_release(prologue_dA2_consumer_state) + prologue_dA2_consumer_state.advance() + + # dA3 = A @ dA2 + pipeline_mma_dA3.producer_acquire(mma_dA3_producer_state) + pipeline_prologue_dA3.consumer_wait(prologue_dA3_consumer_state) + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + + sA_mn_cur = sA_mn[(None, None, None, 0)] + sDA_mn_cur = sDA_mn[(None, None, None, 0)] + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sA_mn_cur.iterator, sA_mn_cur.layout, "mn")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDA_mn_cur.iterator, sDA_mn_cur.layout, "mn")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + mma_ws_ss_m64n64_mn_mn_call(A_mn_opA_smem, desc_a_base, dA_mn_opB_smem, desc_b_base, TMEM_DA2_ACC_OFF, self.BT) + + pipeline_mma_dA3.producer_commit(mma_dA3_producer_state) + mma_dA3_producer_state.advance() + pipeline_prologue_dA3.consumer_release(prologue_dA3_consumer_state) + prologue_dA3_consumer_state.advance() + + pipeline_load_A.consumer_release(load_A_consumer_state) + load_A_consumer_state.advance() + + a_stage_idx = (a_stage_idx + 1) % self.a_stage + + # Load aux loop body + elif warp_idx in self.aux_warp_ids: + cute.arch.setmaxregister_decrease(self.num_regs_others) + tidx = thread_idx - (self.threads_per_cta - 64) + + load_beta_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) + + for wu_iter in cutlass.range(0, num_iters, unroll=0): + work_idx = block_idx_x + wu_iter * grid_dim_x + i_t = work_idx // H # chunk index (global) + head_idx = work_idx % H # head index + + # Decode chunk_indices + batch_idx = chunk_indices[(i_t, 0)] + tile_idx = chunk_indices[(i_t, 1)] + tok_offset = cu_seqlens[(batch_idx,)] + seq_len = cu_seqlens[(batch_idx + 1,)] - tok_offset + sub_seq_len = min(self.BT, seq_len - tile_idx * self.BT) + + pipeline_load_beta.producer_acquire(load_beta_producer_state) + beta_f32 = Float32(0.0) + if tidx < sub_seq_len: + beta_f32 = Float32(beta_gmem[(tok_offset + tile_idx * self.BT + tidx, (head_idx, Int32(0)))]) + sBeta[(tidx, )] = beta_f32 + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + pipeline_load_beta.producer_commit(load_beta_producer_state) + load_beta_producer_state.advance() + + # ===================== TMEM cleanup ===================== + tmem.relinquish_alloc_permit() + self.tmem_dealloc_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr, TMEM_TOTAL) + + @cute.jit + def _tma_partition_A(self, tma_atom, tma_tensor, smem, tile_shape, tiled_mma, batch_idx, hidx): + """Partition a TMA tensor as MMA A-operand (M,K dims). + + ``tma_tensor`` should already have domain_offset applied for varlen. + + For tile_shape = (BT, BK, BV) = (M, N, K): + coord = (None, 0, None) — slices out the N-tile axis (mode 1) at 0, + leaving mode 0 (M=BT) and mode 2 (K=BV) free for TMA to iterate. + + Returns (tXsX, tXgX) — SMEM partition and GMEM coordinate partition. + """ + coord = (None, 0, None) + gX = cute.local_tile(tma_tensor, cute.slice_(tile_shape, coord), (None, None, (hidx, batch_idx))) + thr_mma = tiled_mma.get_slice(0) + tCgX = thr_mma.partition_A(gX) + tXsX, tXgX = cpasync.tma_partition( + tma_atom, + 0, + cute.make_layout(1), + cute.group_modes(smem, 0, 3), + cute.group_modes(tCgX, 0, 3), + ) + return tXsX, tXgX + + @cute.jit + def _tma_partition_B(self, tma_atom, tma_tensor, smem, tile_shape, tiled_mma, batch_idx, hidx): + """Partition a TMA tensor as MMA B-operand (N,K dims). + + Mirrors the identical helper in recompute_wu.py / fwd_o.py. + ``tma_tensor`` should already have domain_offset applied for varlen. + + For tile_shape = (BT, BK, BV) = (M, N, K): + coord = (0, None, None) — slices out the M-tile axis (mode 0) at 0, + leaving mode 1 (N=BK) and mode 2 (K=BV) free for TMA to iterate. + + Returns (tXsX, tXgX) — SMEM partition and GMEM coordinate partition. + """ + coord = (0, None, None) + gX = cute.local_tile(tma_tensor, cute.slice_(tile_shape, coord), (None, None, (hidx, batch_idx))) + thr_mma = tiled_mma.get_slice(0) + tCgX = thr_mma.partition_B(gX) + tXsX, tXgX = cpasync.tma_partition( + tma_atom, + 0, + cute.make_layout(1), + cute.group_modes(smem, 0, 3), + cute.group_modes(tCgX, 0, 3), + ) + return tXsX, tXgX + + @cute.jit + def _epilog_partition_varlen(self, atom, gC_2d, epi_tile, sC): + """Partition for varlen epilog TMA load (2D tensor with domain_offset). + + Uses local_tile instead of flat_divide to correctly preserve TMA basis + stride coordinates through domain_offset. Matches Flash Attention's + pattern: slice mode2 → domain_offset(2D) → local_tile → tma_partition. + + Uses (None, None) to keep all tile-count modes, producing the same + rank as _epilog_partition (flat_divide) so copy indexing is unchanged. + """ + gC_tiled = cute.local_tile(gC_2d, epi_tile, (None, None)) + sC_g = cute.group_modes(sC, 0, 2) + gC_g = cute.group_modes(gC_tiled, 0, 2) + bSG_sC, bSG_gC = cpasync.tma_partition( + atom, + 0, + cute.make_layout(1), + sC_g, + gC_g, + ) + return bSG_sC, bSG_gC + + +# ===================================================================== +# Compilation & Cache +# ===================================================================== + +_bwd_wy_kernel_cache: dict = {} + + +def _compile_bwd_wy_variant(H, K, V, scale, chunk_size, beta_dtype, use_fast_math): + """Compile one ChunkKdaBwdWyDqkgFused kernel variant. + + Uses make_fake_compact_tensor and make_fake_stream for compilation with + TVM-FFI. At runtime, torch tensors are passed directly (zero-copy). + Uses sym_int() for dynamic B, T, NT dimensions. + """ + kernel_obj = ChunkKdaBwdWyDqkgFused( + chunk_size=chunk_size, + head_dim_k=K, + head_dim_v=V, + scale=scale, + beta_dtype=beta_dtype, + use_fast_math=use_fast_math, + ) + + sym_b = cute.sym_int() # T (non-varlen) or T_total (varlen) + sym_nt = cute.sym_int() # NT_total + sym_cu = cute.sym_int() # cu_seqlens size + sym_ci = cute.sym_int() # chunk_indices rows + + BT = chunk_size + + # only support varlen for real-world use cases + # varlen: data tensors are [1, T_total, H, ...] + q_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) + k_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) + v_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) + vnew_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) + g_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) + beta_fake = make_fake_compact_tensor(beta_dtype, (1, sym_b, H), stride_order=(2, 1, 0), assumed_align=128) + A_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, BT), stride_order=(3, 2, 1, 0), assumed_align=128) + do_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) + dv_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) + + dq_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) + dk_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) + dv2_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) + dg_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) + db_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H), stride_order=(2, 1, 0), assumed_align=128) + dA_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, BT), stride_order=(3, 2, 1, 0), assumed_align=128) + + h_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_nt, H, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128) + dh_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_nt, H, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128) + + cu_fake = make_fake_compact_tensor(cutlass.Int32, (sym_cu,), assumed_align=128) + ci_fake = make_fake_compact_tensor(cutlass.Int32, (sym_ci, 2), stride_order=(1, 0), assumed_align=128) + stream_fake = make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_fn = cute.compile( + kernel_obj, + # Inputs + q_fake, + k_fake, + v_fake, + vnew_fake, + g_fake, + beta_fake, + A_fake, + h_fake, + do_fake, + dh_fake, + dv_fake, + # Outputs + dq_fake, + dk_fake, + dv2_fake, + dg_fake, + db_fake, + dA_fake, + # Metadata + cu_fake, + ci_fake, + (Int32(1), Int32(1), Int32(H), Int32(K), Int32(V)), + Int32(1), # total_nt dummy + stream_fake, + options=COMPILE_OPTIONS, + ) + return compiled_fn + + +def _get_compiled_bwd_wy(H, K, V, scale, chunk_size, beta_dtype): + """Get a compiled ChunkKdaBwdWyDqkgFused kernel with on-demand (lazy) compilation. + + Cache key: (H, K, V, scale, chunk_size, beta_dtype, USE_FAST_MATH) + """ + key = (H, K, V, scale, chunk_size, beta_dtype, USE_FAST_MATH) + if key not in _bwd_wy_kernel_cache: + _bwd_wy_kernel_cache[key] = _compile_bwd_wy_variant( + H, + K, + V, + scale, + chunk_size, + _torch_to_cutlass_dtype[beta_dtype], + USE_FAST_MATH, + ) + return _bwd_wy_kernel_cache[key] + + +# ===================================================================== +# Python API (FLA-compatible) +# ===================================================================== + +_bwd_wy_dummy_cu_seqlens = None +_bwd_wy_dummy_chunk_indices = None + + +def chunk_kda_bwd_wy_dqkg_fused( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dv: torch.Tensor, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + ChunkKdaBwdWyDqkgFused — FLA-compatible Python API. + + Computes backward gradients dq, dk, dv2, db, dg, dA for the KDA + chunkwise delta-rule backward pass using the CuTe DSL Blackwell kernel. + + Returns: + (dq, dk, dv2, db, dg, dA) matching FLA's chunk_kda_bwd_wy_dqkg_fused output order. + """ + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + + B, T, H, K = q.shape + V = v.shape[3] + BT = chunk_size + beta_dtype = beta.dtype + + if scale is None: + scale = K**-0.5 + + assert cu_seqlens is not None and chunk_indices is not None + # Ensure cu_seqlens is int32 + assert cu_seqlens.dtype == torch.int32, "cu_seqlens must be int32" + assert q.dim() == 4 and q.shape[0] == 1 + T_total = q.shape[1] + num_seqs = cu_seqlens.shape[0] - 1 + total_nt_val = chunk_indices.shape[0] + ps = (Int32(num_seqs), Int32(T_total), Int32(H), Int32(K), Int32(V)) + + # Allocate output tensors + dq = torch.empty_like(q, dtype=torch.float32) + dk = torch.empty_like(k, dtype=torch.float32) + dv2 = torch.empty_like(v) # bf16 + dg = torch.empty_like(g, dtype=torch.float32) + db = torch.empty(B, T, H, dtype=torch.float32, device=q.device) + dA = torch.empty(B, T, H, BT, dtype=torch.float32, device=q.device) + + compiled_fn = _get_compiled_bwd_wy( + H, + K, + V, + scale, + chunk_size, + beta_dtype, + ) + + # TVM-FFI call + compiled_fn( + # Inputs + q, + k, + v, + v_new, + g, + beta, + A, + h, + do, + dh, + dv, + # Outputs + dq, + dk, + dv2, + dg, + db, + dA, + # Metadata + cu_seqlens, + chunk_indices, + ps, + Int32(total_nt_val), + ) + + return dq, dk, dv2, db, dg, dA + + +# ===================================================================== +# Main (test entry point) +# ===================================================================== + + +def main(): + parser = argparse.ArgumentParser(description="Chunk KDA BWD WY DqKG Fused kernel test") + parser.add_argument("--B", type=int, default=1) + parser.add_argument("--T", type=int, default=64) + parser.add_argument("--H", type=int, default=1) + parser.add_argument("--K", type=int, default=128) + parser.add_argument("--V", type=int, default=128) + parser.add_argument("--scale", type=float, default=None) + parser.add_argument("--chunk_size", type=int, default=64) + args = parser.parse_args() + + if args.scale is None: + args.scale = args.K**-0.5 + B, T, H, K, V = args.B, args.T, args.H, args.K, args.V + BT = args.chunk_size + seq_lens = [63, 63, 63] + seq_lens = [64] + total_len = sum(seq_lens) + T = total_len + scale = args.scale + NT = (T + BT - 1) // BT + dtype, device = torch.bfloat16, "cuda" + cu_seqlens = torch.tensor(_exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + print(f"Config: B={B}, T={T}, H={H}, K={K}, V={V}, BT={BT}, scale={scale:.4f}") + print(f" Chunks per seq: {NT}, Total chunks: {B * NT}") + print(f" BK={64}, BV={64}, NK={K // 64}, NV={V // 64}") + + # Generate test data + torch.manual_seed(42) + q = torch.randn(B, T, H, K, dtype=dtype, device=device) + k = torch.randn(B, T, H, K, dtype=dtype, device=device) + v = torch.randn(B, T, H, V, dtype=dtype, device=device) + v_new = torch.randn(B, T, H, V, dtype=dtype, device=device) + g = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1 + beta = torch.randn(B, T, H, dtype=torch.bfloat16, device=device) + A = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 + h = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 + do_t = torch.randn(B, T, H, V, dtype=dtype, device=device) + dh = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 + dv = torch.randn(B, T, H, V, dtype=dtype, device=device) + + print("\n=== Compilation Test ===") + try: + dq, dk, dv2, db, dg, dA = chunk_kda_bwd_wy_dqkg_fused( + q=q, + k=k, + v=v, + v_new=v_new, + g=g, + beta=beta, + A=A, + h=h, + do=do_t, + dh=dh, + dv=dv, + cu_seqlens=cu_seqlens, + scale=scale, + chunk_size=BT, + ) + torch.cuda.synchronize() + # do_slice = do_t[0, :, 1, :].to(torch.float32) + # h_slice = h[0, 0, 1, :, :].to(torch.float32) + # dq_ref = do_slice @ h_slice.T + import pdb;pdb.set_trace() + # torch.testing.assert_close(dq_ref, dq[0,:,1,:], rtol=1e-2, atol=1e-2) + print(f" dq shape: {dq.shape}, dtype: {dq.dtype}") + print(f" dk shape: {dk.shape}, dtype: {dk.dtype}") + print(f" dv2 shape: {dv2.shape}, dtype: {dv2.dtype}") + print(f" dg shape: {dg.shape}, dtype: {dg.dtype}") + print(f" db shape: {db.shape}, dtype: {db.dtype}") + print(f" dA shape: {dA.shape}, dtype: {dA.dtype}") + except Exception as e: + import traceback + + print(f" ERROR: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/cula/ops/intrinsics_sm100.py b/cula/ops/intrinsics_sm100.py new file mode 100644 index 0000000..e1352f2 --- /dev/null +++ b/cula/ops/intrinsics_sm100.py @@ -0,0 +1,359 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NVVM wrappers for SM100 (Blackwell) Tensor Memory intrinsics. + +Provides low-level, CuteDSL-compatible helpers that move data between +Tensor Memory (TMEM) and registers / shared memory via the native +``nvvm.tcgen05.*`` MLIR ops. + +**T2R / R2T** – ``tcgen05.ld`` / ``tcgen05.st`` with ``.32x32b`` shape. +**S2T** – ``tcgen05.cp`` with ``.128x256b`` shape (SMEM → TMEM) +PTX reference +------------- + tcgen05.ld.sync.aligned.32x32b.xN.b32 {r0, ..., rN-1}, [taddr]; + tcgen05.st.sync.aligned.32x32b.xN.b32 [taddr], {r0, ..., rN-1}; + +where ``N ∈ {2, 4, 8, 16, 32, 64, 128}`` and each ``r`` is a 32-bit +register. ``taddr`` encodes both the TMEM column index (bits [15:0]) +and the lane index (bits [31:16]). + +See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld + +Usage inside a ``@cute.kernel`` or ``@cute.jit`` function:: + + from cula.ops.intrinsics_sm100 import ( + tcgen05_ld_32x32b, tcgen05_st_32x32b, + reinterpret_cast, subvec, store_256b, + ) + from cutlass.cute.typing import Float32, Int32 + + # Load 32 × 32-bit values from TMEM → opaque vector<32 x i32> + vec_i32 = tcgen05_ld_32x32b(32, taddr) + + # Zero-cost reinterpret as f32 (single vector.bitcast, no instructions) + vec_f32 = reinterpret_cast(vec_i32, Int32, 32, Float32) + + # Store to global via store_256b (4 × 256-bit stores) + # store_256b takes vector<8 x i32>, so reinterpret back and slice + vec_i32_back = reinterpret_cast(vec_f32, Float32, 32, Int32) + for chunk in range(4): # 32 / 8 = 4 chunks + store_256b(gmem_addr + chunk * 32, subvec(vec_i32_back, chunk * 8, 8)) + + # Store back to TMEM + tcgen05_st_32x32b(32, taddr, vec_i32_back) +""" + +__all__ = [ + "tcgen05_ld_32x32b", + "tcgen05_st_32x32b", + "tcgen05_cp_128x256b", + "reinterpret_cast", + "subvec", + "store_256b", +] + +import cutlass.cute as cute +from cutlass._mlir import ir as _ir_mod +from cutlass._mlir.dialects import arith as _arith +from cutlass._mlir.dialects import llvm +from cutlass._mlir.dialects import nvvm as _nvvm +from cutlass._mlir.dialects import vector as _vector +from cutlass.cute.typing import Int32 +from cutlass.cutlass_dsl import dsl_user_op + +from cula.ops.ptx_umma_ext import Tcgen05SmemDescriptor + + +def _to_ir(val, loc=None, ip=None): + """Extract raw MLIR IR value from a CuteDSL wrapper.""" + return val.ir_value(loc=loc, ip=ip) if hasattr(val, "ir_value") else val + + +# --------------------------------------------------------------------------- +# tcgen05.ld.sync.aligned.32x32b.xN.b32 (via nvvm.tcgen05.ld) +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05_ld_32x32b(num: int, taddr: int): + """Load *num* × 32-bit values from TMEM → an opaque ``vector``. + + ``num`` must be a **compile-time constant** in {2, 4, 8, 16, 32, 64, 128}. + Returns a single opaque MLIR vector value (``vector``). + + Use :func:`reinterpret_cast` to reinterpret the element type (zero-cost), + and :func:`subvec` to slice a contiguous sub-vector. + + Parameters + ---------- + num : int + Number of 32-bit registers to load. Must be a compile-time constant. + taddr : int + TMEM address (bits [31:16] = lane, bits [15:0] = column). + """ + + @dsl_user_op + def _do(addr_val, *, loc=None, ip=None): + i32_ty = _ir_mod.IntegerType.get_signless(32) + ptr6_ty = llvm.PointerType.get(address_space=6) + tmem_ptr = llvm.inttoptr(ptr6_ty, _to_ir(addr_val, loc, ip), loc=loc, ip=ip) + vec_i32_ty = _ir_mod.VectorType.get([num], i32_ty) + return _nvvm.tcgen05_ld( + res=vec_i32_ty, + shape=_nvvm.Tcgen05LdStShape.SHAPE_32X32B, + num=num, + tmem_addr=tmem_ptr, + loc=loc, + ip=ip, + ) + + return _do(Int32(taddr)) + + +# --------------------------------------------------------------------------- +# tcgen05.st.sync.aligned.32x32b.xN.b32 (via nvvm.tcgen05.st) +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05_st_32x32b(num: int, taddr: int, vec): + """Store *num* × 32-bit values from an opaque vector → TMEM. + + ``num`` must be a **compile-time constant** in {2, 4, 8, 16, 32, 64, 128}. + + Parameters + ---------- + num : int + Number of 32-bit registers to store. Must be a compile-time constant. + taddr : int + TMEM address (bits [31:16] = lane, bits [15:0] = column). + vec : opaque vector + An opaque ``vector`` value (from :func:`tcgen05_ld_32x32b` + or :func:`reinterpret_cast`). + """ + + @dsl_user_op + def _do(addr_val, vec_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + tmem_ptr = llvm.inttoptr(ptr6_ty, _to_ir(addr_val, loc, ip), loc=loc, ip=ip) + _nvvm.tcgen05_st( + shape=_nvvm.Tcgen05LdStShape.SHAPE_32X32B, + num=num, + tmem_addr=tmem_ptr, + r=_to_ir(vec_val, loc, ip), + loc=loc, + ip=ip, + ) + + _do(Int32(taddr), vec) + + +# --------------------------------------------------------------------------- +# reinterpret_cast (zero-cost vector.bitcast) +# --------------------------------------------------------------------------- + + +@cute.jit +def reinterpret_cast(vec, src_type, src_num, tgt_type): + """Zero-cost reinterpret of a vector's element type (single ``vector.bitcast``). + + Analogous to C++ ``reinterpret_cast``: no instructions emitted, just + re-labels the bits. The total bit-width is preserved: + ``src_num * src_type.width == tgt_num * tgt_type.width``. + + Parameters + ---------- + vec : opaque vector + Source vector (e.g. ``vector`` from :func:`tcgen05_ld_32x32b`). + src_type : CuTeDSL type + Element type of *vec* (e.g. ``Int32``). + src_num : int + Number of elements in *vec* (compile-time constant). + tgt_type : CuTeDSL type + Desired element type (e.g. ``Float32``, ``BFloat16``, ``Float16``). + + Returns + ------- + opaque vector + ``vector`` where ``M = src_num * src_type.width // tgt_type.width``. + + Examples + -------- + :: + + vec_i32 = tcgen05_ld_32x32b(8, taddr) # vector<8 x i32> + vec_f32 = reinterpret_cast(vec_i32, Int32, 8, Float32) # vector<8 x f32> + vec_bf16 = reinterpret_cast(vec_i32, Int32, 8, BFloat16) # vector<16 x bf16> + vec_back = reinterpret_cast(vec_bf16, BFloat16, 16, Int32) # vector<8 x i32> + """ + tgt_num = src_num * src_type.width // tgt_type.width + + @dsl_user_op + def _do(v, *, loc=None, ip=None): + tgt_vec_ty = _ir_mod.VectorType.get([tgt_num], tgt_type.mlir_type) + return _vector.bitcast(tgt_vec_ty, _to_ir(v, loc, ip), loc=loc, ip=ip) + + return _do(vec) + + +# --------------------------------------------------------------------------- +# subvec (extract a contiguous sub-vector) +# --------------------------------------------------------------------------- + + +@cute.jit +def subvec(vec, offset, size): + """Extract a contiguous sub-vector (``vector.extract_strided_slice``). + + Parameters + ---------- + vec : opaque vector + Source vector. + offset : int + Starting element index (compile-time constant). + size : int + Number of elements to extract (compile-time constant). + + Returns + ------- + opaque vector + ``vector``. + """ + + @dsl_user_op + def _do(v, *, loc=None, ip=None): + ir_v = _to_ir(v, loc, ip) + elem_ty = _ir_mod.VectorType(ir_v.type).element_type + res_ty = _ir_mod.VectorType.get([size], elem_ty) + return _vector.extract_strided_slice( + res_ty, + ir_v, + offsets=[offset], + sizes=[size], + strides=[1], + loc=loc, + ip=ip, + ) + + return _do(vec) + + +# --------------------------------------------------------------------------- +# st.global.L1::no_allocate.v8.f32 (256-bit direct R2G store) +# --------------------------------------------------------------------------- + +_STORE_256B_ASM = "st.global.L1::no_allocate.v8.f32 [$0], {$1, $2, $3, $4, $5, $6, $7, $8};" +_STORE_256B_CONSTRAINTS = "l,r,r,r,r,r,r,r,r" + + +@cute.jit +def store_256b(gmem_ptr, vec): + """Store 256 bits (8 × 32-bit) to global memory, bypassing L1 allocation. + + Issues ``st.global.L1::no_allocate.v8.f32`` with ``"r"`` (integer register) + constraints — type-agnostic, just like C++ ``reinterpret_cast``. + + Parameters + ---------- + gmem_ptr : pointer + Global-memory destination address (must be 32-byte aligned). + vec : opaque vector + A ``vector<8 x i32>`` (use :func:`subvec` to slice from a larger vector). + """ + + @dsl_user_op + def _do(addr, v, *, loc=None, ip=None): + i32_ty = _ir_mod.IntegerType.get_signless(32) + ir_v = _to_ir(v, loc, ip) + elems = [ + _vector.extractelement( + ir_v, + position=_arith.constant(i32_ty, i, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + for i in range(8) + ] + operands = [_to_ir(addr, loc, ip)] + elems + llvm.inline_asm( + _ir_mod.Type.parse("!llvm.void"), + operands, + _STORE_256B_ASM, + _STORE_256B_CONSTRAINTS, + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + _do(gmem_ptr, vec) + + +# --------------------------------------------------------------------------- +# tcgen05.cp.cta_group::1.128x256b (via nvvm.tcgen05.cp) +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05_cp_128x256b(taddr: int, smem_desc: Tcgen05SmemDescriptor): + """Async copy SMEM → TMEM with shape ``128x256b`` (``cta_group::1``). + + Issues ``tcgen05.cp.cta_group::1.128x256b [taddr], s-desc;`` + via the native ``nvvm.tcgen05.cp`` MLIR op. + + The instruction copies a 128-row × 256-bit tile from shared memory + (described by *smem_desc*) into Tensor Memory at *taddr*. The copy + is **asynchronous** — use ``tcgen05.commit`` + ``mbarrier.wait`` to + synchronize. + + PTX reference + ------------- + tcgen05.cp.cta_group::1.128x256b [taddr], s-desc; + + Parameters + ---------- + taddr : int + TMEM destination address (uint32, passed as ``!llvm.ptr<6>``). + smem_desc : Tcgen05SmemDescriptor + 64-bit SMEM matrix descriptor (same format as ``tcgen05.mma`` + descriptors — see ``Tcgen05SmemDescriptor``). + """ + + @dsl_user_op + def _do(addr_val, desc_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + tmem_ptr = llvm.inttoptr(ptr6_ty, _to_ir(addr_val, loc, ip), loc=loc, ip=ip) + _nvvm.tcgen05_cp( + shape=_nvvm.Tcgen05CpShape.SHAPE_128x256b, + taddr=tmem_ptr, + smem_desc=_to_ir(desc_val, loc, ip), + cta_group=_nvvm.Tcgen05GroupKind.CTA_1, + loc=loc, + ip=ip, + ) + + _do(Int32(taddr), smem_desc.desc_i64[0]) + +@cute.jit +def tcgen05_fence_before(): + """tcgen05.fence::before_thread_sync — non-blocking ordering fence.""" + _nvvm.tcgen05_fence(kind=_nvvm.Tcgen05FenceKind.BEFORE_THREAD_SYNC) + + +@cute.jit +def tcgen05_fence_after(): + """tcgen05.fence::after_thread_sync — non-blocking ordering fence.""" + _nvvm.tcgen05_fence(kind=_nvvm.Tcgen05FenceKind.AFTER_THREAD_SYNC) \ No newline at end of file diff --git a/cula/ops/ptx_umma_ext.py b/cula/ops/ptx_umma_ext.py new file mode 100644 index 0000000..4426361 --- /dev/null +++ b/cula/ops/ptx_umma_ext.py @@ -0,0 +1,958 @@ +"""CuteDSL UMMA extension wrappers for SM100 (Blackwell) ``tcgen05.mma``. + +CuteDSL's high-level ``cute.gemm()`` / ``make_tiled_mma()`` API does not +expose all ``tcgen05.mma`` instruction variants. This module provides +low-level wrappers for the two categories currently needed: + +1. **Masked MMA** – SS and TS forms with the 128-bit ``disable-output-lane`` + mask operand (``{m0, m1, m2, m3}``). Implemented via the native + ``nvvm.tcgen05_mma`` MLIR op with its ``write_disable_mask`` parameter + (``vector<4xi32>``). + +2. **Weight-stationary (WS) MMA** – ``tcgen05.mma.ws`` SS / TS forms for + both ``kind::tf32`` and ``kind::f16``. Implemented via + ``llvm.inline_asm``. + +---------------------------------------------------------------------- +PTX instruction forms +---------------------------------------------------------------------- +SS (SMEM A, SMEM B): + tcgen05.mma.cta_group::1.kind::tf32 [tmem_c], desc_a, desc_b, + desc_val, {m0,m1,m2,m3}, p; + +TS (TMEM A, SMEM B): + tcgen05.mma.cta_group::1.kind::tf32 [tmem_c], [tmem_a], desc_b, + desc_val, {m0,m1,m2,m3}, p; + +WS_SS (weight-stationary, SMEM A, SMEM B): + tcgen05.mma.ws.cta_group::1.kind::tf32 [tmem_c], desc_a, desc_b, + desc_val, p; + tcgen05.mma.ws.cta_group::1.kind::f16 [tmem_c], desc_a, desc_b, + desc_val, p; + +WS_TS (weight-stationary, TMEM A, SMEM B): + tcgen05.mma.ws.cta_group::1.kind::tf32 [tmem_c], [tmem_a], desc_b, + desc_val, p; + tcgen05.mma.ws.cta_group::1.kind::f16 [tmem_c], [tmem_a], desc_b, + desc_val, p; + +---------------------------------------------------------------------- +Disable-output-lane mask layout (4 × uint32 = 128 bits) +---------------------------------------------------------------------- +Each uint32 covers 32 M-dimension rows (8 rows × 4 elements per group). + 0x00000000 → group is ACTIVE (output written) + 0xFFFFFFFF → group is DISABLED (output suppressed) + +Predefined SS mask constants (SMEM A variants): + SS_NO_MASK = (0, 0, 0, 0) all rows active + SS_MASK0 = (0, 0xFF…, 0, 0xFF…) odd groups disabled + SS_MASK1 = (0xFF…, 0, 0xFF…, 0) even groups disabled + SS_MASK2 = (0xFF…, 0xFF…, 0, 0xFF…) group 2 only active + SS_MASK3 = (0xFF…, 0xFF…, 0xFF…, 0) group 3 only active + +Predefined TS mask constants (TMEM A variants): + TS_NO_MASK = (0, 0, 0, 0) all rows active + TS_MASK0 = (0, 0xFF…, 0xFF…, 0xFF…) group 0 only active + TS_MASK1 = (0xFF…, 0, 0xFF…, 0xFF…) group 1 only active + TS_MASK2 = (0xFF…, 0xFF…, 0, 0xFF…) group 2 only active + TS_MASK3 = (0xFF…, 0xFF…, 0xFF…, 0) group 3 only active + TS_MASK02 = (0, 0xFF…, 0, 0xFF…) groups 0,2 only active + TS_MASK13 = (0xFF…, 0, 0xFF…, 0) groups 1,3 only active + +Public API (all decorated with @cute.jit) +---------------------------------------------------------------------- +Descriptor helpers (call inside @cute.jit): + Tcgen05SmemDescriptor — 64-bit SMEM descriptor object + initialize_tcgen05_descriptor — fill descriptor bitfields + +Low-level primitives (pass mask words explicitly): + tcgen05mma_ss(desc_a, desc_b, tmem_c, desc_val, scale_out, + mask0, mask1, mask2, mask3) + tcgen05mma_ts(tmem_a, desc_b, tmem_c, desc_val, scale_out, + mask0, mask1, mask2, mask3) + tcgen05mma_ws_ss_tf32(desc_a, desc_b, tmem_c, desc_val, scale_out) + tcgen05mma_ws_ts_tf32(tmem_a, desc_b, tmem_c, desc_val, scale_out) + tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, desc_val, scale_out) + tcgen05mma_ws_ts_f16(tmem_a, desc_b, tmem_c, desc_val, scale_out) + +Named convenience wrappers (pre-set masks, pass only MMA operands): + tcgen05mma_ss_no_mask / tcgen05mma_ss_mask0 / …mask1 / …mask2 / …mask3 + tcgen05mma_ts_no_mask / tcgen05mma_ts_mask0 / …mask1 / …mask2 / …mask3 + tcgen05mma_ts_mask02 / tcgen05mma_ts_mask13 +""" + +__all__ = [ + # descriptor helpers + "Tcgen05SmemDescriptor", + "initialize_tcgen05_descriptor", + # low-level primitives + "tcgen05mma_ss", + "tcgen05mma_ts", + "tcgen05mma_ws_ss_tf32", + "tcgen05mma_ws_ts_tf32", + "tcgen05mma_ws_ss_f16", + "tcgen05mma_ws_ts_f16", + # SS named wrappers + "tcgen05mma_ss_no_mask", + "tcgen05mma_ss_mask0", + "tcgen05mma_ss_mask1", + "tcgen05mma_ss_mask2", + "tcgen05mma_ss_mask3", + # TS named wrappers + "tcgen05mma_ts_no_mask", + "tcgen05mma_ts_mask0", + "tcgen05mma_ts_mask1", + "tcgen05mma_ts_mask2", + "tcgen05mma_ts_mask3", + "tcgen05mma_ts_mask02", + "tcgen05mma_ts_mask13", + # collector enums (re-exported for convenience) + "CollectorBBuffer", + "CollectorOp", +] + +import cutlass +import cutlass.cute as cute +from cutlass._mlir import ir +from cutlass._mlir.dialects import arith as _arith +from cutlass._mlir.dialects import llvm +from cutlass._mlir.dialects import nvvm as _nvvm +from cutlass.cutlass_dsl import dsl_user_op + +# Re-export collector enums for caller convenience. +CollectorBBuffer = _nvvm.Tcgen05MMACollectorBBuffer +CollectorOp = _nvvm.Tcgen05MMACollectorOp + +# --------------------------------------------------------------------------- +# Mask constants (4 × uint32). 0 = ACTIVE, 0xFFFFFFFF = DISABLED. +# --------------------------------------------------------------------------- +_ALL_ACTIVE = 0x00000000 +_ALL_OFF = 0xFFFFFFFF + +# SS masks (SMEM A, SMEM B) +SS_NO_MASK = (_ALL_ACTIVE, _ALL_ACTIVE, _ALL_ACTIVE, _ALL_ACTIVE) +SS_MASK0 = (_ALL_ACTIVE, _ALL_OFF, _ALL_ACTIVE, _ALL_OFF) # {0,F,0,F} +SS_MASK1 = (_ALL_OFF, _ALL_ACTIVE, _ALL_OFF, _ALL_ACTIVE) # {F,0,F,0} +SS_MASK2 = (_ALL_OFF, _ALL_OFF, _ALL_ACTIVE, _ALL_OFF) # {F,F,0,F} +SS_MASK3 = (_ALL_OFF, _ALL_OFF, _ALL_OFF, _ALL_ACTIVE) # {F,F,F,0} + +# TS masks (TMEM A, SMEM B) +TS_NO_MASK = (_ALL_ACTIVE, _ALL_ACTIVE, _ALL_ACTIVE, _ALL_ACTIVE) +TS_MASK0 = (_ALL_ACTIVE, _ALL_OFF, _ALL_OFF, _ALL_OFF) # {0,F,F,F} +TS_MASK1 = (_ALL_OFF, _ALL_ACTIVE, _ALL_OFF, _ALL_OFF) # {F,0,F,F} +TS_MASK2 = (_ALL_OFF, _ALL_OFF, _ALL_ACTIVE, _ALL_OFF) # {F,F,0,F} +TS_MASK3 = (_ALL_OFF, _ALL_OFF, _ALL_OFF, _ALL_ACTIVE) # {F,F,F,0} +TS_MASK02 = (_ALL_ACTIVE, _ALL_OFF, _ALL_ACTIVE, _ALL_OFF) # {0,F,0,F} +TS_MASK13 = (_ALL_OFF, _ALL_ACTIVE, _ALL_OFF, _ALL_ACTIVE) # {F,0,F,0} + + +# --------------------------------------------------------------------------- +# Tcgen05SmemDescriptor — 64-bit SMEM descriptor stored as 2×Int32 +# --------------------------------------------------------------------------- + + +class Tcgen05SmemDescriptor: + """64-bit shared-memory descriptor for tcgen05 MMA (Blackwell / SM100). + + The descriptor encodes SMEM base address, leading/stride byte offsets, + swizzle mode, and other fields required by the ``tcgen05.mma`` PTX + instruction to locate a matrix tile in shared memory. + + 64-bit layout (PTX ISA Table 40):: + + Bit 63 Bit 0 + ┌──────────┬────────┬─────┬──────────┬────┬──────────┬──────┬──────────────┐ + │ 63 61 │ 60 53 │ 52 │ 51 49 │ 48 │ 45 32 │31 30 │ 29 16│15 14│ 13 0│ + │layout_typ│ reservd│l_abs│base_offst│ 46 │ SBO │ rsvd │ LBO │rsvd │start_adr│ + │ (3 bit) │ (8 bit)│(1b) │ (3 bit) │=0b001│(14 bit)│(2 b) │(14 bit)│(2b) │(14 bit) │ + └──────────┴────────┴─────┴──────────┴────┴──────────┴──────┴────────┴─────┴─────────┘ + + Field descriptions: + + - **start_address** [bits 0-13]: SMEM base pointer, encoded as + ``smem_ptr >> 4`` (16-byte aligned). The hardware reconstructs the + full address as ``encoded_value << 4``. + + - **LBO** (Leading Byte Offset) [bits 16-29]: distance in bytes between + consecutive elements along the leading dimension, encoded as + ``lbo_bytes >> 4``. When ``lbo_mode=1`` this is an absolute byte + address rather than a relative offset. + + - **SBO** (Stride Byte Offset) [bits 32-45]: distance in bytes between + consecutive elements along the stride dimension, encoded as + ``sbo_bytes >> 4``. + + - **version** [bits 46-48]: fixed constant ``0b001`` (= 1). + + - **base_offset** [bits 49-51]: 3-bit alignment correction when the + SMEM tile does not start at a natural swizzle-pattern boundary + (1024B for 128B swizzle, 512B for 64B, 256B for 32B). + Computed as ``(start_addr >> 7) & 0x7``. Usually 0. + + - **lbo_mode** (leading_abs) [bit 52]: 0 → LBO is a relative byte + offset; 1 → LBO is an absolute byte address. + + - **layout_type** (swizzle_mode) [bits 61-63]: + - 0 = SWIZZLE_NONE + - 1 = SWIZZLE_128B_BASE32B (128-byte pattern, 32-byte atom) + - 2 = SWIZZLE_128B (128-byte pattern) + - 4 = SWIZZLE_64B (64-byte pattern) + - 6 = SWIZZLE_32B (32-byte pattern) + + Storage: two Int32 registers (desc[0] = low 32 bits, desc[1] = high 32 + bits), recast to a single Int64 for the PTX ``l``-constraint operand. + + Usage inside a @cute.jit kernel:: + + desc = Tcgen05SmemDescriptor() + initialize_tcgen05_descriptor(desc, smem_ptr, lbo, sbo, 0, True, swizzle) + """ + + def __init__(self, desc_64: cute.Int64 = None): + # desc[0]: low 32 bits → start_address[0:14] | LBO[16:30] + # desc[1]: high 32 bits → SBO[0:14] | version[14:16] | base_offset[17:20] + # | lbo_mode[20] | layout_type[29:32] + self.desc = cute.make_rmem_tensor((2,), dtype=cutlass.Int32) + # Alias the 2×i32 as 1×i64 for PTX "l" constraint (64-bit operand) + self.desc_i64 = cute.make_tensor(cute.recast_ptr(self.desc.iterator, dtype=cute.Int64), (1,)) + if desc_64 is not None: + self.desc_i64[0] = desc_64 + + def __add__(self, byte_offset): + """Return a new descriptor offset by ``byte_offset`` bytes. + + Only the start_address field (bits 0-13 of desc[0]) is modified. + Since it is stored in 16-byte units, we add ``byte_offset >> 4``. + All other fields (LBO, SBO, swizzle, etc.) are copied unchanged. + """ + res = cute.make_rmem_tensor((2,), dtype=cutlass.Int32) + res_i64 = cute.make_tensor(cute.recast_ptr(res.iterator, dtype=cute.Int64), (1,)) + res[0] = self.desc[0] + (byte_offset >> 4) # adjust start_address + res[1] = self.desc[1] # high word unchanged + return Tcgen05SmemDescriptor(res_i64[0]) + + +# --------------------------------------------------------------------------- +# initialize_tcgen05_descriptor +# --------------------------------------------------------------------------- + + +def initialize_tcgen05_descriptor( + desc, + start_address, + leading_byte_offset, + stride_byte_offset, + base_offset, + leading_abs, + swizzle_mode, +): + """Pack SMEM descriptor bitfields into *desc* (a Tcgen05SmemDescriptor). + + Constructs the 64-bit descriptor in two 32-bit halves (desc[0] and desc[1]). + All address/offset fields must be pre-divided by 16 (``>> 4``) before + passing, because the hardware stores them in 16-byte granularity. + + Low 32 bits — desc[0]:: + + ┌────────────────┬──────┬──────────────────┐ + │ bits 29…16 │15…14 │ bits 13…0 │ + │ LBO (14 bits) │ rsvd │ start_addr >> 4 │ + └────────────────┴──────┴──────────────────┘ + + - [0:14) start_address >> 4 — SMEM tile base pointer in 16B units. + - [14:16) reserved (0). + - [16:30) leading_byte_offset — LBO in 16B units (caller passes >> 4). + + High 32 bits — desc[1]:: + + ┌────────┬────────┬─────┬──────────┬────────┬──────────────────┐ + │ 31…29 │ 28…21 │ 20 │ 19…17 │ 16…14 │ bits 13…0 │ + │ layout │ rsvd │l_abs│base_off │version │ SBO (14 bits) │ + │ (3 bit)│ (8 bit)│(1b) │ (3 bit) │=0b001 │ │ + └────────┴────────┴─────┴──────────┴────────┴──────────────────┘ + + - [0:14) stride_byte_offset — SBO in 16B units (caller passes >> 4). + - [14:16) version = 1 (fixed constant 0b001, only bit 14 set). + - [17:20) base_offset & 0x7 — swizzle alignment correction. + Typically 0. Non-zero when the tile doesn't start at + the natural swizzle boundary (1024B/512B/256B). + - [20:21) lbo_mode — 0 = LBO is relative offset, 1 = absolute address. + - [29:32) layout_type (swizzle_mode & 0x7): + 0 = SWIZZLE_NONE + 1 = SWIZZLE_128B_BASE32B (Swizzle<2,5,2>) + 2 = SWIZZLE_128B (Swizzle<3,4,3>) + 4 = SWIZZLE_64B (Swizzle<2,4,3>) + 6 = SWIZZLE_32B (Swizzle<1,4,3>) + + Args: + desc: Tcgen05SmemDescriptor to fill. + start_address: CuTeDSL Pointer to the SMEM tile start. + leading_byte_offset: Leading-dimension byte offset, already >> 4. + stride_byte_offset: Stride byte offset, already >> 4. + base_offset: Swizzle alignment correction (raw int, bits 17-19). + leading_abs: Bool — True → LBO is absolute address. + swizzle_mode: Swizzle layout_type integer (bits 29-31). + """ + # Encode start_address: take SMEM pointer, shift right by 4 to get 16B units + ptr_val = start_address.toint() >> 4 + + # --- Low 32 bits (desc[0]) --- + # bits [0:14) = start_address >> 4 + # bits [16:30) = leading_byte_offset (already in 16B units) + desc.desc[0] = cutlass.Int32(ptr_val) | cutlass.Int32(cutlass.Int32(leading_byte_offset) << 16) + + # --- High 32 bits (desc[1]) --- + # bits [0:14) = stride_byte_offset (already in 16B units) + # bit [14] = version = 1 (fixed) + # bits [17:20) = base_offset & 0x7 (swizzle alignment correction) + # bit [20] = lbo_mode (0=relative, 1=absolute) + # bits [29:32) = layout_type (swizzle mode) + desc.desc[1] = ( + cutlass.Int32(stride_byte_offset) + | cutlass.Int32(1 << 14) # version = 1 + | cutlass.Int32(cutlass.Int32(base_offset & 0x7) << 17) + | cutlass.Int32(cutlass.Int32(int(leading_abs)) << 20) + | cutlass.Int32(cutlass.Int32(swizzle_mode & 0x7) << 29) + ) + + +# --------------------------------------------------------------------------- +# Internal helper +# --------------------------------------------------------------------------- + + +def _ir(val, loc=None, ip=None): + """Extract raw MLIR IR value from a CuTeDSL wrapper.""" + return val.ir_value(loc=loc, ip=ip) if hasattr(val, "ir_value") else val + + +# =========================================================================== +# Low-level primitives +# =========================================================================== + +# --------------------------------------------------------------------------- +# tcgen05mma_ss — SMEM A, SMEM B (non-warp-specialised) +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05mma_ss( + desc_a: Tcgen05SmemDescriptor, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, + mask0: int, + mask1: int, + mask2: int, + mask3: int, +): + """Issue ``tcgen05.mma.cta_group::1.kind::tf32`` with SMEM operands. + + ``mask{0-3}`` are the four uint32 words of the 128-bit + ``disable-output-lane`` mask (0=active, 0xFFFFFFFF=disabled). + + Caller must ensure single-thread execution (e.g. via ``elect_one``); + no internal ``elect.sync`` is performed. + + Args: + desc_a: 64-bit SMEM descriptor for matrix A. + desc_b: 64-bit SMEM descriptor for matrix B. + tmem_c: TMEM base address (uint32) for accumulators C/D. + desc_val: High 32 bits of the UMMA instruction descriptor (idescE>>32). + scale_out: 1 → accumulate into C, 0 → overwrite C (clear accumulators). + mask0-3: Four uint32 words of the disable-output-lane mask. + """ + + @dsl_user_op + def _do(c_val, da_val, db_val, dv_val, sc_val, m0_val, m1_val, m2_val, m3_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + i32_ty = ir.IntegerType.get_signless(32) + i1_ty = ir.IntegerType.get_signless(1) + vec4i32_ty = ir.VectorType.get([4], i32_ty) + + c_ir = _ir(c_val, loc, ip) + d_ptr = llvm.inttoptr(ptr6_ty, c_ir, loc=loc, ip=ip) + da_ir = _ir(da_val, loc, ip) # i64 SMEM descriptor + db_ir = _ir(db_val, loc, ip) # i64 SMEM descriptor + dv_ir = _ir(dv_val, loc, ip) + sc_ir = _ir(sc_val, loc, ip) + enable_d = _arith.trunci(i1_ty, sc_ir, loc=loc, ip=ip) + + m0_ir = _ir(m0_val, loc, ip) + m1_ir = _ir(m1_val, loc, ip) + m2_ir = _ir(m2_val, loc, ip) + m3_ir = _ir(m3_val, loc, ip) + + undef = llvm.mlir_undef(vec4i32_ty, loc=loc, ip=ip) + idx0 = _arith.constant(i32_ty, 0, loc=loc, ip=ip) + idx1 = _arith.constant(i32_ty, 1, loc=loc, ip=ip) + idx2 = _arith.constant(i32_ty, 2, loc=loc, ip=ip) + idx3 = _arith.constant(i32_ty, 3, loc=loc, ip=ip) + v = llvm.InsertElementOp(undef, m0_ir, idx0, loc=loc, ip=ip) + v = llvm.InsertElementOp(v, m1_ir, idx1, loc=loc, ip=ip) + v = llvm.InsertElementOp(v, m2_ir, idx2, loc=loc, ip=ip) + mask = llvm.InsertElementOp(v, m3_ir, idx3, loc=loc, ip=ip) + + _nvvm.tcgen05_mma( + mma_kind=_nvvm.Tcgen05MMAKind.TF32, + cta_group=_nvvm.Tcgen05GroupKind.CTA_1, + d=d_ptr, + a=da_ir, + b=db_ir, + idesc=dv_ir, + enable_input_d=enable_d, + write_disable_mask=mask, + loc=loc, + ip=ip, + ) + + _do( + cutlass.Int32(tmem_c), + desc_a.desc_i64[0], + desc_b.desc_i64[0], + cutlass.Int32(desc_val), + cutlass.Int32(scale_out), + cutlass.Int32(mask0), + cutlass.Int32(mask1), + cutlass.Int32(mask2), + cutlass.Int32(mask3), + ) + + +# --------------------------------------------------------------------------- +# tcgen05mma_ts — TMEM A, SMEM B (non-warp-specialised) +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05mma_ts( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, + mask0: int, + mask1: int, + mask2: int, + mask3: int, +): + """Issue ``tcgen05.mma.cta_group::1.kind::tf32`` with TMEM A operand. + + Matrix A is read from TMEM via indirect addressing ``[tmem_a]``. + Matrix B is read from SMEM via descriptor. + Caller must ensure single-thread execution (e.g. via ``elect_one``). + + Args: + tmem_a: TMEM base address (uint32) for matrix A. + desc_b: 64-bit SMEM descriptor for matrix B. + tmem_c: TMEM base address (uint32) for accumulators C/D. + desc_val: High 32 bits of the UMMA instruction descriptor (idescE>>32). + scale_out: 1 → accumulate into C, 0 → overwrite C. + mask0-3: Four uint32 words of the disable-output-lane mask. + """ + + @dsl_user_op + def _do(c_val, a_val, db_val, dv_val, sc_val, m0_val, m1_val, m2_val, m3_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + i32_ty = ir.IntegerType.get_signless(32) + i1_ty = ir.IntegerType.get_signless(1) + vec4i32_ty = ir.VectorType.get([4], i32_ty) + + c_ir = _ir(c_val, loc, ip) + a_ir = _ir(a_val, loc, ip) + d_ptr = llvm.inttoptr(ptr6_ty, c_ir, loc=loc, ip=ip) + a_ptr = llvm.inttoptr(ptr6_ty, a_ir, loc=loc, ip=ip) + b_ir = _ir(db_val, loc, ip) + dv_ir = _ir(dv_val, loc, ip) + sc_ir = _ir(sc_val, loc, ip) + enable_d = _arith.trunci(i1_ty, sc_ir, loc=loc, ip=ip) + + m0_ir = _ir(m0_val, loc, ip) + m1_ir = _ir(m1_val, loc, ip) + m2_ir = _ir(m2_val, loc, ip) + m3_ir = _ir(m3_val, loc, ip) + + undef = llvm.mlir_undef(vec4i32_ty, loc=loc, ip=ip) + idx0 = _arith.constant(i32_ty, 0, loc=loc, ip=ip) + idx1 = _arith.constant(i32_ty, 1, loc=loc, ip=ip) + idx2 = _arith.constant(i32_ty, 2, loc=loc, ip=ip) + idx3 = _arith.constant(i32_ty, 3, loc=loc, ip=ip) + v = llvm.InsertElementOp(undef, m0_ir, idx0, loc=loc, ip=ip) + v = llvm.InsertElementOp(v, m1_ir, idx1, loc=loc, ip=ip) + v = llvm.InsertElementOp(v, m2_ir, idx2, loc=loc, ip=ip) + mask = llvm.InsertElementOp(v, m3_ir, idx3, loc=loc, ip=ip) + + _nvvm.tcgen05_mma( + mma_kind=_nvvm.Tcgen05MMAKind.TF32, + cta_group=_nvvm.Tcgen05GroupKind.CTA_1, + d=d_ptr, + a=a_ptr, + b=b_ir, + idesc=dv_ir, + enable_input_d=enable_d, + write_disable_mask=mask, + loc=loc, + ip=ip, + ) + + _do( + cutlass.Int32(tmem_c), + cutlass.Int32(tmem_a), + desc_b.desc_i64[0], + cutlass.Int32(desc_val), + cutlass.Int32(scale_out), + cutlass.Int32(mask0), + cutlass.Int32(mask1), + cutlass.Int32(mask2), + cutlass.Int32(mask3), + ) + + +# --------------------------------------------------------------------------- +# tcgen05mma_ws_ss_tf32 — weight-stationary, SMEM A, SMEM B, kind::tf32 +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05mma_ws_ss_tf32( + desc_a: Tcgen05SmemDescriptor, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, + collector_b_buffer=None, + collector_op=None, +): + """Issue ``tcgen05.mma.ws.cta_group::1.kind::tf32`` (weight-stationary form). + + This variant does NOT take a ``disable-output-lane`` mask; the + optional ``zero-column-mask-desc`` operand is omitted. + + Args: + desc_a: 64-bit SMEM descriptor for matrix A. + desc_b: 64-bit SMEM descriptor for matrix B. + tmem_c: TMEM base address (uint32) for accumulators C/D. + desc_val: High 32 bits of the UMMA instruction descriptor (idescE>>32). + scale_out: 1 → accumulate, 0 → overwrite. + collector_b_buffer: Optional ``CollectorBBuffer`` enum (B0–B3). + Defaults to None (hardware default: ``b0::discard``). + collector_op: Optional ``CollectorOp`` enum (FILL/USE/LASTUSE/DISCARD). + Defaults to None (hardware default: discard). + """ + + @dsl_user_op + def _do(c_val, da_val, db_val, dv_val, sc_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + i1_ty = ir.IntegerType.get_signless(1) + + c_ir = _ir(c_val, loc, ip) + d_ptr = llvm.inttoptr(ptr6_ty, c_ir, loc=loc, ip=ip) + da_ir = _ir(da_val, loc, ip) + db_ir = _ir(db_val, loc, ip) + dv_ir = _ir(dv_val, loc, ip) + sc_ir = _ir(sc_val, loc, ip) + enable_d = _arith.trunci(i1_ty, sc_ir, loc=loc, ip=ip) + + _nvvm.tcgen05_mma_ws( + mma_kind=_nvvm.Tcgen05MMAKind.TF32, + d=d_ptr, + a=da_ir, + b=db_ir, + idesc=dv_ir, + enable_input_d=enable_d, + collector_b_buffer=collector_b_buffer, + collector_op=collector_op, + loc=loc, + ip=ip, + ) + + _do( + cutlass.Int32(tmem_c), + desc_a.desc_i64[0], + desc_b.desc_i64[0], + cutlass.Int32(desc_val), + cutlass.Int32(scale_out), + ) + + +# --------------------------------------------------------------------------- +# tcgen05mma_ws_ss_f16 — weight-stationary, SMEM A, SMEM B, kind::f16 +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05mma_ws_ss_f16( + desc_a: Tcgen05SmemDescriptor, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, + collector_b_buffer=None, + collector_op=None, +): + """Issue ``tcgen05.mma.ws.cta_group::1.kind::f16`` (weight-stationary form). + + Same as the tf32 variant but uses ``.kind::f16`` for half-precision + input types (f16 / bf16). K dimension is 16 instead of 8. + + This variant does NOT take a ``disable-output-lane`` mask; the + optional ``zero-column-mask-desc`` operand is omitted. + + Args: + desc_a: 64-bit SMEM descriptor for matrix A. + desc_b: 64-bit SMEM descriptor for matrix B. + tmem_c: TMEM base address (uint32) for accumulators C/D. + desc_val: High 32 bits of the UMMA instruction descriptor (idescE>>32). + scale_out: 1 → accumulate, 0 → overwrite. + collector_b_buffer: Optional ``CollectorBBuffer`` enum (B0–B3). + Defaults to None (hardware default: ``b0::discard``). + collector_op: Optional ``CollectorOp`` enum (FILL/USE/LASTUSE/DISCARD). + Defaults to None (hardware default: discard). + """ + + @dsl_user_op + def _do(c_val, da_val, db_val, dv_val, sc_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + i1_ty = ir.IntegerType.get_signless(1) + + c_ir = _ir(c_val, loc, ip) + d_ptr = llvm.inttoptr(ptr6_ty, c_ir, loc=loc, ip=ip) + da_ir = _ir(da_val, loc, ip) + db_ir = _ir(db_val, loc, ip) + dv_ir = _ir(dv_val, loc, ip) + sc_ir = _ir(sc_val, loc, ip) + enable_d = _arith.trunci(i1_ty, sc_ir, loc=loc, ip=ip) + + _nvvm.tcgen05_mma_ws( + mma_kind=_nvvm.Tcgen05MMAKind.F16, + d=d_ptr, + a=da_ir, + b=db_ir, + idesc=dv_ir, + enable_input_d=enable_d, + collector_b_buffer=collector_b_buffer, + collector_op=collector_op, + loc=loc, + ip=ip, + ) + + _do( + cutlass.Int32(tmem_c), + desc_a.desc_i64[0], + desc_b.desc_i64[0], + cutlass.Int32(desc_val), + cutlass.Int32(scale_out), + ) + + +# --------------------------------------------------------------------------- +# tcgen05mma_ws_ts_tf32 — weight-stationary, TMEM A, SMEM B, kind::tf32 +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05mma_ws_ts_tf32( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, + collector_b_buffer=None, + collector_op=None, +): + """Issue ``tcgen05.mma.ws.cta_group::1.kind::tf32`` with TMEM A (weight-stationary). + + Matrix A is read from TMEM via indirect addressing ``[tmem_a]``. + Matrix B is read from SMEM via descriptor. + This variant does NOT take a ``disable-output-lane`` mask; the + optional ``zero-column-mask-desc`` operand is omitted. + + Args: + tmem_a: TMEM base address (uint32) for matrix A. + desc_b: 64-bit SMEM descriptor for matrix B. + tmem_c: TMEM base address (uint32) for accumulators C/D. + desc_val: High 32 bits of the UMMA instruction descriptor (idescE>>32). + scale_out: 1 → accumulate, 0 → overwrite. + collector_b_buffer: Optional ``CollectorBBuffer`` enum (B0–B3). + Defaults to None (hardware default: ``b0::discard``). + collector_op: Optional ``CollectorOp`` enum (FILL/USE/LASTUSE/DISCARD). + Defaults to None (hardware default: discard). + """ + + @dsl_user_op + def _do(c_val, a_val, db_val, dv_val, sc_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + i1_ty = ir.IntegerType.get_signless(1) + + c_ir = _ir(c_val, loc, ip) + d_ptr = llvm.inttoptr(ptr6_ty, c_ir, loc=loc, ip=ip) + a_ir = _ir(a_val, loc, ip) + a_ptr = llvm.inttoptr(ptr6_ty, a_ir, loc=loc, ip=ip) + db_ir = _ir(db_val, loc, ip) + dv_ir = _ir(dv_val, loc, ip) + sc_ir = _ir(sc_val, loc, ip) + enable_d = _arith.trunci(i1_ty, sc_ir, loc=loc, ip=ip) + + _nvvm.tcgen05_mma_ws( + mma_kind=_nvvm.Tcgen05MMAKind.TF32, + d=d_ptr, + a=a_ptr, + b=db_ir, + idesc=dv_ir, + enable_input_d=enable_d, + collector_b_buffer=collector_b_buffer, + collector_op=collector_op, + loc=loc, + ip=ip, + ) + + _do( + cutlass.Int32(tmem_c), + cutlass.Int32(tmem_a), + desc_b.desc_i64[0], + cutlass.Int32(desc_val), + cutlass.Int32(scale_out), + ) + + +# --------------------------------------------------------------------------- +# tcgen05mma_ws_ts_f16 — weight-stationary, TMEM A, SMEM B, kind::f16 +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05mma_ws_ts_f16( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, + collector_b_buffer=None, + collector_op=None, +): + """Issue ``tcgen05.mma.ws.cta_group::1.kind::f16`` with TMEM A (weight-stationary). + + Same as the tf32 variant but uses ``.kind::f16`` for half-precision + input types (f16 / bf16). K dimension is 16 instead of 8. + + Matrix A is read from TMEM via indirect addressing ``[tmem_a]``. + Matrix B is read from SMEM via descriptor. + This variant does NOT take a ``disable-output-lane`` mask; the + optional ``zero-column-mask-desc`` operand is omitted. + + Args: + tmem_a: TMEM base address (uint32) for matrix A. + desc_b: 64-bit SMEM descriptor for matrix B. + tmem_c: TMEM base address (uint32) for accumulators C/D. + desc_val: High 32 bits of the UMMA instruction descriptor (idescE>>32). + scale_out: 1 → accumulate, 0 → overwrite. + collector_b_buffer: Optional ``CollectorBBuffer`` enum (B0–B3). + Defaults to None (hardware default: ``b0::discard``). + collector_op: Optional ``CollectorOp`` enum (FILL/USE/LASTUSE/DISCARD). + Defaults to None (hardware default: discard). + """ + + @dsl_user_op + def _do(c_val, a_val, db_val, dv_val, sc_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + i1_ty = ir.IntegerType.get_signless(1) + + c_ir = _ir(c_val, loc, ip) + d_ptr = llvm.inttoptr(ptr6_ty, c_ir, loc=loc, ip=ip) + a_ir = _ir(a_val, loc, ip) + a_ptr = llvm.inttoptr(ptr6_ty, a_ir, loc=loc, ip=ip) + db_ir = _ir(db_val, loc, ip) + dv_ir = _ir(dv_val, loc, ip) + sc_ir = _ir(sc_val, loc, ip) + enable_d = _arith.trunci(i1_ty, sc_ir, loc=loc, ip=ip) + + _nvvm.tcgen05_mma_ws( + mma_kind=_nvvm.Tcgen05MMAKind.F16, + d=d_ptr, + a=a_ptr, + b=db_ir, + idesc=dv_ir, + enable_input_d=enable_d, + collector_b_buffer=collector_b_buffer, + collector_op=collector_op, + loc=loc, + ip=ip, + ) + + _do( + cutlass.Int32(tmem_c), + cutlass.Int32(tmem_a), + desc_b.desc_i64[0], + cutlass.Int32(desc_val), + cutlass.Int32(scale_out), + ) + + +# =========================================================================== +# Named convenience wrappers +# =========================================================================== +# These call the low-level primitives with pre-set mask constants so callers +# do not need to repeat the literal values. Signature: same as the base +# function but without the mask0-3 args. + +# --------------------------------------------------------------------------- +# SS named wrappers (SMEM A) +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05mma_ss_no_mask( + desc_a: Tcgen05SmemDescriptor, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """SS MMA with no output-lane disable (all rows active).""" + tcgen05mma_ss(desc_a, desc_b, tmem_c, desc_val, scale_out, SS_NO_MASK[0], SS_NO_MASK[1], SS_NO_MASK[2], SS_NO_MASK[3]) + + +@cute.jit +def tcgen05mma_ss_mask0( + desc_a: Tcgen05SmemDescriptor, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """SS MMA: mask={0, 0xF…, 0, 0xF…} — groups 0,2 active (1,3 disabled).""" + tcgen05mma_ss(desc_a, desc_b, tmem_c, desc_val, scale_out, SS_MASK0[0], SS_MASK0[1], SS_MASK0[2], SS_MASK0[3]) + + +@cute.jit +def tcgen05mma_ss_mask1( + desc_a: Tcgen05SmemDescriptor, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """SS MMA: mask={0xF…, 0, 0xF…, 0} — groups 1,3 active (0,2 disabled).""" + tcgen05mma_ss(desc_a, desc_b, tmem_c, desc_val, scale_out, SS_MASK1[0], SS_MASK1[1], SS_MASK1[2], SS_MASK1[3]) + + +@cute.jit +def tcgen05mma_ss_mask2( + desc_a: Tcgen05SmemDescriptor, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """SS MMA: mask={0xF…, 0xF…, 0, 0xF…} — group 2 only active.""" + tcgen05mma_ss(desc_a, desc_b, tmem_c, desc_val, scale_out, SS_MASK2[0], SS_MASK2[1], SS_MASK2[2], SS_MASK2[3]) + + +@cute.jit +def tcgen05mma_ss_mask3( + desc_a: Tcgen05SmemDescriptor, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """SS MMA: mask={0xF…, 0xF…, 0xF…, 0} — group 3 only active.""" + tcgen05mma_ss(desc_a, desc_b, tmem_c, desc_val, scale_out, SS_MASK3[0], SS_MASK3[1], SS_MASK3[2], SS_MASK3[3]) + + +# --------------------------------------------------------------------------- +# TS named wrappers (TMEM A) +# --------------------------------------------------------------------------- + + +@cute.jit +def tcgen05mma_ts_no_mask( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """TS MMA with no output-lane disable (all rows active).""" + tcgen05mma_ts(tmem_a, desc_b, tmem_c, desc_val, scale_out, TS_NO_MASK[0], TS_NO_MASK[1], TS_NO_MASK[2], TS_NO_MASK[3]) + + +@cute.jit +def tcgen05mma_ts_mask0( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """TS MMA: mask={0, 0xF…, 0xF…, 0xF…} — group 0 only active.""" + tcgen05mma_ts(tmem_a, desc_b, tmem_c, desc_val, scale_out, TS_MASK0[0], TS_MASK0[1], TS_MASK0[2], TS_MASK0[3]) + + +@cute.jit +def tcgen05mma_ts_mask1( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """TS MMA: mask={0xF…, 0, 0xF…, 0xF…} — group 1 only active.""" + tcgen05mma_ts(tmem_a, desc_b, tmem_c, desc_val, scale_out, TS_MASK1[0], TS_MASK1[1], TS_MASK1[2], TS_MASK1[3]) + + +@cute.jit +def tcgen05mma_ts_mask2( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """TS MMA: mask={0xF…, 0xF…, 0, 0xF…} — group 2 only active.""" + tcgen05mma_ts(tmem_a, desc_b, tmem_c, desc_val, scale_out, TS_MASK2[0], TS_MASK2[1], TS_MASK2[2], TS_MASK2[3]) + + +@cute.jit +def tcgen05mma_ts_mask3( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """TS MMA: mask={0xF…, 0xF…, 0xF…, 0} — group 3 only active.""" + tcgen05mma_ts(tmem_a, desc_b, tmem_c, desc_val, scale_out, TS_MASK3[0], TS_MASK3[1], TS_MASK3[2], TS_MASK3[3]) + + +@cute.jit +def tcgen05mma_ts_mask02( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """TS MMA: mask={0, 0xF…, 0, 0xF…} — groups 0,2 active (1,3 disabled). + + Used in the KDA intra-chunk backward kernel for the QK/KG phase where + only even row-groups of the M tile contribute to the triangular region. + """ + tcgen05mma_ts(tmem_a, desc_b, tmem_c, desc_val, scale_out, TS_MASK02[0], TS_MASK02[1], TS_MASK02[2], TS_MASK02[3]) + + +@cute.jit +def tcgen05mma_ts_mask13( + tmem_a: int, + desc_b: Tcgen05SmemDescriptor, + tmem_c: int, + desc_val: int, + scale_out: int, +): + """TS MMA: mask={0xF…, 0, 0xF…, 0} — groups 1,3 active (0,2 disabled). + + Used in the KDA intra-chunk backward kernel for the QK/KG phase where + only odd row-groups of the M tile contribute to the triangular region. + """ + tcgen05mma_ts(tmem_a, desc_b, tmem_c, desc_val, scale_out, TS_MASK13[0], TS_MASK13[1], TS_MASK13[2], TS_MASK13[3]) From 59d02eeb0bb9e0ac6066f45489193911ab7b34a8 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Tue, 5 May 2026 10:22:38 +0800 Subject: [PATCH 02/26] integrate and pass test --- benchmarks/bench_kda_fwd_bwd_e2e.py | 5 +- cula/kda/chunk_bwd.py | 5 +- cula/ops/chunk_wy_dqkg_sm100.py | 100 ++-- tests/test_ptx_umma_masked.py | 323 +++++++++++ tests/test_ptx_umma_ws.py | 848 ++++++++++++++++++++++++++++ 5 files changed, 1231 insertions(+), 50 deletions(-) create mode 100644 tests/test_ptx_umma_masked.py create mode 100644 tests/test_ptx_umma_ws.py diff --git a/benchmarks/bench_kda_fwd_bwd_e2e.py b/benchmarks/bench_kda_fwd_bwd_e2e.py index c6b4117..2275bce 100644 --- a/benchmarks/bench_kda_fwd_bwd_e2e.py +++ b/benchmarks/bench_kda_fwd_bwd_e2e.py @@ -221,8 +221,11 @@ def check_determinism(num_seqs=5, T=512, iters=20): ref = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) for i in range(iters): out = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) - for name in ("o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"): + for name in ("o", "ht", "dq", "dk", "dv", "dg", "dh0"): assert torch.equal(out[name], ref[name]), f"[determinism] cuLA {name} mismatch at iter {i}" + for name in ("dbeta",): + # NOTE: for db, kernel uses atomic add which can cause non-determinism, so we use a looser check here + torch.testing.assert_close(out[name], ref[name], rtol=1e-5, atol=1e-5), f"db mismatch at iter {i}" return True diff --git a/cula/kda/chunk_bwd.py b/cula/kda/chunk_bwd.py index 859b6be..83f7924 100644 --- a/cula/kda/chunk_bwd.py +++ b/cula/kda/chunk_bwd.py @@ -37,6 +37,7 @@ import cula.cudac as cula_cuda from cula.kda.chunk_intra import chunk_kda_bwd_intra +from cula.ops.chunk_wy_dqkg_sm100 import chunk_kda_bwd_wy_dqkg_fused as chunk_kda_bwd_wy_dqkg_fused_cutedsl from cula.utils import prepare_uniform_cu_seqlens _delta_h_mod = importlib.import_module("cula.ops.chunk_delta_h") @@ -554,7 +555,7 @@ def chunk_kda_bwd( transpose_state_layout=transpose_state_layout, ) - dq, dk, dv, db, dg, dAkk = chunk_kda_bwd_wy_dqkg_fused( + dq, dk, dv, db, dg, dAkk = chunk_kda_bwd_wy_dqkg_fused_cutedsl( q=q, k=k, v=v, @@ -570,7 +571,7 @@ def chunk_kda_bwd( cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, - transpose_state_layout=transpose_state_layout, + # transpose_state_layout=transpose_state_layout, ) dq, dk, db, dg = chunk_kda_bwd_intra( diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index f24c185..98775a6 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -27,7 +27,7 @@ from cutlass.cute.typing import Float32, Int32, Int64, BFloat16 from fla.ops.utils import prepare_chunk_indices -from cula.utils import USE_FAST_MATH, assert_blackwell +from cula.utils import USE_FAST_MATH, assert_blackwell, prepare_uniform_cu_seqlens from cula.ops.intrinsics_sm100 import ( tcgen05_fence_before, @@ -415,7 +415,7 @@ def __init__( g_dtype: type[cutlass.Numeric] = cutlass.Float32, beta_dtype: type[cutlass.Numeric] = cutlass.Float32, scale: float = 1.0, - min_occupancy: int = 1, # FIXME: change to 2, bug exists for accuracy + min_occupancy: int = 1, use_fast_math: bool = True, ): assert chunk_size == 64, "chunk_size must be 64" @@ -1779,7 +1779,7 @@ def kernel( self.cuda_wg_sync_barrier.arrive_and_wait() pipeline_load_beta.consumer_wait(load_beta_consumer_state) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") beta_val = sBeta[(row,)] db_val = Float32(0.0) @@ -1963,7 +1963,7 @@ def kernel( chunk = cute.local_tile(dw_bf16_rmem, (8,), (i,)) smem_store_bf16x8_sw128(sDw_raw_ptr, row, col_base, chunk) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") pipeline_prologue_dw.producer_commit(prologue_dw_producer_state) prologue_dw_producer_state.advance() @@ -2000,7 +2000,7 @@ def kernel( chunk_kg = cute.local_tile(rKG_bf16, (8,), (i,)) smem_store_bf16x8_sw128(sK_raw_ptr, row, col_base, chunk_kg) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") pipeline_prologue_kg.producer_commit(prologue_kg_producer_state) prologue_kg_producer_state.advance() @@ -2072,7 +2072,7 @@ def kernel( if wg_idx == 0: sum = Float32(0.0) - for r in cutlass.range_constexpr(self.BT): + for r in cutlass.range(self.BT, unroll_full=True): sum += sG_raw[(r, local_tidx, 0)] sDgk[(local_tidx, )] += sum @@ -2190,7 +2190,7 @@ def kernel( chunk_dA = cute.local_tile(rDA, (8,), (i,)) smem_store_bf16x8_sw128(sQ_raw_ptr, row, col_base, chunk_dA) # notify dA2 = dA @ A - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") pipeline_prologue_dA2.producer_commit(prologue_dA2_producer_state) prologue_dA2_producer_state.advance() @@ -2221,7 +2221,7 @@ def kernel( chunk_dA2 = cute.local_tile(rDA2, (8,), (i,)) smem_store_bf16x8_sw128(sQ_raw_ptr, row, col_base, chunk_dA2) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") pipeline_prologue_dA3.producer_commit(prologue_dA3_producer_state) prologue_dA3_producer_state.advance() @@ -2586,10 +2586,7 @@ def kernel( # address stores to avoid layout-coordinate ambiguity. smem_store_bf16x8_sw128(sA_raw, row, col * 8, zeros8) # Make generic-proxy SMEM stores visible to UMMA async-proxy readers. - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) + cute.arch.fence_proxy("async.shared", space="cta") self.mma_warp_sync_barrier.arrive_and_wait() for v_iter in cutlass.range(self.num_v_tiles): @@ -2608,10 +2605,7 @@ def kernel( for col in cutlass.range_constexpr(self.BV // 8): # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sDo_raw_ptr, row, col * 8, zeros8) - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) + cute.arch.fence_proxy("async.shared", space="cta") self.mma_warp_sync_barrier.arrive_and_wait() if v_iter == 0: @@ -2648,10 +2642,7 @@ def kernel( for col in cutlass.range_constexpr(self.BV // 8): # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sDv_raw, row, col * 8, zeros8) - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) + cute.arch.fence_proxy("async.shared", space="cta") self.mma_warp_sync_barrier.arrive_and_wait() # if lane_idx == 0: @@ -2692,10 +2683,7 @@ def kernel( for col in cutlass.range_constexpr(self.BV // 8): # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sV_raw, row, col * 8, zeros8) - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) + cute.arch.fence_proxy("async.shared", space="cta") self.mma_warp_sync_barrier.arrive_and_wait() if v_iter == 0: pipeline_mma_dA.producer_acquire(mma_dA_producer_state) @@ -2735,10 +2723,7 @@ def kernel( for col in cutlass.range_constexpr(self.BV // 8): # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sDvnew_raw_ptr, row, col * 8, zeros8) - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) + cute.arch.fence_proxy("async.shared", space="cta") self.mma_warp_sync_barrier.arrive_and_wait() pipeline_load_dh.consumer_wait(load_dh_consumer_state) @@ -2772,7 +2757,7 @@ def kernel( vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage pipeline_prologue_dw.consumer_wait(prologue_dw_consumer_state) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") # dkgb = A @ dw pipeline_mma_dkgb.producer_acquire(mma_dgkb_producer_state) sA_mn_cur = sA_mn[(None, None, None, a_stage_idx)] @@ -2787,7 +2772,7 @@ def kernel( mma_dgkb_producer_state.advance() pipeline_prologue_kg.consumer_wait(prologue_kg_consumer_state) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") # dA += dw @ kg^T sDw_k_cur = sDw_k[(None, None, None, 0)] sKG_k_cur = sKG_k[(None, None, None, 0)] @@ -2808,7 +2793,7 @@ def kernel( # dA2 = dA @ A pipeline_mma_dA2.producer_acquire(mma_dA2_producer_state) pipeline_prologue_dA2.consumer_wait(prologue_dA2_consumer_state) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") sDA_k_cur = sDA_k[(None, None, None, 0)] sA_k_cur = sA_k[(None, None, None, 0)] @@ -2826,7 +2811,7 @@ def kernel( # dA3 = A @ dA2 pipeline_mma_dA3.producer_acquire(mma_dA3_producer_state) pipeline_prologue_dA3.consumer_wait(prologue_dA3_consumer_state) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") sA_mn_cur = sA_mn[(None, None, None, 0)] sDA_mn_cur = sDA_mn[(None, None, None, 0)] @@ -2873,10 +2858,7 @@ def kernel( beta_f32 = Float32(beta_gmem[(tok_offset + tile_idx * self.BT + tidx, (head_idx, Int32(0)))]) sBeta[(tidx, )] = beta_f32 - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) + cute.arch.fence_proxy("async.shared", space="cta") pipeline_load_beta.producer_commit(load_beta_producer_state) load_beta_producer_state.advance() @@ -3101,13 +3083,16 @@ def chunk_kda_bwd_wy_dqkg_fused( Returns: (dq, dk, dv2, db, dg, dA) matching FLA's chunk_kda_bwd_wy_dqkg_fused output order. """ - if chunk_indices is None and cu_seqlens is not None: - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) - B, T, H, K = q.shape V = v.shape[3] BT = chunk_size beta_dtype = beta.dtype + device = q.device + + if cu_seqlens is None: + cu_seqlens = prepare_uniform_cu_seqlens(B, T, device, torch.int32) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if scale is None: scale = K**-0.5 @@ -3115,19 +3100,18 @@ def chunk_kda_bwd_wy_dqkg_fused( assert cu_seqlens is not None and chunk_indices is not None # Ensure cu_seqlens is int32 assert cu_seqlens.dtype == torch.int32, "cu_seqlens must be int32" - assert q.dim() == 4 and q.shape[0] == 1 - T_total = q.shape[1] + T_total = B * T num_seqs = cu_seqlens.shape[0] - 1 total_nt_val = chunk_indices.shape[0] ps = (Int32(num_seqs), Int32(T_total), Int32(H), Int32(K), Int32(V)) # Allocate output tensors - dq = torch.empty_like(q, dtype=torch.float32) - dk = torch.empty_like(k, dtype=torch.float32) - dv2 = torch.empty_like(v) # bf16 - dg = torch.empty_like(g, dtype=torch.float32) - db = torch.empty(B, T, H, dtype=torch.float32, device=q.device) - dA = torch.empty(B, T, H, BT, dtype=torch.float32, device=q.device) + dq = torch.empty(1, T_total, H, K, dtype=torch.float32, device=device) + dk = torch.empty(1, T_total, H, K, dtype=torch.float32, device=device) + dv2 = torch.empty(1, T_total, H, V, dtype=torch.bfloat16, device=device) + dg = torch.empty(1, T_total, H, K, dtype=torch.float32, device=device) + db = torch.empty(1, T_total, H, dtype=torch.float32, device=device) + dA = torch.empty(1, T_total, H, BT, dtype=torch.float32, device=device) compiled_fn = _get_compiled_bwd_wy( H, @@ -3138,6 +3122,19 @@ def chunk_kda_bwd_wy_dqkg_fused( beta_dtype, ) + if B != 1: + q = q.reshape(1, T_total, H, K) + k = k.reshape(1, T_total, H, K) + v = v.reshape(1, T_total, H, V) + v_new = v_new.reshape(1, T_total, H, V) + g = g.reshape(1, T_total, H, K) + beta = beta.reshape(1, T_total, H) + A = A.reshape(1, T_total, H, BT) + h = h.reshape(1, total_nt_val, H, K, V) + do = do.reshape(1, T_total, H, V) + dh = dh.reshape(1, total_nt_val, H, K, V) + dv = dv.reshape(1, T_total, H, V) + # TVM-FFI call compiled_fn( # Inputs @@ -3166,6 +3163,15 @@ def chunk_kda_bwd_wy_dqkg_fused( Int32(total_nt_val), ) + # rearrange back + if B != 1: + dq = dq.reshape(B, T, H, K) + dk = dk.reshape(B, T, H, K) + dv2 = dv2.reshape(B, T, H, V) + dg = dg.reshape(B, T, H, K) + db = db.reshape(B, T, H) + dA = dA.reshape(B, T, H, BT) + return dq, dk, dv2, db, dg, dA diff --git a/tests/test_ptx_umma_masked.py b/tests/test_ptx_umma_masked.py new file mode 100644 index 0000000..356575e --- /dev/null +++ b/tests/test_ptx_umma_masked.py @@ -0,0 +1,323 @@ +""" +Standalone CuteDSL test for ptx_umma_masked.py inline PTX MMA wrappers. + +Tests: + 1. tcgen05mma_ss_no_mask -- M=64, N=64, K=8, TF32, all rows active → matches torch.mm + 2. tcgen05mma_ss_mask0 -- groups 0,2 active (rows 0-15, 32-47), groups 1,3 disabled + 3. tcgen05mma_ss_mask1 -- groups 1,3 active (rows 16-31, 48-63), groups 0,2 disabled + +SMEM layout: + A: swizzled K-major (Swizzle<1,4,3>, SWIZZLE_32B), descriptor LBO=1 SBO=16 layout=6 + B: swizzled MN-major (Swizzle<2,5,2>, SWIZZLE_128B_BASE32B), LBO=64 SBO=32 layout=1 + Data is loaded with M-major mapping for A, direct row-major for B. + +All descriptor values are computed via make_umma_smem_desc / smem_descriptor_to_int +(proven correct in test_umma_ptx_jit.py). Wrapped in Tcgen05SmemDescriptor for API +compatibility with ptx_umma_masked.py convenience wrappers. +""" + +import pathlib +import sys + +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import torch +from cutlass.cute.arch import ( + elect_one, + mbarrier_init, + mbarrier_init_fence, + mbarrier_wait, + sync_threads, +) +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cute.nvgpu.tcgen05 import ( + Pack, + Repetition, + make_umma_smem_desc, + smem_descriptor_to_int, +) +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Float32, Int32, Int64, TFloat32 + +from cula.ops.ptx_umma_ext import ( + Tcgen05SmemDescriptor, + tcgen05mma_ss_mask0, + tcgen05mma_ss_mask1, + tcgen05mma_ss_no_mask, +) + +M_DIM, N_DIM, K_DIM = 64, 64, 8 +TMEM_COLS = 64 + +IDESC_M64_N64 = (4 << 24) | (8 << 17) | (1 << 16) | (2 << 10) | (2 << 7) | (1 << 4) +assert IDESC_M64_N64 == 0x4110910 + + +class _Kernel: + def __init__(self, mask_mode: str = "none"): + self.mask_mode = mask_mode + + @cute.kernel + def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): + """ + For mask_mode == "none": single-phase MMA, all rows written. + For mask_mode == "mask0"/"mask1": two-phase MMA: + Phase 1 - full no-mask MMA with A_in used as A_zero (passed by caller as zeros), + scale_out=0 → zeroes TMEM for all rows. + Phase 2 - masked MMA using B column from B_in (same B, new A from A_in second half). + + To keep the interface simple, for mask tests A_in is [A_zero (64×8) || A_real (64×8)] + concatenated to shape (128, 8). Phase 1 loads rows [0:64], phase 2 loads rows [64:128]. + """ + M, N, K = M_DIM, N_DIM, K_DIM + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + smem = utils.SmemAllocator() + tmem_hold_ptr = smem.allocate(Int32) + mbar_ptr = smem.allocate(Int64, byte_alignment=8) + + # Build tiled_mma to get correct SMEM layout + tiled_mma = sm100_utils.make_trivial_tiled_mma( + TFloat32, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + Float32, + tcgen05.CtaGroup.ONE, + (M, N), + ) + mma_tiler = (M, N, K) + + # Allocate swizzled SMEM + a_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler, TFloat32, 1) + b_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler, TFloat32, 1) + bufferA = smem.allocate_tensor( + element_type=TFloat32, + layout=a_smem_layout.outer, + byte_alignment=128, + swizzle=a_smem_layout.inner, + ) + bufferB = smem.allocate_tensor( + element_type=TFloat32, + layout=b_smem_layout.outer, + byte_alignment=128, + swizzle=b_smem_layout.inner, + ) + bufA_s0 = bufferA[(None, None, None, 0)] + bufB_s0 = bufferB[(None, None, None, 0)] + + if tidx == cutlass.Int32(0): + mbarrier_init(mbar_ptr, 1) + mbarrier_init_fence() + + # gA_all: flat view of input A tensor (either M*K or 2M*K elements) + # For no_mask: caller passes (M,K) → total = M*K + # For mask tests: caller passes (2M,K) → total = 2M*K; row_offset selects which half + if cutlass.const_expr(self.mask_mode != "none"): + gA_all = cute.make_tensor(A_in.iterator, cute.make_layout(2 * M_DIM * K_DIM)) + else: + gA_all = cute.make_tensor(A_in.iterator, cute.make_layout(M_DIM * K_DIM)) + + # Load B once (shared by both phases) + gB_flat = cute.make_tensor(B_in.iterator, cute.make_layout(K * N)) + for step in cutlass.range(K * N // 128, unroll_full=False): + idx = tidx + step * 128 + bufB_s0[idx] = gB_flat[idx] + sync_threads() + + # TMEM allocation + alloc_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=128) + tmem = utils.TmemAllocator( + tmem_hold_ptr, + barrier_for_retrieve=alloc_bar, + allocator_warp_id=0, + ) + tmem.allocate(TMEM_COLS) + tmem.wait_for_alloc() + tmem_ptr_f32 = tmem.retrieve_ptr(Float32) + + acc_shape = tiled_mma.partition_shape_C((M, N)) + acc_shape_staged = cute.append(acc_shape, 1) + tCtAcc = cute.make_tensor(tmem_ptr_f32, tiled_mma.make_fragment_C(acc_shape_staged).layout) + tmem_col_buf = cute.make_tensor(tmem_hold_ptr, cute.make_layout(1)) + tmem_col = tmem_col_buf[0] + + # Build descriptors + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufA_s0.iterator, bufA_s0.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufB_s0.iterator, bufB_s0.layout, "mn")) + desc_a = Tcgen05SmemDescriptor(desc_a_i64) + desc_b = Tcgen05SmemDescriptor(desc_b_i64) + + if cutlass.const_expr(self.mask_mode != "none"): + # Phase 1: Load A_zero (first M rows of gA_all = all zeros), no_mask MMA → zero TMEM + for step in cutlass.range(M_DIM * K_DIM // 128, unroll_full=False): + smem_idx = tidx + step * 128 + m = smem_idx % M_DIM + k = smem_idx // M_DIM + bufA_s0[smem_idx] = gA_all[m * K_DIM + k] # row_offset=0 + sync_threads() + if warp_idx == cutlass.Int32(0): + tcgen05mma_ss_no_mask(desc_a, desc_b, tmem_col, IDESC_M64_N64, 0) + with elect_one(): + tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + mbarrier_wait(mbar_ptr, 0) + sync_threads() + # Re-arm mbar for second MMA + if tidx == cutlass.Int32(0): + mbarrier_init(mbar_ptr, 1) + mbarrier_init_fence() + + # Phase 2: Load A_real (rows M..2M of gA_all = real data), masked MMA + for step in cutlass.range(M_DIM * K_DIM // 128, unroll_full=False): + smem_idx = tidx + step * 128 + m = smem_idx % M_DIM + k = smem_idx // M_DIM + bufA_s0[smem_idx] = gA_all[(M_DIM + m) * K_DIM + k] # row_offset=M + sync_threads() + if warp_idx == cutlass.Int32(0): + if cutlass.const_expr(self.mask_mode == "mask0"): + tcgen05mma_ss_mask0(desc_a, desc_b, tmem_col, IDESC_M64_N64, 0) + elif cutlass.const_expr(self.mask_mode == "mask1"): + tcgen05mma_ss_mask1(desc_a, desc_b, tmem_col, IDESC_M64_N64, 0) + with elect_one(): + tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + mbarrier_wait(mbar_ptr, 0) + sync_threads() + else: + # Simple single-phase no_mask + for step in cutlass.range(M_DIM * K_DIM // 128, unroll_full=False): + smem_idx = tidx + step * 128 + m = smem_idx % M_DIM + k = smem_idx // M_DIM + bufA_s0[smem_idx] = gA_all[m * K_DIM + k] + sync_threads() + if warp_idx == cutlass.Int32(0): + tcgen05mma_ss_no_mask(desc_a, desc_b, tmem_col, IDESC_M64_N64, 0) + with elect_one(): + tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + mbarrier_wait(mbar_ptr, 0) + sync_threads() + + # T2R: TMEM → RMEM + t2r_atom = cute.make_copy_atom(tcgen05.Ld16x256bOp(Repetition(8), Pack.NONE), Float32) + fake_smem = cute.make_tensor(cute.make_ptr(Float32, 0, cute.AddressSpace.smem), cute.make_layout((M, N))) + tCtAcc_flat = tCtAcc[((None, None), 0, 0, None)] + tiled_t2r = tcgen05.make_tmem_copy(t2r_atom, tCtAcc_flat[(None, None, 0)]) + thr_t2r = tiled_t2r.get_slice(tidx) + tTR_tAcc = thr_t2r.partition_S(tCtAcc_flat) + tTR_sDummy = thr_t2r.partition_D(fake_smem) + tTR_rAcc = cute.make_rmem_tensor(tTR_sDummy.shape, Float32) + + cute.copy(tiled_t2r, tTR_tAcc[(None, None, None, 0)], tTR_rAcc) + cute.arch.fence_view_async_tmem_load() + + # R2G: RMEM → GMEM (row-major) + gC = cute.make_tensor(C_out.iterator, cute.make_layout((M, N), stride=(N, 1))) + tTR_gC = thr_t2r.partition_D(gC) + cute.copy(cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), tTR_rAcc, tTR_gC) + + sync_threads() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr_f32, TMEM_COLS) + + @cute.jit + def _launch(self, A: cute.Tensor, B: cute.Tensor, C: cute.Tensor, stream): + self.kernel(A, B, C).launch(grid=(1, 1, 1), block=(128, 1, 1), stream=stream) + + def run(self, A_cpu, B_cpu): + """ + For no_mask: A_cpu is (M, K). + For mask tests: A_cpu is (2M, K) where [0:M] = zeros, [M:2M] = real A. + """ + A_gpu = A_cpu.contiguous().float().cuda() + B_gpu = B_cpu.contiguous().float().cuda() + C_gpu = torch.zeros(M_DIM, N_DIM, dtype=torch.float32, device="cuda") + stream = cutlass_torch.default_stream() + self._launch(from_dlpack(A_gpu), from_dlpack(B_gpu), from_dlpack(C_gpu), stream) + torch.cuda.synchronize() + return C_gpu.cpu() + + +def test_ss_no_mask(): + print("\n=== Test 1: tcgen05mma_ss_no_mask (all rows active) ===") + torch.manual_seed(42) + A = torch.randn(M_DIM, K_DIM) + B = torch.randn(K_DIM, N_DIM) + ref = torch.mm(A, B) + got = _Kernel("none").run(A, B) + rel = (got - ref).abs().max().item() / (ref.abs().max().item() + 1e-8) + print(f" got[0,:4]={got[0, :4].tolist()}") + print(f" ref[0,:4]={ref[0, :4].tolist()}") + print(f" max_rel_err={rel:.4f}") + assert rel < 0.02, f"FAIL: rel={rel:.4f}" + print(" PASSED") + + +def _run_masked(mask_mode, A_real, B): + """ + Two-phase: phase1 = no_mask with A_zero (rows 0..M-1 of combined A = zeros), + phase2 = masked with A_real (rows M..2M-1 of combined A). + Combined A is shape (2M, K): [zeros || A_real]. + """ + A_combined = torch.cat([torch.zeros_like(A_real), A_real], dim=0) + return _Kernel(mask_mode).run(A_combined, B) + + +def test_ss_mask0(): + print("\n=== Test 2: tcgen05mma_ss_mask0 ===") + torch.manual_seed(0) + A = torch.randn(M_DIM, K_DIM) + B = torch.randn(K_DIM, N_DIM) + ref = torch.mm(A, B) # expected for active rows + got = _run_masked("mask0", A, B) + + # SS_MASK0 = (0, 0xFF..., 0, 0xFF...) → mask words 1,3 disable rows 16-31 and 48-63 + # Active: rows 0-15 and 32-47, Disabled: rows 16-31 and 48-63 + active_rows = list(range(0, 16)) + list(range(32, 48)) + masked_rows = list(range(16, 32)) + list(range(48, 64)) + + rel_active = (got[active_rows] - ref[active_rows]).abs().max().item() / (ref[active_rows].abs().max().item() + 1e-8) + zero_max = got[masked_rows].abs().max().item() + + print(f" active rows 0-15,32-47: max_rel_err={rel_active:.4f} (expect <0.02)") + print(f" masked rows 16-31,48-63: max_abs={zero_max:.4f} (expect 0.0)") + assert rel_active < 0.02, f"FAIL active rows: {rel_active:.4f}" + assert zero_max == 0.0, f"FAIL masked rows not zero: {zero_max}" + print(" PASSED") + + +def test_ss_mask1(): + print("\n=== Test 3: tcgen05mma_ss_mask1 ===") + torch.manual_seed(7) + A = torch.randn(M_DIM, K_DIM) + B = torch.randn(K_DIM, N_DIM) + ref = torch.mm(A, B) + got = _run_masked("mask1", A, B) + + # SS_MASK1 = (0xFF..., 0, 0xFF..., 0) → mask words 0,2 disable rows 0-15 and 32-47 + # Active: rows 16-31 and 48-63, Disabled: rows 0-15 and 32-47 + active_rows = list(range(16, 32)) + list(range(48, 64)) + masked_rows = list(range(0, 16)) + list(range(32, 48)) + + rel_active = (got[active_rows] - ref[active_rows]).abs().max().item() / (ref[active_rows].abs().max().item() + 1e-8) + zero_max = got[masked_rows].abs().max().item() + + print(f" active rows 16-31,48-63: max_rel_err={rel_active:.4f} (expect <0.02)") + print(f" masked rows 0-15,32-47: max_abs={zero_max:.4f} (expect 0.0)") + assert rel_active < 0.02, f"FAIL active rows: {rel_active:.4f}" + assert zero_max == 0.0, f"FAIL masked rows not zero: {zero_max}" + print(" PASSED") + + +if __name__ == "__main__": + test_ss_no_mask() + test_ss_mask0() + test_ss_mask1() + print("\n=== All tests passed! ===") diff --git a/tests/test_ptx_umma_ws.py b/tests/test_ptx_umma_ws.py new file mode 100644 index 0000000..ea9acf0 --- /dev/null +++ b/tests/test_ptx_umma_ws.py @@ -0,0 +1,848 @@ +""" +Standalone CuteDSL test for tcgen05.mma.ws (weight-stationary) inline PTX wrappers. + +Tests: + 1. tcgen05mma_ws_ss_tf32 -- WS mode, SMEM A × SMEM B → TMEM C, kind::tf32 + 2. tcgen05mma_ws_ts_tf32 -- WS mode, TMEM A × SMEM B → TMEM C, kind::tf32 + 3. tcgen05mma_ws_ss_f16 -- WS mode, SMEM A × SMEM B → TMEM C, kind::f16 + 4. tcgen05mma_ws_ts_f16 -- WS mode, TMEM A × SMEM B → TMEM C, kind::f16 + +For the WS_TS test, matrix A is first loaded into TMEM via an SS MMA (identity- +like multiplication), then used as the A operand for the WS TS MMA. To keep +things simple we use a two-TMEM-column approach: + - tmem region 0: accumulator for both phases + - tmem region 1: holds A data for TS phase (populated via R2T store) + +SMEM layout follows the same conventions as test_ptx_umma_masked.py. +""" + +import pathlib +import sys + +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import torch +from cutlass.cute.arch import ( + elect_one, + mbarrier_init, + mbarrier_init_fence, + mbarrier_wait, + sync_threads, +) +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cute.nvgpu.tcgen05 import ( + make_umma_smem_desc, + smem_descriptor_to_int, +) +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.tensor import TensorSSA +from cutlass.cute.typing import Float16, Float32, Int32, Int64, TFloat32, BFloat16 + +from cula.ops.intrinsics_sm100 import ( + store_256b, + subvec, + tcgen05_cp_128x256b, + reinterpret_cast, + tcgen05_ld_32x32b, + tcgen05_st_32x32b, +) +from cula.ops.ptx_umma_ext import ( + CollectorBBuffer, + CollectorOp, + Tcgen05SmemDescriptor, + tcgen05mma_ws_ss_f16, + tcgen05mma_ws_ss_tf32, + tcgen05mma_ws_ts_f16, +) + +M_DIM, N_DIM = 64, 64 +# TODO: support arbitrary K +K_DIM_TF32 = 8 # kind::tf32 → K>=8, tile size +A_K_STEP_BYTES_TF32 = M_DIM * 8 * 4 # smem offset for each K-atom in operand A +B_K_STEP_BYTES_TF32 = N_DIM * 8 * 4 # smem offset for each K-atom in operand B +K_DIM_F16 = 128 # default after sweep +# NOTE: per-K-atom byte offsets are derived from the SMEM layout at runtime +# (see _WsSsF16Kernel) so K_DIM_F16 can be any multiple of 16. The layout's +# k_iter mode becomes hierarchical at K≥128 (e.g. (4, K/64):(16, 4096) for A), +# which the layout-based offset computation handles transparently. + +# Instruction descriptor for M=64, N=64, TF32, dense, TransposeB=1 +# Bits: M>>4=4 at [24:28], N>>3=8 at [17:22], TransposeB at [16], +# btype=tf32(2) at [10:12], atype=tf32(2) at [7:9], dtype=f32(1) at [4:5] +IDESC_TF32_M64_N64 = (4 << 24) | (8 << 17) | (1 << 16) | (2 << 10) | (2 << 7) | (1 << 4) +assert IDESC_TF32_M64_N64 == 0x4110910 + +# Instruction descriptor for M=64, N=64, BF16, dense, TransposeB=1 +# Bits: M>>4=4 at [24:28], N>>3=8 at [17:22], TransposeB at [16], +# btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] +IDESC_F16_M64_N64 = (4 << 24) | (8 << 17) | (1 << 16) | (1 << 10) | (1 << 7) | (1 << 4) +assert IDESC_F16_M64_N64 == 0x4110490 + +# Instruction descriptor for M=64, N=128, BF16, dense, TransposeB=1 +# Bits: M>>4=4 at [24:28], N>>3=16 at [17:22], TransposeB at [16], +# btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] +IDESC_F16_M64_N128 = (4 << 24) | (16 << 17) | (1 << 16) | (1 << 10) | (1 << 7) | (1 << 4) +assert IDESC_F16_M64_N128 == 0x4210490 + + +# ===================================================================== +# Test 1: tcgen05mma_ws_ss_tf32 (weight-stationary, SMEM A, SMEM B, tf32) +# ===================================================================== + + +class _WsSsTf32Kernel: + @cute.kernel + def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): + M, N, K = M_DIM, N_DIM, K_DIM_TF32 + ACC_NUM_COLS = N // 2 + NUM_COLS = ACC_NUM_COLS + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + smem = utils.SmemAllocator() + tmem_hold_ptr = smem.allocate(Int32) + mbar_ptr = smem.allocate(Int64, byte_alignment=8) + + # --- SMEM layouts via sm100_utils (handles swizzle correctly for TF32) --- + # NOTE: we use non-ws mode TiledMMA for creating smem layout in a easy way, + # because smem layouts of ws mode and non-ws mode are the same + non_ws_tiled_mma = sm100_utils.make_trivial_tiled_mma( + TFloat32, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + Float32, + tcgen05.CtaGroup.ONE, + (M, N), + ) + mma_tiler = (M, N, K) + a_smem_layout = sm100_utils.make_smem_layout_a(non_ws_tiled_mma, mma_tiler, TFloat32, 1) + b_smem_layout = sm100_utils.make_smem_layout_b(non_ws_tiled_mma, mma_tiler, TFloat32, 1) + bufferA = smem.allocate_tensor( + element_type=TFloat32, + layout=a_smem_layout.outer, + byte_alignment=128, + swizzle=a_smem_layout.inner, + ) + bufferB = smem.allocate_tensor( + element_type=TFloat32, + layout=b_smem_layout.outer, + byte_alignment=128, + swizzle=b_smem_layout.inner, + ) + bufA_s0 = bufferA[(None, None, None, 0)] + bufB_s0 = bufferB[(None, None, None, 0)] + + if tidx == cutlass.Int32(0): + mbarrier_init(mbar_ptr, 1) + mbarrier_init_fence() + + # Load A (row-major input → K-major swizzled SMEM) and B + gA_flat = cute.make_tensor(A_in.iterator, cute.make_layout(M * K)) + gB_flat = cute.make_tensor(B_in.iterator, cute.make_layout(K * N)) + + for step in cutlass.range(M * K // 128, unroll_full=False): + smem_idx = tidx + step * 128 + m = smem_idx % M + k = smem_idx // M + bufA_s0[smem_idx] = gA_flat[m * K + k] + for step in cutlass.range(K * N // 128, unroll_full=False): + idx = tidx + step * 128 + bufB_s0[idx] = gB_flat[idx] + sync_threads() + + # --- TMEM allocation --- + alloc_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=128) + tmem = utils.TmemAllocator( + tmem_hold_ptr, + barrier_for_retrieve=alloc_bar, + allocator_warp_id=0, + ) + tmem.allocate(NUM_COLS) + tmem.wait_for_alloc() + tmem_ptr_f32 = tmem.retrieve_ptr(Float32) + + tmem_col_buf = cute.make_tensor(tmem_hold_ptr, cute.make_layout(1)) + tmem_col = tmem_col_buf[0] + + # Build SMEM descriptors (rank-2 vec_mode layout required) + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufA_s0.iterator, bufA_s0.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufB_s0.iterator, bufB_s0.layout, "mn")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + + # Issue WS SS MMA (scale_out=0 → D = A*B, not accumulate) + if warp_idx == cutlass.Int32(0): + with elect_one(): + for ks in cutlass.range_constexpr(K // 8): + scale = 0 if ks == 0 else 1 + desc_a = desc_a_base + (ks * A_K_STEP_BYTES_TF32) + desc_b = desc_b_base + (ks * B_K_STEP_BYTES_TF32) + tcgen05mma_ws_ss_tf32(desc_a, desc_b, tmem_col, IDESC_TF32_M64_N64, scale) + tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + mbarrier_wait(mbar_ptr, 0) + sync_threads() + + # T2R → R2G: tcgen05_ld directly into store_256b (type-agnostic, like C++ reinterpret_cast) + vec_i32 = tcgen05_ld_32x32b(ACC_NUM_COLS, tmem_col) + cute.arch.fence_view_async_tmem_load() + + # 1. reinterpret_cast to f32 (zero-cost bitcast) + # vec_f32 = reinterpret_cast(vec_i32, Int32, ACC_NUM_COLS, Float32) + + # 2. TensorSSA wrap → .to(BFloat16) (real CUDA core CVT) + # regs = TensorSSA(vec_f32, (ACC_NUM_COLS,), Float32) + + # Debug print: thread 0, first 4 register values + # if tidx == cutlass.Int32(0): + # cute.printf("[T2R] tid=0, regs[0..3] = %f, %f, %f, %f", + # regs[0], regs[1], regs[2], regs[3]) + + # R2G via store_256b (4 × 256-bit stores per thread) + # Layout E (column-major warp order): + # warp0->(M0,N0), warp1->(M1,N0), warp2->(M0,N1), warp3->(M1,N1) + lane_idx = tidx % 32 + row = (warp_idx % 2) * 32 + lane_idx + col_base = (warp_idx // 2) * 32 + base_addr = (C_out.iterator + row * N + col_base).toint() + for chunk in cutlass.range_constexpr(ACC_NUM_COLS // 8): + store_256b(base_addr + chunk * 32, subvec(vec_i32, chunk * 8, 8)) + + sync_threads() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr_f32, NUM_COLS) + + @cute.jit + def _launch(self, A: cute.Tensor, B: cute.Tensor, C: cute.Tensor, stream): + self.kernel(A, B, C).launch(grid=(1, 1, 1), block=(128, 1, 1), stream=stream) + + def run(self, A_cpu, B_cpu): + assert K_DIM_TF32 == 8, "TODO: support larger K-dimension" + A_gpu = A_cpu.contiguous().float().cuda() + B_gpu = B_cpu.contiguous().float().cuda() + C_gpu = torch.zeros(M_DIM, N_DIM, dtype=torch.float32, device="cuda") + stream = cutlass_torch.default_stream() + self._launch(from_dlpack(A_gpu), from_dlpack(B_gpu), from_dlpack(C_gpu), stream) + torch.cuda.synchronize() + return C_gpu.cpu() + + +# ===================================================================== +# Test 2: tcgen05mma_ws_ts_tf32 (weight-stationary, TMEM A, SMEM B, tf32) +# ===================================================================== + +# TODO +class _WsTsTf32Kernel: + """Two-step test: + - Region 0 (tmem_a_region): populate A into TMEM via S2T copy + - Region 1 (tmem_c_region): result of WS TS MMA (tmem_a × B → C) + """ + + @cute.kernel + def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): + pass + + @cute.jit + def _launch(self, A: cute.Tensor, B: cute.Tensor, C: cute.Tensor, stream): + self.kernel(A, B, C).launch(grid=(1, 1, 1), block=(128, 1, 1), stream=stream) + + def run(self, A_cpu, B_cpu): + raise NotImplementedError( + "TODO: implement this test following the pattern of _WsSsTf32Kernel, but with the two-step approach described in the docstring" + ) + + +# ===================================================================== +# Test 3: tcgen05mma_ws_ss_f16 (weight-stationary, SMEM A, SMEM B, f16) +# ===================================================================== + + +class _WsSsF16Kernel: + def __init__(self, M: int, N: int, K: int): + self.M = M + self.N = N + self.K = K + if N == 64: + self.idesc = IDESC_F16_M64_N64 + elif N == 128: + self.idesc = IDESC_F16_M64_N128 + else: + raise ValueError(f"Unsupported N={N} for F16 IDESC (expected 64 or 128)") + + @cute.kernel + def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): + M, N, K = self.M, self.N, self.K + idesc = self.idesc + ACC_NUM_COLS = N // 2 + NUM_COLS = ACC_NUM_COLS + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + smem = utils.SmemAllocator() + tmem_hold_ptr = smem.allocate(Int32) + mbar_ptr = smem.allocate(Int64, byte_alignment=8) + + # Create MMA SMEM Layouts + # NOTE: we use non-ws mode TiledMMA for creating smem layout in a easy way, + # because smem layouts of ws mode and non-ws mode are the same + mma_tiler = (M, N, K) + non_ws_tiled_mma = sm100_utils.make_trivial_tiled_mma( + BFloat16, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + Float32, + tcgen05.CtaGroup.ONE, + (M, N), + ) + a_smem_layout = sm100_utils.make_smem_layout_a(non_ws_tiled_mma, mma_tiler, BFloat16, 1) + b_smem_layout = sm100_utils.make_smem_layout_b(non_ws_tiled_mma, mma_tiler, BFloat16, 1) + bufferA = smem.allocate_tensor( + element_type=BFloat16, + layout=a_smem_layout.outer, + byte_alignment=128, + swizzle=a_smem_layout.inner, + ) + + bufferB = smem.allocate_tensor( + element_type=BFloat16, + layout=b_smem_layout.outer, + byte_alignment=128, + swizzle=b_smem_layout.inner, + ) + + bufA_s0 = bufferA[(None, None, None, 0)] + bufB_s0 = bufferB[(None, None, None, 0)] + + if tidx == cutlass.Int32(0): + mbarrier_init(mbar_ptr, 1) + mbarrier_init_fence() + + # Load A (row-major input → K-major swizzled SMEM) + gA_flat = cute.make_tensor(A_in.iterator, cute.make_layout(M * K)) + gB_flat = cute.make_tensor(B_in.iterator, cute.make_layout(K * N)) + + for step in cutlass.range(M * K // 128, unroll_full=False): + smem_idx = tidx + step * 128 + m = smem_idx % M + k = smem_idx // M + bufA_s0[smem_idx] = gA_flat[m * K + k] + for step in cutlass.range(K * N // 128, unroll_full=False): + idx = tidx + step * 128 + bufB_s0[idx] = gB_flat[idx] + sync_threads() + + # --- TMEM allocation --- + alloc_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=128) + tmem = utils.TmemAllocator( + tmem_hold_ptr, + barrier_for_retrieve=alloc_bar, + allocator_warp_id=0, + ) + tmem.allocate(NUM_COLS) + tmem.wait_for_alloc() + tmem_ptr_f32 = tmem.retrieve_ptr(Float32) + + tmem_col_buf = cute.make_tensor(tmem_hold_ptr, cute.make_layout(1)) + tmem_col = tmem_col_buf[0] + + # Build SMEM descriptors (rank-2 vec_mode layout required) + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufA_s0.iterator, bufA_s0.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufB_s0.iterator, bufB_s0.layout, "mn")) + desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + + # Per-K-atom byte offsets are derived from the (unswizzled) outer layout + # so we transparently handle every K size: + # K∈{16,32,64} → A k_iter is single-mode, uniform stride + # K≥128 → A k_iter is hierarchical e.g. (4,K/64):(16,4096) + # B is always uniform stride=1024 elem + # Coord ((0,0), 0, ks, 0) into outer layout gives the linear elem offset + # of the ks-th MMA-K atom; * sizeof(elem) → byte offset to add to desc. + ELEM_BYTES_F16 = BFloat16.width // 8 + a_outer = a_smem_layout.outer + b_outer = b_smem_layout.outer + + # Issue WS SS MMA (scale_out=0 → D = A*B, not accumulate) + if warp_idx == cutlass.Int32(0): + with elect_one(): + for ks in cutlass.range_constexpr(K // 16): + scale = 0 if ks == 0 else 1 + a_off = cute.crd2idx(((0, 0), 0, ks, 0), a_outer) * ELEM_BYTES_F16 + b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_F16 + desc_a = desc_a_base + a_off + desc_b = desc_b_base + b_off + tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_col, idesc, scale) + tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + mbarrier_wait(mbar_ptr, 0) + sync_threads() + + # T2R + # Layout E (M=64, ws mode): 128 lanes, 32 columns + # ref: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e + # .32x32b.x32 loads all 32 columns → 32 FP32 regs per thread + # Layout: warp0->(M0,N0), warp1->(M0,N1), warp2->(M1,N0), warp3->(M1,N1) + # for 64x64 Acc, each warp process 32x32, with 128 lanes in TMEM all used + + vec_i32 = tcgen05_ld_32x32b(ACC_NUM_COLS, tmem_col) + cute.arch.fence_view_async_tmem_load() + + # =======DEBUG======== + # # 1. reinterpret_cast to f32 (zero-cost bitcast) + # vec_f32 = reinterpret_cast(vec_i32, Int32, ACC_NUM_COLS, Float32) + + # # 2. TensorSSA wrap → .to(BFloat16) (real CUDA core CVT) + # regs = TensorSSA(vec_f32, (ACC_NUM_COLS,), Float32) + + # # Debug print: thread 0, first 4 register values + # if tidx == cutlass.Int32(0): + # cute.printf("[T2R] tid=0, regs[0..3] = %f, %f, %f, %f", + # regs[0], regs[1], regs[2], regs[3]) + + # R2G via store_256b (4 × 256-bit stores per thread) + # Layout E (column-major warp order): + # warp0->(M0,N0), warp1->(M1,N0), warp2->(M0,N1), warp3->(M1,N1) + # in each warp, each thread process one row, T0->[0, 0:31], T1->[1, 0:31], ..., T31->[31, 0:31] + lane_idx = tidx % 32 + row = (warp_idx % 2) * M // 2 + lane_idx # M0 or M1 + col_base = (warp_idx // 2) * ACC_NUM_COLS # N0 or N1 + # 32 regs = 4 chunks of 8 × 32-bit each (256 bits) + base_addr = (C_out.iterator + row * N + col_base).toint() + for chunk in cutlass.range_constexpr(ACC_NUM_COLS // 8): + store_256b(base_addr + chunk * 32, subvec(vec_i32, chunk * 8, 8)) + + sync_threads() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr_f32, NUM_COLS) + + @cute.jit + def _launch(self, A: cute.Tensor, B: cute.Tensor, C: cute.Tensor, stream): + self.kernel(A, B, C).launch(grid=(1, 1, 1), block=(128, 1, 1), stream=stream) + + def run(self, A_cpu, B_cpu): + M, N = self.M, self.N + A_gpu = A_cpu.cuda().to(torch.bfloat16).contiguous() + B_gpu = B_cpu.cuda().to(torch.bfloat16).contiguous() + C_gpu = torch.zeros(M, N, dtype=torch.float32, device="cuda") + stream = cutlass_torch.default_stream() + self._launch(from_dlpack(A_gpu), from_dlpack(B_gpu), from_dlpack(C_gpu), stream) + torch.cuda.synchronize() + return C_gpu.cpu() + + +# ===================================================================== +# Test 4: tcgen05mma_ws_ts_f16 (weight-stationary, TMEM A, SMEM B, f16) +# ===================================================================== + + +# TODO: support with first S2R then R2T with tcgen05.st +class _WsTsF16Kernel: + """Two-step test (same strategy as _WsTsTf32Kernel but with kind::f16): + - Region 0 (tmem_a_region): populate A into TMEM via S2T copy + - Region 1 (tmem_c_region): result of WS TS MMA (tmem_a × B → C) + """ + def __init__(self, M: int, N: int, K: int): + self.M = M + self.N = N + self.K = K + if N == 64: + self.idesc = IDESC_F16_M64_N64 + elif N == 128: + self.idesc = IDESC_F16_M64_N128 + else: + raise ValueError(f"Unsupported N={N} for F16 IDESC (expected 64 or 128)") + self.wg_sync_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=128, + ) + + @cute.kernel + def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): + M, N, K = self.M, self.N, self.K + idesc = self.idesc + ACC_NUM_COLS = N // 2 + # Layout E (M=64 + .ws mode) for A operand in TMEM: + # - 4 warps × 16 active lanes (lane<16) cover M=64 rows + # warp w → rows [w*16 .. w*16+15] + # - Each active lane stores its full K-row as bf16 pairs in 32-bit cells + # - Per K-atom (16 BF16): 8 TMEM cols (16 active lanes × 8 dwords = 128 BF16 = 64 rows × 2 BF16-pairs) + # wait: 16 lanes × 8 cols × 2 bf16/col = 256 BF16 per warp per K-atom; 4 warps × 256 = 1024 BF16 = 64×16 ✓ + # - Total cols for A^T[64, K]: K/2 cols + # ref: chunk_kda_bwd_sm100_tmem.cuh::pack_AT_to_tmem_bf16, FlashMLA SM100_UTCCP_128dp256bit + OPA_NUM_COLS = K // 2 # 8 cols per K-atom of 16 BF16 + TMEM_OPA = 0 + TMEM_ACC = TMEM_OPA + OPA_NUM_COLS + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = tidx % 32 + + smem = utils.SmemAllocator() + tmem_hold_ptr = smem.allocate(Int32) + mbar_ptr = smem.allocate(Int64, byte_alignment=8) + + # Create MMA SMEM Layouts + # NOTE: we use non-ws mode TiledMMA for creating smem layout in a easy way, + # because smem layouts of ws mode and non-ws mode are the same + mma_tiler = (M, N, K) + non_ws_tiled_mma = sm100_utils.make_trivial_tiled_mma( + BFloat16, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + Float32, + tcgen05.CtaGroup.ONE, + (M, N), + ) + a_smem_layout = sm100_utils.make_smem_layout_a(non_ws_tiled_mma, mma_tiler, BFloat16, 1) + b_smem_layout = sm100_utils.make_smem_layout_b(non_ws_tiled_mma, mma_tiler, BFloat16, 1) + bufferA = smem.allocate_tensor( + element_type=BFloat16, + layout=a_smem_layout.outer, + byte_alignment=128, + swizzle=a_smem_layout.inner, + ) + + bufferB = smem.allocate_tensor( + element_type=BFloat16, + layout=b_smem_layout.outer, + byte_alignment=128, + swizzle=b_smem_layout.inner, + ) + + bufA_s0 = bufferA[(None, None, None, 0)] + bufB_s0 = bufferB[(None, None, None, 0)] + + if tidx == cutlass.Int32(0): + mbarrier_init(mbar_ptr, 1) + mbarrier_init_fence() + + # Load A (row-major input → K-major swizzled SMEM) + gA_flat = cute.make_tensor(A_in.iterator, cute.make_layout(M * K)) + gB_flat = cute.make_tensor(B_in.iterator, cute.make_layout(K * N)) + + for step in cutlass.range(M * K // 128, unroll_full=False): + smem_idx = tidx + step * 128 + m = smem_idx % M + k = smem_idx // M + bufA_s0[smem_idx] = gA_flat[m * K + k] + for step in cutlass.range(K * N // 128, unroll_full=False): + idx = tidx + step * 128 + bufB_s0[idx] = gB_flat[idx] + sync_threads() + + # --- TMEM allocation --- + alloc_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=128) + tmem = utils.TmemAllocator( + tmem_hold_ptr, + barrier_for_retrieve=alloc_bar, + allocator_warp_id=0, + ) + tmem.allocate(512) + tmem.wait_for_alloc() + tmem_ptr_f32 = tmem.retrieve_ptr(Float32) + + rA = cute.make_rmem_tensor((self.K, ), BFloat16) + # Layout E A-operand pack: warp w lane l (l<16) → A^T row = w*16 + l, full K-row. + # Inactive lanes (l>=16) store zeros (warp-collective tcgen05.st requires all 32 lanes). + row_active = warp_idx * 16 + lane_idx # only valid when lane_idx < 16 + is_active = lane_idx < cutlass.Int32(16) + for j in cutlass.range_constexpr(self.K): + if is_active: + rA[j] = bufA_s0[row_active * self.K + j] + else: + rA[j] = BFloat16(0.0) + + rA_val = rA.load() + rA_i32_val = reinterpret_cast(rA_val, BFloat16, self.K, Int32) + # Store K/2 dwords per active lane → K/2 TMEM cols per warp subpart + tcgen05_st_32x32b(self.K // 2, TMEM_OPA, rA_i32_val) + cute.arch.fence_view_async_tmem_store() + + self.wg_sync_barrier.arrive_and_wait() + + # Build SMEM descriptors (rank-2 vec_mode layout required) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufB_s0.iterator, bufB_s0.layout, "mn")) + desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) + tmem_a_base = TMEM_OPA + + # Per-K-atom byte offsets are derived from the (unswizzled) outer layout + # so we transparently handle every K size: + # K∈{16,32,64} → A k_iter is single-mode, uniform stride + # K≥128 → A k_iter is hierarchical e.g. (4,K/64):(16,4096) + # B is always uniform stride=1024 elem + # Coord ((0,0), 0, ks, 0) into outer layout gives the linear elem offset + # of the ks-th MMA-K atom; * sizeof(elem) → byte offset to add to desc. + ELEM_BYTES_F16 = BFloat16.width // 8 + b_outer = b_smem_layout.outer + + # Issue WS SS MMA (scale_out=0 → D = A*B, not accumulate) + if warp_idx == cutlass.Int32(0): + with elect_one(): + for ks in cutlass.range_constexpr(K // 16): + scale = 0 if ks == 0 else 1 + b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_F16 + desc_b = desc_b_base + b_off + tmem_a = tmem_a_base + ks * 8 # 8 TMEM cols per K-atom of 16 BF16 + tcgen05mma_ws_ts_f16(tmem_a, desc_b, TMEM_ACC, idesc, scale) + tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + mbarrier_wait(mbar_ptr, 0) + sync_threads() + + # T2R + # Layout E (M=64, ws mode): 128 lanes, 32 columns + # ref: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e + # .32x32b.x32 loads all 32 columns → 32 FP32 regs per thread + # Layout: warp0->(M0,N0), warp1->(M0,N1), warp2->(M1,N0), warp3->(M1,N1) + # for 64x64 Acc, each warp process 32x32, with 128 lanes in TMEM all used + + vec_i32 = tcgen05_ld_32x32b(ACC_NUM_COLS, TMEM_ACC) + cute.arch.fence_view_async_tmem_load() + + # =======DEBUG======== + # # 1. reinterpret_cast to f32 (zero-cost bitcast) + # vec_f32 = reinterpret_cast(vec_i32, Int32, ACC_NUM_COLS, Float32) + + # # 2. TensorSSA wrap → .to(BFloat16) (real CUDA core CVT) + # regs = TensorSSA(vec_f32, (ACC_NUM_COLS,), Float32) + + # # Debug print: thread 0, first 4 register values + # if tidx == cutlass.Int32(0): + # cute.printf("[T2R] tid=0, regs[0..3] = %f, %f, %f, %f", + # regs[0], regs[1], regs[2], regs[3]) + + # R2G via store_256b (4 × 256-bit stores per thread) + # Layout E C-out (column-major warp order): + # warp0->(M0,N0), warp1->(M1,N0), warp2->(M0,N1), warp3->(M1,N1) + col_base = (warp_idx // 2) * ACC_NUM_COLS # N0 or N1 + out_row = (warp_idx % 2) * (M // 2) + lane_idx # M0 or M1, full 32 lanes used for C + # 32 regs = 4 chunks of 8 × 32-bit each (256 bits) + base_addr = (C_out.iterator + out_row * N + col_base).toint() + for chunk in cutlass.range_constexpr(ACC_NUM_COLS // 8): + store_256b(base_addr + chunk * 32, subvec(vec_i32, chunk * 8, 8)) + + sync_threads() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr_f32, 512) + + @cute.jit + def _launch(self, A: cute.Tensor, B: cute.Tensor, C: cute.Tensor, stream): + self.kernel(A, B, C).launch(grid=(1, 1, 1), block=(128, 1, 1), stream=stream) + + def run(self, A_cpu, B_cpu): + M, N = self.M, self.N + A_gpu = A_cpu.cuda().to(torch.bfloat16).contiguous() + B_gpu = B_cpu.cuda().to(torch.bfloat16).contiguous() + C_gpu = torch.zeros(M, N, dtype=torch.float32, device="cuda") + stream = cutlass_torch.default_stream() + self._launch(from_dlpack(A_gpu), from_dlpack(B_gpu), from_dlpack(C_gpu), stream) + torch.cuda.synchronize() + return C_gpu.cpu() + + +# ===================================================================== +# Test 5: tcgen05mma_ws_ss_tf32 with explicit collector_b_buffer/collector_op +# ===================================================================== + + +class _WsSsTf32CollectorKernel: + """Same as _WsSsTf32Kernel but passes collector_b_buffer=B0, collector_op=DISCARD.""" + + @cute.kernel + def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): + M, N, K = M_DIM, N_DIM, 8 # default K with 8 + ACC_NUM_COLS = N // 2 + NUM_COLS = ACC_NUM_COLS + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + smem = utils.SmemAllocator() + tmem_hold_ptr = smem.allocate(Int32) + mbar_ptr = smem.allocate(Int64, byte_alignment=8) + + non_ws_tiled_mma = sm100_utils.make_trivial_tiled_mma( + TFloat32, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + Float32, + tcgen05.CtaGroup.ONE, + (M, N), + ) + mma_tiler = (M, N, K) + a_smem_layout = sm100_utils.make_smem_layout_a(non_ws_tiled_mma, mma_tiler, TFloat32, 1) + b_smem_layout = sm100_utils.make_smem_layout_b(non_ws_tiled_mma, mma_tiler, TFloat32, 1) + bufferA = smem.allocate_tensor( + element_type=TFloat32, layout=a_smem_layout.outer, byte_alignment=128, swizzle=a_smem_layout.inner + ) + bufferB = smem.allocate_tensor( + element_type=TFloat32, layout=b_smem_layout.outer, byte_alignment=128, swizzle=b_smem_layout.inner + ) + bufA_s0 = bufferA[(None, None, None, 0)] + bufB_s0 = bufferB[(None, None, None, 0)] + + if tidx == cutlass.Int32(0): + mbarrier_init(mbar_ptr, 1) + mbarrier_init_fence() + + gA_flat = cute.make_tensor(A_in.iterator, cute.make_layout(M * K)) + gB_flat = cute.make_tensor(B_in.iterator, cute.make_layout(K * N)) + for step in cutlass.range(M * K // 128, unroll_full=False): + smem_idx = tidx + step * 128 + m = smem_idx % M + k = smem_idx // M + bufA_s0[smem_idx] = gA_flat[m * K + k] + for step in cutlass.range(K * N // 128, unroll_full=False): + idx = tidx + step * 128 + bufB_s0[idx] = gB_flat[idx] + sync_threads() + + alloc_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=128) + tmem = utils.TmemAllocator(tmem_hold_ptr, barrier_for_retrieve=alloc_bar, allocator_warp_id=0) + tmem.allocate(NUM_COLS) + tmem.wait_for_alloc() + tmem_ptr_f32 = tmem.retrieve_ptr(Float32) + tmem_col_buf = cute.make_tensor(tmem_hold_ptr, cute.make_layout(1)) + tmem_col = tmem_col_buf[0] + + desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufA_s0.iterator, bufA_s0.layout, "k")) + desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufB_s0.iterator, bufB_s0.layout, "mn")) + desc_a = Tcgen05SmemDescriptor(desc_a_i64) + desc_b = Tcgen05SmemDescriptor(desc_b_i64) + + if warp_idx == cutlass.Int32(0): + with elect_one(): + tcgen05mma_ws_ss_tf32( + desc_a, + desc_b, + tmem_col, + IDESC_TF32_M64_N64, + 0, + collector_b_buffer=CollectorBBuffer.B0, + collector_op=CollectorOp.DISCARD, + ) + tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + mbarrier_wait(mbar_ptr, 0) + sync_threads() + + vec_i32 = tcgen05_ld_32x32b(NUM_COLS, tmem_col) + cute.arch.fence_view_async_tmem_load() + lane_idx = tidx % 32 + row = (warp_idx % 2) * M // 2 + lane_idx + col_base = (warp_idx // 2) * ACC_NUM_COLS + base_addr = (C_out.iterator + row * N + col_base).toint() + for chunk in cutlass.range_constexpr(ACC_NUM_COLS // 8): + store_256b(base_addr + chunk * 32, subvec(vec_i32, chunk * 8, 8)) + + sync_threads() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr_f32, NUM_COLS) + + @cute.jit + def _launch(self, A: cute.Tensor, B: cute.Tensor, C: cute.Tensor, stream): + self.kernel(A, B, C).launch(grid=(1, 1, 1), block=(128, 1, 1), stream=stream) + + def run(self, A_cpu, B_cpu): + A_gpu = A_cpu.contiguous().float().cuda() + B_gpu = B_cpu.contiguous().float().cuda() + C_gpu = torch.zeros(M_DIM, N_DIM, dtype=torch.float32, device="cuda") + stream = cutlass_torch.default_stream() + self._launch(from_dlpack(A_gpu), from_dlpack(B_gpu), from_dlpack(C_gpu), stream) + torch.cuda.synchronize() + return C_gpu.cpu() + + +# ===================================================================== +# Test functions +# ===================================================================== + + +def test_ws_ss_tf32(): + print("\n=== Test 1: tcgen05mma_ws_ss_tf32 (weight-stationary, SMEM A × SMEM B, tf32) ===") + torch.manual_seed(42) + A = torch.randn(M_DIM, K_DIM_TF32) + B = torch.randn(K_DIM_TF32, N_DIM) + ref = torch.mm(A, B) + got = _WsSsTf32Kernel().run(A, B) + err = (got - ref).abs() + rel = err.max().item() / (ref.abs().max().item() + 1e-8) + max_idx = err.argmax().item() + mi, mj = max_idx // N_DIM, max_idx % N_DIM + print(f" got[0,:4]={got[0, :4].tolist()}") + print(f" ref[0,:4]={ref[0, :4].tolist()}") + print(f" max_rel_err={rel:.4f} at ({mi},{mj}): got={got[mi, mj]:.6f} ref={ref[mi, mj]:.6f}") + assert rel < 0.02, f"FAIL: rel={rel:.4f}" + print(" PASSED") + + +def test_ws_ss_f16(): + print("\n=== Test 3: tcgen05mma_ws_ss_f16 (weight-stationary, SMEM A × SMEM B, f16) ===") + torch.manual_seed(42) + for N in [64, 128]: + for K in [64, 128]: + print(f" --- N={N}, K={K} ---") + A = torch.randn(M_DIM, K) + B = torch.randn(K, N) + ref = torch.mm(A, B) + got = _WsSsF16Kernel(M_DIM, N, K).run(A, B) + err = (got - ref).abs() + rel = err.max().item() / (ref.abs().max().item() + 1e-8) + max_idx = err.argmax().item() + mi, mj = max_idx // N, max_idx % N + print(f" got[0,:4]={got[0, :4].tolist()}") + print(f" ref[0,:4]={ref[0, :4].tolist()}") + print(f" max_rel_err={rel:.4f} at ({mi},{mj}): got={got[mi, mj]:.6f} ref={ref[mi, mj]:.6f}") + assert rel < 0.02, f"FAIL N={N}, K={K}: rel={rel:.4f}" + print(f" PASSED (N={N}, K={K})") + +def test_ws_ts_f16(): + print("\n=== Test 3: tcgen05mma_ws_ts_f16 (weight-stationary, TMEM A × SMEM B, f16) ===") + torch.manual_seed(42) + for N in [64, 128]: + for K in [64, 128]: + print(f" --- N={N}, K={K} ---") + A = torch.randn(M_DIM, K) + B = torch.randn(K, N) + ref = torch.mm(A, B) + got = _WsTsF16Kernel(M_DIM, N, K).run(A, B) + err = (got - ref).abs() + rel = err.max().item() / (ref.abs().max().item() + 1e-8) + max_idx = err.argmax().item() + mi, mj = max_idx // N, max_idx % N + print(f" got[0,:4]={got[0, :4].tolist()}") + print(f" ref[0,:4]={ref[0, :4].tolist()}") + print(f" max_rel_err={rel:.4f} at ({mi},{mj}): got={got[mi, mj]:.6f} ref={ref[mi, mj]:.6f}") + assert rel < 0.02, f"FAIL N={N}, K={K}: rel={rel:.4f}" + print(f" PASSED (N={N}, K={K})") + + +def test_ws_ss_tf32_collector(): + """Explicit collector_b_buffer=B0, collector_op=DISCARD should match default.""" + print("\n=== Test 5: tcgen05mma_ws_ss_tf32 + collector (B0::DISCARD) ===") + torch.manual_seed(42) + A = torch.randn(M_DIM, K_DIM_TF32) + B = torch.randn(K_DIM_TF32, N_DIM) + ref = torch.mm(A, B) + got = _WsSsTf32CollectorKernel().run(A, B) + err = (got - ref).abs() + rel = err.max().item() / (ref.abs().max().item() + 1e-8) + max_idx = err.argmax().item() + mi, mj = max_idx // N_DIM, max_idx % N_DIM + print(f" got[0,:4]={got[0, :4].tolist()}") + print(f" ref[0,:4]={ref[0, :4].tolist()}") + print(f" max_rel_err={rel:.4f} at ({mi},{mj}): got={got[mi, mj]:.6f} ref={ref[mi, mj]:.6f}") + assert rel < 0.02, f"FAIL: rel={rel:.4f}" + print(" PASSED") + + +if __name__ == "__main__": + test_ws_ss_tf32() + test_ws_ss_tf32_collector() + test_ws_ss_f16() + # test_ws_ts_f16() + print("\n=== All tests passed! ===") From 8fc4a4f4bfe1984f24e27220ae69b1a7c64d11e1 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Tue, 5 May 2026 22:12:26 +0800 Subject: [PATCH 03/26] change kdk compute order, better perf --- cula/ops/chunk_wy_dqkg_sm100.py | 42 ++++++++++++++++----------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 98775a6..f69f037 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -2062,23 +2062,6 @@ def kernel( rKdk = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) rKdk.store(rK_fp32.load() * rDk.load()) - # dgk += sum(kdk, axis=0) - # write kdk to G SMEM then do BT-dim reduce - for i in cutlass.range_constexpr(self.BK // 4 // 4): - col_base = bk_col_base + i * 4 - chunk_kdk = cute.local_tile(rKdk, (4,), (i,)) - smem_store_f32x4_sw128(sG_raw_ptr, row, col_base, chunk_kdk) - self.cuda_wg_sync_barrier.arrive_and_wait() - - if wg_idx == 0: - sum = Float32(0.0) - for r in cutlass.range(self.BT, unroll_full=True): - sum += sG_raw[(r, local_tidx, 0)] - sDgk[(local_tidx, )] += sum - - pipeline_load_g.consumer_release(load_g_consumer_state) - load_g_consumer_state.advance() - # gb = gk_exp * beta[:, None] rGb = cute.make_rmem_tensor((bk_num_cols_per_wg, ), Float32) rGb.store(rG_exp_val * beta_val) @@ -2102,6 +2085,23 @@ def kernel( chunk_dk = subvec(dk_i32_vec, s * 8, 8) store_256b(dk_base_addr + s * 32, chunk_dk) + # dgk += sum(kdk, axis=0) + # write kdk to G SMEM then do BT-dim reduce + for i in cutlass.range_constexpr(self.BK // 4 // 4): + col_base = bk_col_base + i * 4 + chunk_kdk = cute.local_tile(rKdk, (4,), (i,)) + smem_store_f32x4_sw128(sG_raw_ptr, row, col_base, chunk_kdk) + self.cuda_wg_sync_barrier.arrive_and_wait() + + if wg_idx == 0: + sum = Float32(0.0) + for r in cutlass.range(self.BT, unroll_full=True): + sum += sG_raw[(r, local_tidx, 0)] + sDgk[(local_tidx, )] += sum + + pipeline_load_g.consumer_release(load_g_consumer_state) + load_g_consumer_state.advance() + # dg1 = kg * dkgb * beta[:, None], can reuse kg RMEM rDg = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) rDg.store(rKG_val * dkgb_f32_val * beta_val) @@ -2204,9 +2204,6 @@ def kernel( tcgen05_fence_before() cute.arch.fence_view_async_tmem_load() - pipeline_mma_dA2.consumer_release(mma_dA2_consumer_state) - mma_dA2_consumer_state.advance() - pipeline_prologue_dA3.producer_acquire(prologue_dA3_producer_state) # write dA2 to smem notify dA2 = A @ dA2 dA2_f32 = reinterpret_cast(dA2_i32, Int32, bt_num_cols_per_wg, Float32) @@ -2232,6 +2229,9 @@ def kernel( tcgen05_fence_before() cute.arch.fence_view_async_tmem_load() + # release mma dA2 after dA3 is finished, protect DA2 TMEM + pipeline_mma_dA2.consumer_release(mma_dA2_consumer_state) + mma_dA2_consumer_state.advance() pipeline_mma_dA3.consumer_release(mma_dA3_consumer_state) mma_dA3_consumer_state.advance() # NOTE: release smem Q because we reuse to store bf16 dA @@ -2620,8 +2620,6 @@ def kernel( desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) mma_ws_ss_m64n128_k_k_call(vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DQ_ACC_OFF, self.BV, is_accum) - # TODO: should we add tcgen05.commit and mbar.wait to ensure current dq MMA has been finished? - pipeline_load_do.consumer_release(load_do_consumer_state) load_do_consumer_state.advance() From aaebc987baf22acbe12b3d55df5fb68ebfc8bf5e Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Tue, 5 May 2026 22:36:10 +0800 Subject: [PATCH 04/26] change dgk compute order, 1% perf --- cula/ops/chunk_wy_dqkg_sm100.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index f69f037..bb6c37a 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -2053,11 +2053,6 @@ def kernel( else: rDk.fill(Float32(0.0)) - # dgk *= exp2(gn) - self.cuda_wg_sync_barrier.arrive_and_wait() - if wg_idx == 0: - sDgk[(local_tidx, )] *= cute.exp2(sGn[(local_tidx,)], fastmath=self.use_fast_math) - # kdk = k * dk rKdk = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) rKdk.store(rK_fp32.load() * rDk.load()) @@ -2093,6 +2088,11 @@ def kernel( smem_store_f32x4_sw128(sG_raw_ptr, row, col_base, chunk_kdk) self.cuda_wg_sync_barrier.arrive_and_wait() + # dgk *= exp2(gn) + if wg_idx == 0: + sDgk[(local_tidx, )] *= cute.exp2(sGn[(local_tidx,)], fastmath=self.use_fast_math) + + self.cuda_wg_sync_barrier.arrive_and_wait() if wg_idx == 0: sum = Float32(0.0) for r in cutlass.range(self.BT, unroll_full=True): From e19bccf250cd8a2a8fb26fa4dd62dbdc26dbf571 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Wed, 6 May 2026 15:08:50 +0800 Subject: [PATCH 05/26] increase A stage to 2, 12% latency reduction --- cula/ops/chunk_wy_dqkg_sm100.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index bb6c37a..e6d449e 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -463,7 +463,7 @@ def __init__( # V-loop TMA: 2-stage double buffer self.vloop_stage = 2 self.kloop_stage = 1 - self.a_stage = 1 # TODO: increase to 2 + self.a_stage = 2 self.mma_stage = 1 # ── MMA tiler shapes ── @@ -820,7 +820,7 @@ def __call__( dA2post_tiled_mma, self.dApost_tiler, self.io_dtype, - self.a_stage, + self.mma_stage, ) # opB: A K-major [BT,BT] @@ -836,7 +836,7 @@ def __call__( dA3post_tiled_mma, self.dApost_tiler, self.io_dtype, - self.a_stage, + self.mma_stage, ) # --- Epilogue (non-MMA) layouts --- @@ -1436,6 +1436,8 @@ def kernel( sH_ptr_base = storage.buf_h.data_ptr().toint() sDh_ptr_base = storage.buf_dh.data_ptr().toint() vloop_opB_bytes_per_stage = cute.size_in_bytes(self.io_dtype, vloop_opB_smem_no_stage) + sA_ptr_base = storage.buf_A.data_ptr().toint() + A_bytes_per_stage = cute.size_in_bytes(self.io_dtype, A_mn_opA_smem_no_stage) # NOTE: make_umma_smem_desc requires the iterator to carry the swizzle # (and ≥16B alignment). When constructing a tensor over a ComposedLayout @@ -2312,7 +2314,7 @@ def kernel( seq_len = cu_seqlens[(batch_idx + 1,)] - tok_offset sub_seq_len = min(self.BT, seq_len - tile_idx * self.BT) - # Load A, TODO: double-buffer? + # Load A tma_A_v = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_A) tAsA, tAgA = self._tma_partition_A( tma_atom_A, @@ -2577,6 +2579,11 @@ def kernel( zeros8.fill(BFloat16(0.0)) pipeline_load_A.consumer_wait(load_A_consumer_state) + sA_raw_ptr = cute.make_ptr( + self.io_dtype, + sA_ptr_base + a_stage_idx * A_bytes_per_stage, + cute.AddressSpace.smem, + ) if sub_seq_len < self.BT: for i in cutlass.range_constexpr(self.BT // 32): row = i * 32 + lane_idx @@ -2584,7 +2591,7 @@ def kernel( for col in cutlass.range_constexpr(self.BT // 8): # A tile is MN_SW128 in shared memory; use raw swizzled # address stores to avoid layout-coordinate ambiguity. - smem_store_bf16x8_sw128(sA_raw, row, col * 8, zeros8) + smem_store_bf16x8_sw128(sA_raw_ptr, row, col * 8, zeros8) # Make generic-proxy SMEM stores visible to UMMA async-proxy readers. cute.arch.fence_proxy("async.shared", space="cta") self.mma_warp_sync_barrier.arrive_and_wait() @@ -2794,7 +2801,7 @@ def kernel( cute.arch.fence_proxy("async.shared", space="cta") sDA_k_cur = sDA_k[(None, None, None, 0)] - sA_k_cur = sA_k[(None, None, None, 0)] + sA_k_cur = sA_k[(None, None, None, a_stage_idx)] desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDA_k_cur.iterator, sDA_k_cur.layout, "k")) desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sA_k_cur.iterator, sA_k_cur.layout, "k")) desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) @@ -2811,7 +2818,7 @@ def kernel( pipeline_prologue_dA3.consumer_wait(prologue_dA3_consumer_state) cute.arch.fence_proxy("async.shared", space="cta") - sA_mn_cur = sA_mn[(None, None, None, 0)] + sA_mn_cur = sA_mn[(None, None, None, a_stage_idx)] sDA_mn_cur = sDA_mn[(None, None, None, 0)] desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sA_mn_cur.iterator, sA_mn_cur.layout, "mn")) desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDA_mn_cur.iterator, sDA_mn_cur.layout, "mn")) From 3100e856928e611ba8d2eb5ed6ca7a57a949abee Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Wed, 6 May 2026 22:26:44 +0800 Subject: [PATCH 06/26] tune wg sync, 2.7% latency reduction --- cula/ops/chunk_wy_dqkg_sm100.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index e6d449e..92c1450 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -184,7 +184,6 @@ def smem_load_bf16x8_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32): cute.autovec_copy(smem_t, rmem_t) return rmem_t -# TODO: is this bug for K_TILE != 128? @cute.jit def smem_store_bf16x8_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32, data: cute.Tensor): """ @@ -1773,6 +1772,12 @@ def kernel( seq_len = cu_seqlens[(batch_idx + 1,)] - tok_offset sub_seq_len = min(self.BT, seq_len - tile_idx * self.BT) + # NOTE: must sync before next wu_iter's `sDgk[local_tidx] = 0` + # init, otherwise WG0 of next iter may overwrite sDgk while + # WG1 of this iter (row == sub_seq_len - 1 lane) is still + # reading sDgk[col] above. This was the source of the + # non-deterministic dg accuracy bug. + self.cuda_wg_sync_barrier.arrive_and_wait() # fill db, dgk to 0 if local_tidx < self.BT: sDb[local_tidx] = Float32(0.0) @@ -1888,7 +1893,6 @@ def kernel( pipeline_load_g.consumer_wait(load_g_consumer_state) # write to gn sGn[local_tidx] = sG_raw[(sub_seq_len - 1, local_tidx, 0)] - self.cuda_wg_sync_barrier.arrive_and_wait() # row-major load, match TMEM layout rG = cute.make_rmem_tensor((self.BK // 4, ), self.g_dtype) @@ -2133,12 +2137,6 @@ def kernel( for i in cutlass.range_constexpr(bk_num_cols_per_wg): col = bk_col_base + i rDg[i] += sDgk[(col,)] - # NOTE: must sync before next wu_iter's `sDgk[local_tidx] = 0` - # init, otherwise WG0 of next iter may overwrite sDgk while - # WG1 of this iter (row == sub_seq_len - 1 lane) is still - # reading sDgk[col] above. This was the source of the - # non-deterministic dg accuracy bug. - self.cuda_wg_sync_barrier.arrive_and_wait() rDg_val = rDg.load() dg_i32_vec = reinterpret_cast( From 4dc840cfcaff671e1983b03bdc4a24e51e2b9aac Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Thu, 7 May 2026 10:37:26 +0800 Subject: [PATCH 07/26] move dg store to aux warp, 2.8% latency reduction --- cula/ops/chunk_wy_dqkg_sm100.py | 104 ++++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 33 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 92c1450..5c1050b 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -123,9 +123,6 @@ def _exclusive_cumsum(a: list[int]): ELEM_BYTES_BF16 = BFloat16.width // 8 -def make_thread_cooperative_group(size: int): - return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) - # ============================================================ # Helpers: _ir, Float32 conversion # ============================================================ @@ -438,13 +435,13 @@ def __init__( self.BK = 128 # K tiling for V-loop GEMM (single K tile) self.BV = 64 # V tiling for V-loop GEMM (single V tile) - # Warp layout: WG0 (4 warps CudaCore+Store) + WG1 (1 MMA + 1 Load + 2 Aux) + # Warp layout: WG0/WG1 (8 CudaCore warps) + WG2 (MMA/Load/Aux/Store) self.threads_per_warp = 32 self.cuda_warp_ids = (0, 1, 2, 3) # WG0: CudaCore + Store self.cuda2_warp_ids = (4, 5, 6, 7) # WG1: CudaCore + Store self.mma_warp_id = 8 # WG2: MMA dispatch self.load_warp_id = 9 # WG2: TMA Load - self.aux_warp_ids = (10, 11) # WG2: Aux/Load Aux + self.aux_warp_ids = (10, 11) # WG2: Aux/Load/Store Aux self.threads_per_cta = self.threads_per_warp * 12 # 384 threads (3 WGs) self.num_regs_cuda = 208 @@ -495,12 +492,8 @@ def __init__( barrier_id=2, num_threads=self.threads_per_cta, ) - self.mma_warp_sync_barrier = pipeline.NamedBarrier( - barrier_id=3, - num_threads=32, - ) self.cuda_wg_sync_barrier = pipeline.NamedBarrier( - barrier_id=4, + barrier_id=3, num_threads=32 * 8, ) self.buffer_align_bytes = 1024 @@ -1001,6 +994,7 @@ class SharedStorage: bar_prologue_kg: cute.struct.MemRange[Int64, self.kloop_stage * 2] bar_prologue_dA2: cute.struct.MemRange[Int64, self.mma_stage * 2] bar_prologue_dA3: cute.struct.MemRange[Int64, self.mma_stage * 2] + bar_store_dg: cute.struct.MemRange[Int64, self.kloop_stage * 2] # TMEM holding buffer tmem_holding_buf: Int32 # A, stage=1, [BT,BT], 8KB @@ -1296,7 +1290,7 @@ def kernel( barrier_storage=storage.bar_load_g.data_ptr(), num_stages=self.kloop_stage, producer_group=make_thread_cooperative_group(len([self.load_warp_id])), - consumer_group=make_thread_cooperative_group(num_cuda_warps_total), + consumer_group=make_thread_cooperative_group(num_cuda_warps_total + len(self.aux_warp_ids)), tx_count=self.tma_bytes_g, ) pipeline_load_v = pipeline.PipelineTmaAsync.create( @@ -1398,6 +1392,12 @@ def kernel( producer_group=make_thread_cooperative_group(len(self.aux_warp_ids) * 32), consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), ) + pipeline_store_dg = pipeline.PipelineAsync.create( + barrier_storage=storage.bar_store_dg.data_ptr(), + num_stages=self.kloop_stage, + producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), + consumer_group=make_thread_cooperative_group(len(self.aux_warp_ids) * 32), + ) # ===================== TMEM allocation ===================== tmem_alloc_bar = pipeline.NamedBarrier(barrier_id=1, num_threads=self.threads_per_cta) @@ -1740,6 +1740,9 @@ def kernel( prologue_dA3_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_stage ) + store_dg_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kloop_stage + ) wg_idx = tidx // 128 local_tidx = tidx % 128 @@ -2105,9 +2108,6 @@ def kernel( sum += sG_raw[(r, local_tidx, 0)] sDgk[(local_tidx, )] += sum - pipeline_load_g.consumer_release(load_g_consumer_state) - load_g_consumer_state.advance() - # dg1 = kg * dkgb * beta[:, None], can reuse kg RMEM rDg = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) rDg.store(rKG_val * dkgb_f32_val * beta_val) @@ -2138,21 +2138,22 @@ def kernel( col = bk_col_base + i rDg[i] += sDgk[(col,)] - rDg_val = rDg.load() - dg_i32_vec = reinterpret_cast( - rDg_val, Float32, bk_num_cols_per_wg, Int32 - ) - # FIXME: after dg store, severe register spill with worse perf! - dg_base_addr = ( - dg_gmem.iterator - + (tok_offset + tile_idx * self.BT + row) * H * K - + head_idx * K - + bk_col_base - ).toint() + # Stage dg to SMEM first. A dedicated store warp later does + # SMEM -> RMEM -> GMEM with store_256b, keeping GMEM store + # address/vector live ranges out of the high-register CC path. + pipeline_store_dg.producer_acquire(store_dg_producer_state) if row < sub_seq_len: - for s in cutlass.range_constexpr(num_stores_f32): - chunk_dg = subvec(dg_i32_vec, s * 8, 8) - store_256b(dg_base_addr + s * 32, chunk_dg) + for i in cutlass.range_constexpr(bk_num_cols_per_wg // 4): + col_base = bk_col_base + i * 4 + chunk_dg = cute.local_tile(rDg, (4,), (i,)) + smem_store_f32x4_sw128(sG_raw_ptr, row, col_base, chunk_dg) + + cute.arch.fence_acq_rel_cta() + pipeline_store_dg.producer_commit(store_dg_producer_state) + store_dg_producer_state.advance() + + pipeline_load_g.consumer_release(load_g_consumer_state) + load_g_consumer_state.advance() pipeline_mma_dA.consumer_wait(mma_dA_consumer_state) tcgen05_fence_after() @@ -2592,7 +2593,6 @@ def kernel( smem_store_bf16x8_sw128(sA_raw_ptr, row, col * 8, zeros8) # Make generic-proxy SMEM stores visible to UMMA async-proxy readers. cute.arch.fence_proxy("async.shared", space="cta") - self.mma_warp_sync_barrier.arrive_and_wait() for v_iter in cutlass.range(self.num_v_tiles): is_accum = False if v_iter == 0 else True @@ -2611,7 +2611,6 @@ def kernel( # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sDo_raw_ptr, row, col * 8, zeros8) cute.arch.fence_proxy("async.shared", space="cta") - self.mma_warp_sync_barrier.arrive_and_wait() if v_iter == 0: pipeline_mma_dq.producer_acquire(mma_dq_producer_state) @@ -2646,7 +2645,6 @@ def kernel( # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sDv_raw, row, col * 8, zeros8) cute.arch.fence_proxy("async.shared", space="cta") - self.mma_warp_sync_barrier.arrive_and_wait() # if lane_idx == 0: # cute.printf("V_iter", v_iter) @@ -2687,7 +2685,7 @@ def kernel( # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sV_raw, row, col * 8, zeros8) cute.arch.fence_proxy("async.shared", space="cta") - self.mma_warp_sync_barrier.arrive_and_wait() + if v_iter == 0: pipeline_mma_dA.producer_acquire(mma_dA_producer_state) @@ -2727,7 +2725,6 @@ def kernel( # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sDvnew_raw_ptr, row, col * 8, zeros8) cute.arch.fence_proxy("async.shared", space="cta") - self.mma_warp_sync_barrier.arrive_and_wait() pipeline_load_dh.consumer_wait(load_dh_consumer_state) if v_iter == 0: @@ -2842,6 +2839,12 @@ def kernel( load_beta_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, 1 ) + load_g_store_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kloop_stage + ) + store_dg_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kloop_stage + ) for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x @@ -2865,6 +2868,41 @@ def kernel( pipeline_load_beta.producer_commit(load_beta_producer_state) load_beta_producer_state.advance() + pipeline_load_g.consumer_wait(load_g_store_consumer_state) + pipeline_store_dg.consumer_wait(store_dg_consumer_state) + + store_lane_row = tidx >> Int32(4) # 0..3 + store_col_base = (tidx & Int32(15)) * Int32(8) # 0,8,...,120 + for row_quad in cutlass.range_constexpr(self.BT // 4): + store_row = row_quad * 4 + store_lane_row + if store_row < sub_seq_len: + vals0 = smem_load_f32x4_sw128(sG_raw_ptr, store_row, store_col_base) + vals1 = smem_load_f32x4_sw128(sG_raw_ptr, store_row, store_col_base + Int32(4)) + dg_store_rmem = cute.make_rmem_tensor((8,), Float32) + dg_store_rmem[0] = vals0[0] + dg_store_rmem[1] = vals0[1] + dg_store_rmem[2] = vals0[2] + dg_store_rmem[3] = vals0[3] + dg_store_rmem[4] = vals1[0] + dg_store_rmem[5] = vals1[1] + dg_store_rmem[6] = vals1[2] + dg_store_rmem[7] = vals1[3] + dg_store_i32_vec = reinterpret_cast( + dg_store_rmem.load(), Float32, 8, Int32 + ) + dg_base_addr = ( + dg_gmem.iterator + + (tok_offset + tile_idx * self.BT + store_row) * H * K + + head_idx * K + + store_col_base + ).toint() + store_256b(dg_base_addr, dg_store_i32_vec) + + pipeline_store_dg.consumer_release(store_dg_consumer_state) + store_dg_consumer_state.advance() + pipeline_load_g.consumer_release(load_g_store_consumer_state) + load_g_store_consumer_state.advance() + # ===================== TMEM cleanup ===================== tmem.relinquish_alloc_permit() self.tmem_dealloc_sync_barrier.arrive_and_wait() From 0c0fa34ac283f01c44f80af5afa670b3e6758739 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Thu, 7 May 2026 11:06:07 +0800 Subject: [PATCH 08/26] add tma store for non-tail chunk --- cula/ops/chunk_wy_dqkg_sm100.py | 93 +++++++++++++++++++++++---------- 1 file changed, 66 insertions(+), 27 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 5c1050b..316eed4 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -742,6 +742,7 @@ def __call__( # ===================== SMEM layouts ===================== tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() # SS opA layout: do/vnew/dv [BT,BV]=[64,64] K-major vloop_opA_smem = sm100_utils.make_smem_layout_a( @@ -853,6 +854,13 @@ def __call__( 1, ) + dg_epi_smem_layout = sm100_utils.make_smem_layout_epi( + self.g_dtype, + utils.LayoutEnum.ROW_MAJOR, + (self.BT, self.BK), + self.kloop_stage, + ) + # ===================== Cluster layout ===================== cluster_layout = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), @@ -953,6 +961,14 @@ def __call__( (self.BT, self.BK), ) + dg_epi_smem_no_stage = cute.select(dg_epi_smem_layout, mode=[0, 1]) + tma_atom_dg, tma_tensor_dg = cpasync.make_tiled_tma_atom( + tma_store_op, + dg, + dg_epi_smem_no_stage, + (self.BT, self.BK), + ) + # ===================== TMA byte counts ===================== self.tma_bytes_A = cute.size_in_bytes(self.io_dtype, A_mn_opA_smem_no_stage) self.tma_bytes_dv = cute.size_in_bytes(self.io_dtype, vloop_opA_smem_no_stage) @@ -1107,6 +1123,8 @@ class SharedStorage: tma_tensor_vnew, tma_atom_q, tma_tensor_q, + tma_atom_dg, + tma_tensor_dg, # SMEM layouts vloop_opA_smem, vloop_opB_smem, @@ -1167,6 +1185,8 @@ def kernel( tma_tensor_vnew: cute.Tensor, tma_atom_q: cute.CopyAtom, tma_tensor_q: cute.Tensor, + tma_atom_dg: cute.CopyAtom, + tma_tensor_dg: cute.Tensor, # SMEM layouts vloop_opA_smem: cute.ComposedLayout, vloop_opB_smem: cute.ComposedLayout, @@ -2148,7 +2168,7 @@ def kernel( chunk_dg = cute.local_tile(rDg, (4,), (i,)) smem_store_f32x4_sw128(sG_raw_ptr, row, col_base, chunk_dg) - cute.arch.fence_acq_rel_cta() + cute.arch.fence_proxy("async.shared", space="cta") pipeline_store_dg.producer_commit(store_dg_producer_state) store_dg_producer_state.advance() @@ -2871,32 +2891,51 @@ def kernel( pipeline_load_g.consumer_wait(load_g_store_consumer_state) pipeline_store_dg.consumer_wait(store_dg_consumer_state) - store_lane_row = tidx >> Int32(4) # 0..3 - store_col_base = (tidx & Int32(15)) * Int32(8) # 0,8,...,120 - for row_quad in cutlass.range_constexpr(self.BT // 4): - store_row = row_quad * 4 + store_lane_row - if store_row < sub_seq_len: - vals0 = smem_load_f32x4_sw128(sG_raw_ptr, store_row, store_col_base) - vals1 = smem_load_f32x4_sw128(sG_raw_ptr, store_row, store_col_base + Int32(4)) - dg_store_rmem = cute.make_rmem_tensor((8,), Float32) - dg_store_rmem[0] = vals0[0] - dg_store_rmem[1] = vals0[1] - dg_store_rmem[2] = vals0[2] - dg_store_rmem[3] = vals0[3] - dg_store_rmem[4] = vals1[0] - dg_store_rmem[5] = vals1[1] - dg_store_rmem[6] = vals1[2] - dg_store_rmem[7] = vals1[3] - dg_store_i32_vec = reinterpret_cast( - dg_store_rmem.load(), Float32, 8, Int32 - ) - dg_base_addr = ( - dg_gmem.iterator - + (tok_offset + tile_idx * self.BT + store_row) * H * K - + head_idx * K - + store_col_base - ).toint() - store_256b(dg_base_addr, dg_store_i32_vec) + tma_dg_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_dg) + tDGsDG, tDGgDG = self._epilog_partition_varlen( + tma_atom_dg, + tma_dg_v[None, None, (head_idx, Int32(0))], + (self.BT, self.BK), + sG_raw, + ) + if sub_seq_len < self.BT: + # Tail chunk, direct store + store_lane_row = tidx >> Int32(4) # 0..3 + store_col_base = (tidx & Int32(15)) * Int32(8) # 0,8,...,120 + for row_quad in cutlass.range_constexpr(self.BT // 4): + store_row = row_quad * 4 + store_lane_row + if store_row < sub_seq_len: + vals0 = smem_load_f32x4_sw128(sG_raw_ptr, store_row, store_col_base) + vals1 = smem_load_f32x4_sw128(sG_raw_ptr, store_row, store_col_base + Int32(4)) + dg_store_rmem = cute.make_rmem_tensor((8,), Float32) + dg_store_rmem[0] = vals0[0] + dg_store_rmem[1] = vals0[1] + dg_store_rmem[2] = vals0[2] + dg_store_rmem[3] = vals0[3] + dg_store_rmem[4] = vals1[0] + dg_store_rmem[5] = vals1[1] + dg_store_rmem[6] = vals1[2] + dg_store_rmem[7] = vals1[3] + dg_store_i32_vec = reinterpret_cast( + dg_store_rmem.load(), Float32, 8, Int32 + ) + dg_base_addr = ( + dg_gmem.iterator + + (tok_offset + tile_idx * self.BT + store_row) * H * K + + head_idx * K + + store_col_base + ).toint() + store_256b(dg_base_addr, dg_store_i32_vec) + else: + # Non-tail chunk, TMA store + cute.arch.fence_proxy("async.shared", space="cta") + cute.copy( + tma_atom_dg, + tDGsDG[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 + tDGgDG[(None, tile_idx, 0)], + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) pipeline_store_dg.consumer_release(store_dg_consumer_state) store_dg_consumer_state.advance() From cf698d9dfd3c960c356d9a9b02a99e74eeb9c262 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Thu, 7 May 2026 11:39:13 +0800 Subject: [PATCH 09/26] store dq to tmem to reduce reg spill, 6.8% latency reduction --- cula/ops/chunk_wy_dqkg_sm100.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 316eed4..1e6ccd3 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -77,6 +77,7 @@ def _exclusive_cumsum(a: list[int]): TMEM_A_BF16_OFF = 256 # [256,272) 16 cols A_bf16 TS opA (persistent) TMEM_DKGB_ACC_OFF = 272 # [272,336) 64 cols, dkgb fp32 acc TMEM_DA2_ACC_OFF = 336 # [336,368) 32 cols dA fp32 acc, used for dA=dA@A and dA=A@dA +TMEM_DQ_SCALED_OFF = 368 # [368,432) 64 cols dq_scaled (stored for dg) TMEM_TOTAL = 512 # Instruction descriptor for M=64, N=64, BF16, dense, TransposeB=1 @@ -1948,11 +1949,13 @@ def kernel( rDq = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) rDq.store(dq_f32_val * rG_exp_val * Float32(self.scale)) - # TODO: store to smem first to reduce register usage dq_f32_val_store = rDq.load() dq_i32_vec = reinterpret_cast( dq_f32_val_store, Float32, bk_num_cols_per_wg, Int32 ) + # store to TMEM first to reduce register usage + tcgen05_st_32x32b(bk_num_cols_per_wg, TMEM_DQ_SCALED_OFF + wg_idx * bk_num_cols_per_wg, dq_i32_vec) + cute.arch.fence_view_async_tmem_store() dq_base_addr = ( dq_gmem.iterator + (tok_offset + tile_idx * self.BT + row) * H * K @@ -2149,7 +2152,11 @@ def kernel( rQ[i * 8 + 7] = vals[7] else: rQ.fill(BFloat16(0.0)) - rDg.store(rQ.load().to(Float32) * dq_f32_val_store + rDg.load() - rKdk.load()) + dq_scaled_i32 = tcgen05_ld_32x32b(bk_num_cols_per_wg, TMEM_DQ_SCALED_OFF + wg_idx * bk_num_cols_per_wg) + cute.arch.fence_view_async_tmem_load() + dq_scaled_f32 = reinterpret_cast(dq_scaled_i32, Int32, bk_num_cols_per_wg, Float32) + dq_scaled_f32_val = TensorSSA(dq_scaled_f32, (bk_num_cols_per_wg,), Float32) + rDg.store(rQ.load().to(Float32) * dq_scaled_f32_val + rDg.load() - rKdk.load()) self.cuda_wg_sync_barrier.arrive_and_wait() # dg = dg2 + m_last * dgk, GMEM store dg From 84d5bea47e3196e9aff2929077b9cc435c25c5be Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Thu, 7 May 2026 15:36:51 +0800 Subject: [PATCH 10/26] add more nan/inf tests --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py index e00c4ed..2a857d1 100644 --- a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -178,6 +178,18 @@ def check_determinism(H=4, total_T=2001, num_seqs=4, iters=1000, beta_dtype=DTYP assert torch.equal(dA_out, ref_dA), f"dA mismatch at iter {i}" # NOTE: for db, kernel uses atomic add which can cause non-determinism, so we use a looser check here torch.testing.assert_close(db_out, ref_db, rtol=1e-5, atol=1e-5), f"db mismatch at iter {i}" + assert torch.isnan(dq_out).sum() == 0, f"dq contains NaNs at iter {i}" + assert torch.isnan(dk_out).sum() == 0, f"dk contains NaNs at iter {i}" + assert torch.isnan(dv_out).sum() == 0, f"dv contains NaNs at iter {i}" + assert torch.isnan(db_out).sum() == 0, f"db contains NaNs at iter {i}" + assert torch.isnan(dg_out).sum() == 0, f"dg contains NaNs at iter {i}" + assert torch.isnan(dA_out).sum() == 0, f"dA contains NaNs at iter {i}" + assert torch.isfinite(dq_out).all(), f"dq contains infs at iter {i}" + assert torch.isfinite(dk_out).all(), f"dk contains infs at iter {i}" + assert torch.isfinite(dv_out).all(), f"dv contains infs at iter {i}" + assert torch.isfinite(db_out).all(), f"db contains infs at iter {i}" + assert torch.isfinite(dg_out).all(), f"dg contains infs at iter {i}" + assert torch.isfinite(dA_out).all(), f"dA contains infs at iter {i}" return True From 6fefaad2d563de38eb42e00b2ac8d025e0af4100 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Thu, 7 May 2026 17:45:21 +0800 Subject: [PATCH 11/26] remove TS ws mode MMA --- cula/ops/intrinsics_sm100.py | 40 ++++++ tests/test_ptx_umma_ws.py | 266 +---------------------------------- 2 files changed, 43 insertions(+), 263 deletions(-) diff --git a/cula/ops/intrinsics_sm100.py b/cula/ops/intrinsics_sm100.py index e1352f2..087e5c6 100644 --- a/cula/ops/intrinsics_sm100.py +++ b/cula/ops/intrinsics_sm100.py @@ -347,6 +347,46 @@ def _do(addr_val, desc_val, *, loc=None, ip=None): _do(Int32(taddr), smem_desc.desc_i64[0]) +@cute.jit +def tcgen05_cp_128x128b(taddr: int, smem_desc: Tcgen05SmemDescriptor): + """Async copy SMEM → TMEM with shape ``128x128b`` (``cta_group::1``). + + Issues ``tcgen05.cp.cta_group::1.128x128b [taddr], s-desc;`` + via the native ``nvvm.tcgen05.cp`` MLIR op. + + The instruction copies a 128-row × 128-bit tile from shared memory + (described by *smem_desc*) into Tensor Memory at *taddr*. The copy + is **asynchronous** — use ``tcgen05.commit`` + ``mbarrier.wait`` to + synchronize. + + PTX reference + ------------- + tcgen05.cp.cta_group::1.128x128b [taddr], s-desc; + + Parameters + ---------- + taddr : int + TMEM destination address (uint32, passed as ``!llvm.ptr<6>``). + smem_desc : Tcgen05SmemDescriptor + 64-bit SMEM matrix descriptor (same format as ``tcgen05.mma`` + descriptors — see ``Tcgen05SmemDescriptor``). + """ + + @dsl_user_op + def _do(addr_val, desc_val, *, loc=None, ip=None): + ptr6_ty = llvm.PointerType.get(address_space=6) + tmem_ptr = llvm.inttoptr(ptr6_ty, _to_ir(addr_val, loc, ip), loc=loc, ip=ip) + _nvvm.tcgen05_cp( + shape=_nvvm.Tcgen05CpShape.SHAPE_128x128b, + taddr=tmem_ptr, + smem_desc=_to_ir(desc_val, loc, ip), + cta_group=_nvvm.Tcgen05GroupKind.CTA_1, + loc=loc, + ip=ip, + ) + + _do(Int32(taddr), smem_desc.desc_i64[0]) + @cute.jit def tcgen05_fence_before(): """tcgen05.fence::before_thread_sync — non-blocking ordering fence.""" diff --git a/tests/test_ptx_umma_ws.py b/tests/test_ptx_umma_ws.py index ea9acf0..b130740 100644 --- a/tests/test_ptx_umma_ws.py +++ b/tests/test_ptx_umma_ws.py @@ -47,10 +47,7 @@ from cula.ops.intrinsics_sm100 import ( store_256b, subvec, - tcgen05_cp_128x256b, - reinterpret_cast, tcgen05_ld_32x32b, - tcgen05_st_32x32b, ) from cula.ops.ptx_umma_ext import ( CollectorBBuffer, @@ -58,7 +55,6 @@ Tcgen05SmemDescriptor, tcgen05mma_ws_ss_f16, tcgen05mma_ws_ss_tf32, - tcgen05mma_ws_ts_f16, ) M_DIM, N_DIM = 64, 64 @@ -234,32 +230,7 @@ def run(self, A_cpu, B_cpu): # ===================================================================== -# Test 2: tcgen05mma_ws_ts_tf32 (weight-stationary, TMEM A, SMEM B, tf32) -# ===================================================================== - -# TODO -class _WsTsTf32Kernel: - """Two-step test: - - Region 0 (tmem_a_region): populate A into TMEM via S2T copy - - Region 1 (tmem_c_region): result of WS TS MMA (tmem_a × B → C) - """ - - @cute.kernel - def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): - pass - - @cute.jit - def _launch(self, A: cute.Tensor, B: cute.Tensor, C: cute.Tensor, stream): - self.kernel(A, B, C).launch(grid=(1, 1, 1), block=(128, 1, 1), stream=stream) - - def run(self, A_cpu, B_cpu): - raise NotImplementedError( - "TODO: implement this test following the pattern of _WsSsTf32Kernel, but with the two-step approach described in the docstring" - ) - - -# ===================================================================== -# Test 3: tcgen05mma_ws_ss_f16 (weight-stationary, SMEM A, SMEM B, f16) +# Test 2: tcgen05mma_ws_ss_f16 (weight-stationary, SMEM A, SMEM B, f16) # ===================================================================== @@ -437,216 +408,7 @@ def run(self, A_cpu, B_cpu): # ===================================================================== -# Test 4: tcgen05mma_ws_ts_f16 (weight-stationary, TMEM A, SMEM B, f16) -# ===================================================================== - - -# TODO: support with first S2R then R2T with tcgen05.st -class _WsTsF16Kernel: - """Two-step test (same strategy as _WsTsTf32Kernel but with kind::f16): - - Region 0 (tmem_a_region): populate A into TMEM via S2T copy - - Region 1 (tmem_c_region): result of WS TS MMA (tmem_a × B → C) - """ - def __init__(self, M: int, N: int, K: int): - self.M = M - self.N = N - self.K = K - if N == 64: - self.idesc = IDESC_F16_M64_N64 - elif N == 128: - self.idesc = IDESC_F16_M64_N128 - else: - raise ValueError(f"Unsupported N={N} for F16 IDESC (expected 64 or 128)") - self.wg_sync_barrier = pipeline.NamedBarrier( - barrier_id=3, - num_threads=128, - ) - - @cute.kernel - def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): - M, N, K = self.M, self.N, self.K - idesc = self.idesc - ACC_NUM_COLS = N // 2 - # Layout E (M=64 + .ws mode) for A operand in TMEM: - # - 4 warps × 16 active lanes (lane<16) cover M=64 rows - # warp w → rows [w*16 .. w*16+15] - # - Each active lane stores its full K-row as bf16 pairs in 32-bit cells - # - Per K-atom (16 BF16): 8 TMEM cols (16 active lanes × 8 dwords = 128 BF16 = 64 rows × 2 BF16-pairs) - # wait: 16 lanes × 8 cols × 2 bf16/col = 256 BF16 per warp per K-atom; 4 warps × 256 = 1024 BF16 = 64×16 ✓ - # - Total cols for A^T[64, K]: K/2 cols - # ref: chunk_kda_bwd_sm100_tmem.cuh::pack_AT_to_tmem_bf16, FlashMLA SM100_UTCCP_128dp256bit - OPA_NUM_COLS = K // 2 # 8 cols per K-atom of 16 BF16 - TMEM_OPA = 0 - TMEM_ACC = TMEM_OPA + OPA_NUM_COLS - tidx, _, _ = cute.arch.thread_idx() - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - lane_idx = tidx % 32 - - smem = utils.SmemAllocator() - tmem_hold_ptr = smem.allocate(Int32) - mbar_ptr = smem.allocate(Int64, byte_alignment=8) - - # Create MMA SMEM Layouts - # NOTE: we use non-ws mode TiledMMA for creating smem layout in a easy way, - # because smem layouts of ws mode and non-ws mode are the same - mma_tiler = (M, N, K) - non_ws_tiled_mma = sm100_utils.make_trivial_tiled_mma( - BFloat16, - tcgen05.OperandMajorMode.K, - tcgen05.OperandMajorMode.MN, - Float32, - tcgen05.CtaGroup.ONE, - (M, N), - ) - a_smem_layout = sm100_utils.make_smem_layout_a(non_ws_tiled_mma, mma_tiler, BFloat16, 1) - b_smem_layout = sm100_utils.make_smem_layout_b(non_ws_tiled_mma, mma_tiler, BFloat16, 1) - bufferA = smem.allocate_tensor( - element_type=BFloat16, - layout=a_smem_layout.outer, - byte_alignment=128, - swizzle=a_smem_layout.inner, - ) - - bufferB = smem.allocate_tensor( - element_type=BFloat16, - layout=b_smem_layout.outer, - byte_alignment=128, - swizzle=b_smem_layout.inner, - ) - - bufA_s0 = bufferA[(None, None, None, 0)] - bufB_s0 = bufferB[(None, None, None, 0)] - - if tidx == cutlass.Int32(0): - mbarrier_init(mbar_ptr, 1) - mbarrier_init_fence() - - # Load A (row-major input → K-major swizzled SMEM) - gA_flat = cute.make_tensor(A_in.iterator, cute.make_layout(M * K)) - gB_flat = cute.make_tensor(B_in.iterator, cute.make_layout(K * N)) - - for step in cutlass.range(M * K // 128, unroll_full=False): - smem_idx = tidx + step * 128 - m = smem_idx % M - k = smem_idx // M - bufA_s0[smem_idx] = gA_flat[m * K + k] - for step in cutlass.range(K * N // 128, unroll_full=False): - idx = tidx + step * 128 - bufB_s0[idx] = gB_flat[idx] - sync_threads() - - # --- TMEM allocation --- - alloc_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=128) - tmem = utils.TmemAllocator( - tmem_hold_ptr, - barrier_for_retrieve=alloc_bar, - allocator_warp_id=0, - ) - tmem.allocate(512) - tmem.wait_for_alloc() - tmem_ptr_f32 = tmem.retrieve_ptr(Float32) - - rA = cute.make_rmem_tensor((self.K, ), BFloat16) - # Layout E A-operand pack: warp w lane l (l<16) → A^T row = w*16 + l, full K-row. - # Inactive lanes (l>=16) store zeros (warp-collective tcgen05.st requires all 32 lanes). - row_active = warp_idx * 16 + lane_idx # only valid when lane_idx < 16 - is_active = lane_idx < cutlass.Int32(16) - for j in cutlass.range_constexpr(self.K): - if is_active: - rA[j] = bufA_s0[row_active * self.K + j] - else: - rA[j] = BFloat16(0.0) - - rA_val = rA.load() - rA_i32_val = reinterpret_cast(rA_val, BFloat16, self.K, Int32) - # Store K/2 dwords per active lane → K/2 TMEM cols per warp subpart - tcgen05_st_32x32b(self.K // 2, TMEM_OPA, rA_i32_val) - cute.arch.fence_view_async_tmem_store() - - self.wg_sync_barrier.arrive_and_wait() - - # Build SMEM descriptors (rank-2 vec_mode layout required) - desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(bufB_s0.iterator, bufB_s0.layout, "mn")) - desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) - tmem_a_base = TMEM_OPA - - # Per-K-atom byte offsets are derived from the (unswizzled) outer layout - # so we transparently handle every K size: - # K∈{16,32,64} → A k_iter is single-mode, uniform stride - # K≥128 → A k_iter is hierarchical e.g. (4,K/64):(16,4096) - # B is always uniform stride=1024 elem - # Coord ((0,0), 0, ks, 0) into outer layout gives the linear elem offset - # of the ks-th MMA-K atom; * sizeof(elem) → byte offset to add to desc. - ELEM_BYTES_F16 = BFloat16.width // 8 - b_outer = b_smem_layout.outer - - # Issue WS SS MMA (scale_out=0 → D = A*B, not accumulate) - if warp_idx == cutlass.Int32(0): - with elect_one(): - for ks in cutlass.range_constexpr(K // 16): - scale = 0 if ks == 0 else 1 - b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_F16 - desc_b = desc_b_base + b_off - tmem_a = tmem_a_base + ks * 8 # 8 TMEM cols per K-atom of 16 BF16 - tcgen05mma_ws_ts_f16(tmem_a, desc_b, TMEM_ACC, idesc, scale) - tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) - mbarrier_wait(mbar_ptr, 0) - sync_threads() - - # T2R - # Layout E (M=64, ws mode): 128 lanes, 32 columns - # ref: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e - # .32x32b.x32 loads all 32 columns → 32 FP32 regs per thread - # Layout: warp0->(M0,N0), warp1->(M0,N1), warp2->(M1,N0), warp3->(M1,N1) - # for 64x64 Acc, each warp process 32x32, with 128 lanes in TMEM all used - - vec_i32 = tcgen05_ld_32x32b(ACC_NUM_COLS, TMEM_ACC) - cute.arch.fence_view_async_tmem_load() - - # =======DEBUG======== - # # 1. reinterpret_cast to f32 (zero-cost bitcast) - # vec_f32 = reinterpret_cast(vec_i32, Int32, ACC_NUM_COLS, Float32) - - # # 2. TensorSSA wrap → .to(BFloat16) (real CUDA core CVT) - # regs = TensorSSA(vec_f32, (ACC_NUM_COLS,), Float32) - - # # Debug print: thread 0, first 4 register values - # if tidx == cutlass.Int32(0): - # cute.printf("[T2R] tid=0, regs[0..3] = %f, %f, %f, %f", - # regs[0], regs[1], regs[2], regs[3]) - - # R2G via store_256b (4 × 256-bit stores per thread) - # Layout E C-out (column-major warp order): - # warp0->(M0,N0), warp1->(M1,N0), warp2->(M0,N1), warp3->(M1,N1) - col_base = (warp_idx // 2) * ACC_NUM_COLS # N0 or N1 - out_row = (warp_idx % 2) * (M // 2) + lane_idx # M0 or M1, full 32 lanes used for C - # 32 regs = 4 chunks of 8 × 32-bit each (256 bits) - base_addr = (C_out.iterator + out_row * N + col_base).toint() - for chunk in cutlass.range_constexpr(ACC_NUM_COLS // 8): - store_256b(base_addr + chunk * 32, subvec(vec_i32, chunk * 8, 8)) - - sync_threads() - tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr_f32, 512) - - @cute.jit - def _launch(self, A: cute.Tensor, B: cute.Tensor, C: cute.Tensor, stream): - self.kernel(A, B, C).launch(grid=(1, 1, 1), block=(128, 1, 1), stream=stream) - - def run(self, A_cpu, B_cpu): - M, N = self.M, self.N - A_gpu = A_cpu.cuda().to(torch.bfloat16).contiguous() - B_gpu = B_cpu.cuda().to(torch.bfloat16).contiguous() - C_gpu = torch.zeros(M, N, dtype=torch.float32, device="cuda") - stream = cutlass_torch.default_stream() - self._launch(from_dlpack(A_gpu), from_dlpack(B_gpu), from_dlpack(C_gpu), stream) - torch.cuda.synchronize() - return C_gpu.cpu() - - -# ===================================================================== -# Test 5: tcgen05mma_ws_ss_tf32 with explicit collector_b_buffer/collector_op +# Test 3: tcgen05mma_ws_ss_tf32 with explicit collector_b_buffer/collector_op # ===================================================================== @@ -800,30 +562,9 @@ def test_ws_ss_f16(): assert rel < 0.02, f"FAIL N={N}, K={K}: rel={rel:.4f}" print(f" PASSED (N={N}, K={K})") -def test_ws_ts_f16(): - print("\n=== Test 3: tcgen05mma_ws_ts_f16 (weight-stationary, TMEM A × SMEM B, f16) ===") - torch.manual_seed(42) - for N in [64, 128]: - for K in [64, 128]: - print(f" --- N={N}, K={K} ---") - A = torch.randn(M_DIM, K) - B = torch.randn(K, N) - ref = torch.mm(A, B) - got = _WsTsF16Kernel(M_DIM, N, K).run(A, B) - err = (got - ref).abs() - rel = err.max().item() / (ref.abs().max().item() + 1e-8) - max_idx = err.argmax().item() - mi, mj = max_idx // N, max_idx % N - print(f" got[0,:4]={got[0, :4].tolist()}") - print(f" ref[0,:4]={ref[0, :4].tolist()}") - print(f" max_rel_err={rel:.4f} at ({mi},{mj}): got={got[mi, mj]:.6f} ref={ref[mi, mj]:.6f}") - assert rel < 0.02, f"FAIL N={N}, K={K}: rel={rel:.4f}" - print(f" PASSED (N={N}, K={K})") - - def test_ws_ss_tf32_collector(): """Explicit collector_b_buffer=B0, collector_op=DISCARD should match default.""" - print("\n=== Test 5: tcgen05mma_ws_ss_tf32 + collector (B0::DISCARD) ===") + print("\n=== Test 2: tcgen05mma_ws_ss_tf32 + collector (B0::DISCARD) ===") torch.manual_seed(42) A = torch.randn(M_DIM, K_DIM_TF32) B = torch.randn(K_DIM_TF32, N_DIM) @@ -844,5 +585,4 @@ def test_ws_ss_tf32_collector(): test_ws_ss_tf32() test_ws_ss_tf32_collector() test_ws_ss_f16() - # test_ws_ts_f16() print("\n=== All tests passed! ===") From c740668b0120c2867192f1c78128a809072b0a90 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Fri, 8 May 2026 22:37:18 +0800 Subject: [PATCH 12/26] support GVA for wy_dqkg --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 32 ++- benchmarks/utils.py | 33 ++- cula/ops/chunk_wy_dqkg_sm100.py | 285 +++++++++++----------- 3 files changed, 193 insertions(+), 157 deletions(-) diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py index 2a857d1..b61c47a 100644 --- a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -151,8 +151,10 @@ def run_cutedsl(inputs: dict): chunk_indices=inputs["chunk_indices"], ) -def check_determinism(H=4, total_T=2001, num_seqs=4, iters=1000, beta_dtype=DTYPE): +def check_determinism(H=4, HV=None, total_T=2001, num_seqs=4, iters=1000, beta_dtype=DTYPE): """Verify deterministic outputs across repeated runs.""" + if HV is None: + HV = H torch.manual_seed(42) seq_lens = generate_balanced_seqlens(total_T, num_seqs) cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=DEVICE) @@ -162,6 +164,7 @@ def check_determinism(H=4, total_T=2001, num_seqs=4, iters=1000, beta_dtype=DTYP H=H, K=K, V=V, + HV=HV, chunk_size=BT, device=DEVICE, seed=SEED, @@ -196,9 +199,11 @@ def check_determinism(H=4, total_T=2001, num_seqs=4, iters=1000, beta_dtype=DTYP # ============================================================ # Fixed-length benchmark # ============================================================ -def bench_fixed(configs, H: int): +def bench_fixed(configs, H: int, HV: int | None = None): + if HV is None: + HV = H print("\n" + "=" * 120) - print(f" Fixed-Length Benchmark: cuLA CuTe DSL vs FLA Triton (H={H}, K={K}, V={V}, BT={BT})") + print(f" Fixed-Length Benchmark: cuLA CuTe DSL vs FLA Triton (H={H}, HV={HV}, K={K}, V={V}, BT={BT})") print("=" * 120) results = [] @@ -215,6 +220,7 @@ def bench_fixed(configs, H: int): H=H, K=K, V=V, + HV=HV, chunk_size=BT, device=DEVICE, seed=SEED, @@ -256,9 +262,11 @@ def bench_fixed(configs, H: int): # ============================================================ # Varlen benchmark # ============================================================ -def bench_varlen(configs, H: int): +def bench_varlen(configs, H: int, HV: int | None = None): + if HV is None: + HV = H print("\n" + "=" * 120) - print(f" Varlen Benchmark: cuLA CuTe DSL vs FLA Triton (H={H}, K={K}, V={V}, BT={BT})") + print(f" Varlen Benchmark: cuLA CuTe DSL vs FLA Triton (H={H}, HV={HV}, K={K}, V={V}, BT={BT})") print("=" * 120) results = [] @@ -275,6 +283,7 @@ def bench_varlen(configs, H: int): H=H, K=K, V=V, + HV=HV, chunk_size=BT, device=DEVICE, seed=SEED, @@ -396,6 +405,12 @@ def main(): default=[H_DEFAULT], help=f"Head counts to benchmark (default: [{H_DEFAULT}])", ) + parser.add_argument( + "--hv", + type=int, + default=None, + help="Number of value heads HV (default: same as H, i.e. no GVA). Must be a multiple of H.", + ) parser.add_argument("--ncu", action="store_true", help="NCU profiling mode: warmup=1, iters=1") args = parser.parse_args() @@ -430,15 +445,16 @@ def main(): ) for H in args.heads: - check_determinism(H=H, iters=10000) + HV = args.hv if args.hv is not None else H + check_determinism(H=H, HV=HV, iters=10000) fixed_res, varlen_res = [], [] if args.mode in ("fixed", "both"): - fixed_res = bench_fixed(fixed_configs, H) + fixed_res = bench_fixed(fixed_configs, H, HV) if args.mode in ("varlen", "both"): - varlen_res = bench_varlen(varlen_configs, H) + varlen_res = bench_varlen(varlen_configs, H, HV) print_report(fixed_res, varlen_res, H) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index b726364..2e5b8b5 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -352,6 +352,7 @@ def prepare_bwd_wy_dqkg_fused_inputs( H: int, K: int, V: int, + HV: int | None = None, chunk_size: int = CHUNK_SIZE, device: torch.device | str = "cuda", seed: int = SEED, @@ -364,9 +365,14 @@ def prepare_bwd_wy_dqkg_fused_inputs( chunk_kda_bwd_wy_dqkg_fused kernels. Follows the same flattening convention used in other prepare_* helpers (B=1 with cu_seqlens for varlen mode). + HV: number of value heads (default: H). Set HV > H for GVA (grouped value attention). + q/k always have H heads; all other tensors use HV heads. + Returns a dict with keys used directly by ``run_fla_triton`` and ``run_cutedsl`` in ``bench_bwd_wy_dqkg_fused.py``. """ + if HV is None: + HV = H BT = chunk_size scale = K**-0.5 @@ -375,22 +381,22 @@ def prepare_bwd_wy_dqkg_fused_inputs( # ---- primary token-indexed tensors ---- q = torch.randn(B, T, H, K, dtype=dtype, device=device) k = torch.randn(B, T, H, K, dtype=dtype, device=device) - v = torch.randn(B, T, H, V, dtype=dtype, device=device) - g_raw = torch.randn(B, T, H, K, dtype=dtype, device=device) - beta = torch.randn(B, T, H, dtype=torch.float, device=device).sigmoid() + v = torch.randn(B, T, HV, V, dtype=dtype, device=device) + g_raw = torch.randn(B, T, HV, K, dtype=dtype, device=device) + beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() # l2norm q, k q, _ = l2norm_fwd(q) k, _ = l2norm_fwd(k) # gate preprocessing - A_log = torch.randn(H, dtype=torch.float, device=device) - dt_bias = torch.randn(H * K, dtype=torch.float, device=device) + A_log = torch.randn(HV, dtype=torch.float, device=device) + dt_bias = torch.randn(HV * K, dtype=torch.float, device=device) - v_new = torch.randn(B, T, H, V, dtype=dtype, device=device) - do = torch.randn(B, T, H, V, dtype=dtype, device=device) - dv = torch.randn(B, T, H, V, dtype=dtype, device=device) - A = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 + v_new = torch.randn(B, T, HV, V, dtype=dtype, device=device) + do = torch.randn(B, T, HV, V, dtype=dtype, device=device) + dv = torch.randn(B, T, HV, V, dtype=dtype, device=device) + A = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1 # ---- chunk-indexed state tensors ---- if cu_seqlens is not None: @@ -401,13 +407,14 @@ def prepare_bwd_wy_dqkg_fused_inputs( NT = (B * T + BT - 1) // BT chunk_indices = None - # h/dh: both FLA Triton and CuTe DSL use bf16 [B, NT, H, K, V] - h = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - dh = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 + # h/dh: both FLA Triton and CuTe DSL use bf16 [B, NT, HV, K, V] + h = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + dh = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 # flatten to batch_size=1 for cu_seqlens compatibility if B != 1: - q, k, v, g_raw, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g_raw, beta)) + q, k = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k)) + v, g_raw, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (v, g_raw, beta)) v_new, do, dv, A = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (v_new, do, dv, A)) h, dh = map(lambda x: rearrange(x, "b nt ... -> 1 (b nt) ..."), (h, dh)) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 1e6ccd3..6d36916 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -504,14 +504,14 @@ def __init__( hardware_info = cutlass.utils.HardwareInfo() self.num_sm = hardware_info.get_device_multiprocessor_count() - def _compute_grid(self, B, T, H, total_nt=None): + def _compute_grid(self, B, T, HV, total_nt=None): """Compute grid dimensions for persistent kernel launch. Grid: (min(num_sm * min_occupancy, total_tiles), 1, 1) Each CTA handles multiple tiles via stride-by-gridDim.x loop. """ assert total_nt is not None - total_tiles = total_nt * H + total_tiles = total_nt * HV grid_x = cutlass.min(Int32(self.num_sm * self.min_occupancy), total_tiles) return (grid_x, Int32(1), Int32(1)) @@ -521,26 +521,26 @@ def __call__( # ── Inputs ── q_in: cute.Tensor, # [B, T, H, K] bf16 k_in: cute.Tensor, # [B, T, H, K] bf16 - v_in: cute.Tensor, # [B, T, H, V] bf16 - v_new_in: cute.Tensor, # [B, T, H, V] bf16 - g_in: cute.Tensor, # [B, T, H, K] fp32 - beta_in: cute.Tensor, # [B, T, H] fp32 - A_in: cute.Tensor, # [B, T, H, BT] bf16 - h_in: cute.Tensor, # [B, NT, H, K, V] bf16 - do_in: cute.Tensor, # [B, T, H, V] bf16 - dh_in: cute.Tensor, # [B, NT, H, K, V] bf16 - dv_in: cute.Tensor, # [B, T, H, V] bf16 + v_in: cute.Tensor, # [B, T, HV, V] bf16 + v_new_in: cute.Tensor, # [B, T, HV, V] bf16 + g_in: cute.Tensor, # [B, T, HV, K] fp32 + beta_in: cute.Tensor, # [B, T, HV] fp32 + A_in: cute.Tensor, # [B, T, HV, BT] bf16 + h_in: cute.Tensor, # [B, NT, HV, K, V] bf16 + do_in: cute.Tensor, # [B, T, HV, V] bf16 + dh_in: cute.Tensor, # [B, NT, HV, K, V] bf16 + dv_in: cute.Tensor, # [B, T, HV, V] bf16 # ── Outputs ── - dq_in: cute.Tensor, # [B, T, H, K] fp32 - dk_in: cute.Tensor, # [B, T, H, K] fp32 - dv2_in: cute.Tensor, # [B, T, H, V] bf16 - dg_in: cute.Tensor, # [B, T, H, K] fp32 - db_in: cute.Tensor, # [B, T, H] fp32 - dA_in: cute.Tensor, # [B, T, H, BT] fp32 + dq_in: cute.Tensor, # [B, T, HV, K] fp32 + dk_in: cute.Tensor, # [B, T, HV, K] fp32 + dv2_in: cute.Tensor, # [B, T, HV, V] bf16 + dg_in: cute.Tensor, # [B, T, HV, K] fp32 + db_in: cute.Tensor, # [B, T, HV] fp32 + dA_in: cute.Tensor, # [B, T, HV, BT] fp32 # ── Metadata ── cu_seqlens_in: cute.Tensor, # [N+1] int32 chunk_indices_in: cute.Tensor, # [NT, 2] int32 - problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + problem_size: tuple[Int32, Int32, Int32, Int32, Int32, Int32], # (B, T, H, HV, K, V) total_nt: Int32, stream, ): @@ -565,7 +565,7 @@ def __call__( cu_seqlens_ptr = cu_seqlens_in.iterator chunk_indices_ptr = chunk_indices_in.iterator - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size BT = self.BT BK = self.BK BV = self.BV @@ -583,10 +583,10 @@ def __call__( q = cute.make_tensor(q_ptr, qk_layout) k = cute.make_tensor(k_ptr, qk_layout) - # v, v_new, do, dv, dv2: (T, V, (H, data_B)) bf16 + # v, v_new, do, dv, dv2: (T, V, (HV, data_B)) bf16 tv_layout = cute.make_layout( - (T, V, (H, data_B)), - stride=(H * V, 1, (V, T * H * V)), + (T, V, (HV, data_B)), + stride=(HV * V, 1, (V, T * HV * V)), ) v = cute.make_tensor(v_ptr, tv_layout) v_new = cute.make_tensor(v_new_ptr, tv_layout) @@ -594,68 +594,68 @@ def __call__( dv = cute.make_tensor(dv_ptr, tv_layout) dv2 = cute.make_tensor(dv2_ptr, tv_layout) - # g: (T, K, (H, data_B)) fp32 + # g: (T, K, (HV, data_B)) fp32 g_layout = cute.make_layout( - (T, K, (H, data_B)), - stride=(H * K, 1, (K, T * H * K)), + (T, K, (HV, data_B)), + stride=(HV * K, 1, (K, T * HV * K)), ) g = cute.make_tensor(g_ptr, g_layout) - # beta: (T, (H, data_B)) fp32 + # beta: (T, (HV, data_B)) fp32 beta_layout = cute.make_layout( - (T, (H, data_B)), - stride=(H, (1, T * H)), + (T, (HV, data_B)), + stride=(HV, (1, T * HV)), ) beta = cute.make_tensor(beta_ptr, beta_layout) - # A: (T, BT, (H, data_B)) bf16 + # A: (T, BT, (HV, data_B)) bf16 a_layout = cute.make_layout( - (T, BT, (H, data_B)), - stride=(H * BT, 1, (BT, T * H * BT)), + (T, BT, (HV, data_B)), + stride=(HV * BT, 1, (BT, T * HV * BT)), ) A = cute.make_tensor(A_ptr, a_layout) # NOTE: for A as operand A, A is loaded as transposed view to do MMA a_t_layout = cute.make_layout( - (BT, T, (H, data_B)), - stride=(1, H * BT, (BT, T * H * BT)), + (BT, T, (HV, data_B)), + stride=(1, HV * BT, (BT, T * HV * BT)), ) A_T = cute.make_tensor(A_ptr, a_t_layout) - # dq, dk: (T, K, (H, data_B)) fp32 + # dq, dk: (T, K, (HV, data_B)) fp32 dqk_layout = cute.make_layout( - (T, K, (H, data_B)), - stride=(H * K, 1, (K, T * H * K)), + (T, K, (HV, data_B)), + stride=(HV * K, 1, (K, T * HV * K)), ) dq = cute.make_tensor(dq_ptr, dqk_layout) dk = cute.make_tensor(dk_ptr, dqk_layout) - # dg: (T, K, (H, data_B)) fp32 + # dg: (T, K, (HV, data_B)) fp32 dg = cute.make_tensor(dg_ptr, dqk_layout) - # db: (T, (H, data_B)) fp32 + # db: (T, (HV, data_B)) fp32 db = cute.make_tensor(db_ptr, beta_layout) - # dA: (T, BT, (H, data_B)) fp32 + # dA: (T, BT, (HV, data_B)) fp32 dA_layout = cute.make_layout( - (T, BT, (H, data_B)), - stride=(H * BT, 1, (BT, T * H * BT)), + (T, BT, (HV, data_B)), + stride=(HV * BT, 1, (BT, T * HV * BT)), ) dA_out = cute.make_tensor(dA_ptr, dA_layout) h_nt_total = NT - # h row-major: (K, V, (h_nt_total, H)) as operand B + # h row-major: (K, V, (h_nt_total, HV)) as operand B h_layout = cute.make_layout( - (K, V, (h_nt_total, H)), - stride=(V, 1, (H * K * V, K * V)), + (K, V, (h_nt_total, HV)), + stride=(V, 1, (HV * K * V, K * V)), ) h = cute.make_tensor(h_ptr, h_layout) dh = cute.make_tensor(dh_ptr, h_layout) # Transposed views for V-loop TMA (data loaded as MMA B-operands): vt_layout = cute.make_layout( - (V, T, (data_B, H)), - stride=(1, H * V, (T * H * V, V)), + (V, T, (data_B, HV)), + stride=(1, HV * V, (T * HV * V, V)), ) v_T = cute.make_tensor(v_ptr, vt_layout) @@ -1092,7 +1092,7 @@ class SharedStorage: chunk_indices = cute.make_tensor(chunk_indices_ptr, cute.make_layout((total_nt, 2), stride=(2, 1))) # ===================== Grid ===================== - grid = self._compute_grid(B, T, H, total_nt=total_nt) + grid = self._compute_grid(B, T, HV, total_nt=total_nt) # ===================== Launch kernel ===================== self.kernel( @@ -1217,9 +1217,9 @@ def kernel( # Metadata cu_seqlens: cute.Tensor, chunk_indices: cute.Tensor, - problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + problem_size: tuple[Int32, Int32, Int32, Int32, Int32, Int32], # (B, T, H, HV, K, V) ): - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size BT = self.BT BK, BV = self.BK, self.BV @@ -1230,7 +1230,7 @@ def kernel( thread_idx = cute.arch.thread_idx()[0] lane_idx = thread_idx % 32 - total_work_units = chunk_indices.layout.shape[0] * H + total_work_units = chunk_indices.layout.shape[0] * HV num_iters = (total_work_units - block_idx_x + grid_dim_x - 1) // grid_dim_x num_cuda_warps = len(self.cuda_warp_ids) @@ -1787,8 +1787,10 @@ def kernel( vloop_stage_idx = 0 for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x - i_t = work_idx // H # chunk index (global) - head_idx = work_idx % H # head index + G = HV // H + i_t = work_idx // HV # chunk index (global) + i_hv = work_idx % HV # value-head index + i_h = i_hv // G # q/k head index # Decode chunk_indices batch_idx = chunk_indices[(i_t, 0)] tile_idx = chunk_indices[(i_t, 1)] @@ -1901,8 +1903,8 @@ def kernel( base_addr = ( dv2_gmem.iterator - + (tok_offset + tile_idx * self.BT + row) * H * V - + head_idx * V + + (tok_offset + tile_idx * self.BT + row) * HV * V + + i_hv * V + v_iter * self.BV + bv_col_base ).toint() @@ -1958,8 +1960,8 @@ def kernel( cute.arch.fence_view_async_tmem_store() dq_base_addr = ( dq_gmem.iterator - + (tok_offset + tile_idx * self.BT + row) * H * K - + head_idx * K + + (tok_offset + tile_idx * self.BT + row) * HV * K + + i_hv * K + bk_col_base ).toint() if row < sub_seq_len: @@ -2062,7 +2064,7 @@ def kernel( self.cuda_wg_sync_barrier.arrive_and_wait() # store db to GMEM if local_tidx < sub_seq_len: - db_gmem[(tok_offset + tile_idx * self.BT + local_tidx, (head_idx, Int32(0)))] = sDb[(local_tidx,)] + db_gmem[(tok_offset + tile_idx * self.BT + local_tidx, (i_hv, Int32(0)))] = sDb[(local_tidx,)] # dk = dk * exp2(gn[None, :] - g) pipeline_mma_dk.consumer_wait(mma_dk_consumer_state) @@ -2103,8 +2105,8 @@ def kernel( # 8 fp32 store each time for store_256b dk_base_addr = ( dk_gmem.iterator - + (tok_offset + tile_idx * self.BT + row) * H * K - + head_idx * K + + (tok_offset + tile_idx * self.BT + row) * HV * K + + i_hv * K + bk_col_base ).toint() if row < sub_seq_len: @@ -2283,8 +2285,8 @@ def kernel( num_stores_dA = bt_num_cols_per_wg // 8 dA_base_addr = ( dA_gmem.iterator - + (tok_offset + tile_idx * self.BT + row) * H * BT - + head_idx * BT + + (tok_offset + tile_idx * self.BT + row) * HV * BT + + i_hv * BT + bt_col_base ).toint() if row < sub_seq_len: @@ -2330,8 +2332,10 @@ def kernel( vloop_stage_idx = 0 for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x - i_t = work_idx // H # chunk index (global) - head_idx = work_idx % H # head index + G = HV // H + i_t = work_idx // HV # chunk index (global) + i_hv = work_idx % HV # value-head index + i_h = i_hv // G # q/k head index # Decode chunk_indices batch_idx = chunk_indices[(i_t, 0)] @@ -2349,7 +2353,7 @@ def kernel( self.dvb_tiler, # [BT, BV, BT] dvb_tiled_mma, Int32(0), - head_idx, + i_hv, ) pipeline_load_A.producer_acquire(load_A_producer_state) cute.copy( @@ -2369,7 +2373,7 @@ def kernel( sH, self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - head_idx, i_t + i_hv, i_t ) pipeline_load_h.producer_acquire(load_h_producer_state) cute.copy( @@ -2387,7 +2391,7 @@ def kernel( sDh, self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - head_idx, i_t + i_hv, i_t ) pipeline_load_dh.producer_acquire(load_dh_producer_state) cute.copy( @@ -2405,7 +2409,7 @@ def kernel( sDo, self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - Int32(0), head_idx, + Int32(0), i_hv, ) pipeline_load_do.producer_acquire(load_do_producer_state) cute.copy( @@ -2423,7 +2427,7 @@ def kernel( sDv, self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - Int32(0), head_idx, + Int32(0), i_hv, ) pipeline_load_dv.producer_acquire(load_dv_producer_state) cute.copy( @@ -2441,7 +2445,7 @@ def kernel( sV, self.dA_vloop_tiler, # [BT, BT, BV] dA_vloop_tiled_mma, - Int32(0), head_idx, + Int32(0), i_hv, ) pipeline_load_v.producer_acquire(load_v_producer_state) cute.copy( @@ -2460,7 +2464,7 @@ def kernel( sVnew, self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - Int32(0), head_idx, + Int32(0), i_hv, ) pipeline_load_vnew.producer_acquire(load_vnew_producer_state) cute.copy( @@ -2477,7 +2481,7 @@ def kernel( tma_g_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_g) tGsG, tGgG = self._epilog_partition_varlen( tma_atom_g, - tma_g_v[None, None, (head_idx, Int32(0))], + tma_g_v[None, None, (i_hv, Int32(0))], (self.BT, self.BK), sG_raw, ) @@ -2494,7 +2498,7 @@ def kernel( tma_k_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_k) tKsK, tKgK = self._epilog_partition_varlen( tma_atom_k, - tma_k_v[None, None, (head_idx, Int32(0))], + tma_k_v[None, None, (i_h, Int32(0))], (self.BT, self.BK), sK_raw, ) @@ -2510,7 +2514,7 @@ def kernel( tma_q_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_q) tQsQ, tQgQ = self._epilog_partition_varlen( tma_atom_q, - tma_q_v[None, None, (head_idx, Int32(0))], + tma_q_v[None, None, (i_h, Int32(0))], (self.BT, self.BK), sQ_raw, ) @@ -2591,8 +2595,10 @@ def kernel( mma_vloop_phase = 0 for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x - i_t = work_idx // H # chunk index (global) - head_idx = work_idx % H # head index + G = HV // H + i_t = work_idx // HV # chunk index (global) + i_hv = work_idx % HV # value-head index (unused in MMA warp) + i_h = i_hv // G # q/k head index (unused in MMA warp) # Decode chunk_indices batch_idx = chunk_indices[(i_t, 0)] @@ -2875,8 +2881,10 @@ def kernel( for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x - i_t = work_idx // H # chunk index (global) - head_idx = work_idx % H # head index + G = HV // H + i_t = work_idx // HV # chunk index (global) + i_hv = work_idx % HV # value-head index + i_h = i_hv // G # q/k head index (unused in aux warp) # Decode chunk_indices batch_idx = chunk_indices[(i_t, 0)] @@ -2888,7 +2896,7 @@ def kernel( pipeline_load_beta.producer_acquire(load_beta_producer_state) beta_f32 = Float32(0.0) if tidx < sub_seq_len: - beta_f32 = Float32(beta_gmem[(tok_offset + tile_idx * self.BT + tidx, (head_idx, Int32(0)))]) + beta_f32 = Float32(beta_gmem[(tok_offset + tile_idx * self.BT + tidx, (i_hv, Int32(0)))]) sBeta[(tidx, )] = beta_f32 cute.arch.fence_proxy("async.shared", space="cta") @@ -2901,7 +2909,7 @@ def kernel( tma_dg_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_dg) tDGsDG, tDGgDG = self._epilog_partition_varlen( tma_atom_dg, - tma_dg_v[None, None, (head_idx, Int32(0))], + tma_dg_v[None, None, (i_hv, Int32(0))], (self.BT, self.BK), sG_raw, ) @@ -2928,8 +2936,8 @@ def kernel( ) dg_base_addr = ( dg_gmem.iterator - + (tok_offset + tile_idx * self.BT + store_row) * H * K - + head_idx * K + + (tok_offset + tile_idx * self.BT + store_row) * HV * K + + i_hv * K + store_col_base ).toint() store_256b(dg_base_addr, dg_store_i32_vec) @@ -3036,7 +3044,7 @@ def _epilog_partition_varlen(self, atom, gC_2d, epi_tile, sC): _bwd_wy_kernel_cache: dict = {} -def _compile_bwd_wy_variant(H, K, V, scale, chunk_size, beta_dtype, use_fast_math): +def _compile_bwd_wy_variant(H, HV, K, V, scale, chunk_size, beta_dtype, use_fast_math): """Compile one ChunkKdaBwdWyDqkgFused kernel variant. Uses make_fake_compact_tensor and make_fake_stream for compilation with @@ -3063,23 +3071,23 @@ def _compile_bwd_wy_variant(H, K, V, scale, chunk_size, beta_dtype, use_fast_mat # varlen: data tensors are [1, T_total, H, ...] q_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) k_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) - v_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) - vnew_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) - g_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) - beta_fake = make_fake_compact_tensor(beta_dtype, (1, sym_b, H), stride_order=(2, 1, 0), assumed_align=128) - A_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, BT), stride_order=(3, 2, 1, 0), assumed_align=128) - do_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) - dv_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) - - dq_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) - dk_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) - dv2_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, H, V), stride_order=(3, 2, 1, 0), assumed_align=128) - dg_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, K), stride_order=(3, 2, 1, 0), assumed_align=128) - db_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H), stride_order=(2, 1, 0), assumed_align=128) - dA_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, H, BT), stride_order=(3, 2, 1, 0), assumed_align=128) - - h_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_nt, H, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128) - dh_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_nt, H, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128) + v_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128) + vnew_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128) + g_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128) + beta_fake = make_fake_compact_tensor(beta_dtype, (1, sym_b, HV), stride_order=(2, 1, 0), assumed_align=128) + A_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, HV, BT), stride_order=(3, 2, 1, 0), assumed_align=128) + do_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128) + dv_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128) + + dq_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128) + dk_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128) + dv2_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128) + dg_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128) + db_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, HV), stride_order=(2, 1, 0), assumed_align=128) + dA_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, HV, BT), stride_order=(3, 2, 1, 0), assumed_align=128) + + h_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128) + dh_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128) cu_fake = make_fake_compact_tensor(cutlass.Int32, (sym_cu,), assumed_align=128) ci_fake = make_fake_compact_tensor(cutlass.Int32, (sym_ci, 2), stride_order=(1, 0), assumed_align=128) @@ -3109,7 +3117,7 @@ def _compile_bwd_wy_variant(H, K, V, scale, chunk_size, beta_dtype, use_fast_mat # Metadata cu_fake, ci_fake, - (Int32(1), Int32(1), Int32(H), Int32(K), Int32(V)), + (Int32(1), Int32(1), Int32(H), Int32(HV), Int32(K), Int32(V)), Int32(1), # total_nt dummy stream_fake, options=COMPILE_OPTIONS, @@ -3117,15 +3125,16 @@ def _compile_bwd_wy_variant(H, K, V, scale, chunk_size, beta_dtype, use_fast_mat return compiled_fn -def _get_compiled_bwd_wy(H, K, V, scale, chunk_size, beta_dtype): +def _get_compiled_bwd_wy(H, HV, K, V, scale, chunk_size, beta_dtype): """Get a compiled ChunkKdaBwdWyDqkgFused kernel with on-demand (lazy) compilation. - Cache key: (H, K, V, scale, chunk_size, beta_dtype, USE_FAST_MATH) + Cache key: (H, HV, K, V, scale, chunk_size, beta_dtype, USE_FAST_MATH) """ - key = (H, K, V, scale, chunk_size, beta_dtype, USE_FAST_MATH) + key = (H, HV, K, V, scale, chunk_size, beta_dtype, USE_FAST_MATH) if key not in _bwd_wy_kernel_cache: _bwd_wy_kernel_cache[key] = _compile_bwd_wy_variant( H, + HV, K, V, scale, @@ -3172,6 +3181,7 @@ def chunk_kda_bwd_wy_dqkg_fused( """ B, T, H, K = q.shape V = v.shape[3] + HV = v.shape[2] BT = chunk_size beta_dtype = beta.dtype device = q.device @@ -3190,18 +3200,19 @@ def chunk_kda_bwd_wy_dqkg_fused( T_total = B * T num_seqs = cu_seqlens.shape[0] - 1 total_nt_val = chunk_indices.shape[0] - ps = (Int32(num_seqs), Int32(T_total), Int32(H), Int32(K), Int32(V)) + ps = (Int32(num_seqs), Int32(T_total), Int32(H), Int32(HV), Int32(K), Int32(V)) # Allocate output tensors - dq = torch.empty(1, T_total, H, K, dtype=torch.float32, device=device) - dk = torch.empty(1, T_total, H, K, dtype=torch.float32, device=device) - dv2 = torch.empty(1, T_total, H, V, dtype=torch.bfloat16, device=device) - dg = torch.empty(1, T_total, H, K, dtype=torch.float32, device=device) - db = torch.empty(1, T_total, H, dtype=torch.float32, device=device) - dA = torch.empty(1, T_total, H, BT, dtype=torch.float32, device=device) + dq = torch.empty(1, T_total, HV, K, dtype=torch.float32, device=device) + dk = torch.empty(1, T_total, HV, K, dtype=torch.float32, device=device) + dv2 = torch.empty(1, T_total, HV, V, dtype=torch.bfloat16, device=device) + dg = torch.empty(1, T_total, HV, K, dtype=torch.float32, device=device) + db = torch.empty(1, T_total, HV, dtype=torch.float32, device=device) + dA = torch.empty(1, T_total, HV, BT, dtype=torch.float32, device=device) compiled_fn = _get_compiled_bwd_wy( H, + HV, K, V, scale, @@ -3212,15 +3223,15 @@ def chunk_kda_bwd_wy_dqkg_fused( if B != 1: q = q.reshape(1, T_total, H, K) k = k.reshape(1, T_total, H, K) - v = v.reshape(1, T_total, H, V) - v_new = v_new.reshape(1, T_total, H, V) - g = g.reshape(1, T_total, H, K) - beta = beta.reshape(1, T_total, H) - A = A.reshape(1, T_total, H, BT) - h = h.reshape(1, total_nt_val, H, K, V) - do = do.reshape(1, T_total, H, V) - dh = dh.reshape(1, total_nt_val, H, K, V) - dv = dv.reshape(1, T_total, H, V) + v = v.reshape(1, T_total, HV, V) + v_new = v_new.reshape(1, T_total, HV, V) + g = g.reshape(1, T_total, HV, K) + beta = beta.reshape(1, T_total, HV) + A = A.reshape(1, T_total, HV, BT) + h = h.reshape(1, total_nt_val, HV, K, V) + do = do.reshape(1, T_total, HV, V) + dh = dh.reshape(1, total_nt_val, HV, K, V) + dv = dv.reshape(1, T_total, HV, V) # TVM-FFI call compiled_fn( @@ -3252,12 +3263,12 @@ def chunk_kda_bwd_wy_dqkg_fused( # rearrange back if B != 1: - dq = dq.reshape(B, T, H, K) - dk = dk.reshape(B, T, H, K) - dv2 = dv2.reshape(B, T, H, V) - dg = dg.reshape(B, T, H, K) - db = db.reshape(B, T, H) - dA = dA.reshape(B, T, H, BT) + dq = dq.reshape(B, T, HV, K) + dk = dk.reshape(B, T, HV, K) + dv2 = dv2.reshape(B, T, HV, V) + dg = dg.reshape(B, T, HV, K) + db = db.reshape(B, T, HV) + dA = dA.reshape(B, T, HV, BT) return dq, dk, dv2, db, dg, dA @@ -3272,6 +3283,7 @@ def main(): parser.add_argument("--B", type=int, default=1) parser.add_argument("--T", type=int, default=64) parser.add_argument("--H", type=int, default=1) + parser.add_argument("--HV", type=int, default=None, help="Number of value heads (default: H, i.e. no GVA)") parser.add_argument("--K", type=int, default=128) parser.add_argument("--V", type=int, default=128) parser.add_argument("--scale", type=float, default=None) @@ -3281,6 +3293,7 @@ def main(): if args.scale is None: args.scale = args.K**-0.5 B, T, H, K, V = args.B, args.T, args.H, args.K, args.V + HV = args.HV if args.HV is not None else H BT = args.chunk_size seq_lens = [63, 63, 63] seq_lens = [64] @@ -3291,23 +3304,23 @@ def main(): dtype, device = torch.bfloat16, "cuda" cu_seqlens = torch.tensor(_exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - print(f"Config: B={B}, T={T}, H={H}, K={K}, V={V}, BT={BT}, scale={scale:.4f}") + print(f"Config: B={B}, T={T}, H={H}, HV={HV}, K={K}, V={V}, BT={BT}, scale={scale:.4f}") print(f" Chunks per seq: {NT}, Total chunks: {B * NT}") print(f" BK={64}, BV={64}, NK={K // 64}, NV={V // 64}") - # Generate test data + # Generate test data (q/k use H heads; all others use HV heads) torch.manual_seed(42) q = torch.randn(B, T, H, K, dtype=dtype, device=device) k = torch.randn(B, T, H, K, dtype=dtype, device=device) - v = torch.randn(B, T, H, V, dtype=dtype, device=device) - v_new = torch.randn(B, T, H, V, dtype=dtype, device=device) - g = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1 - beta = torch.randn(B, T, H, dtype=torch.bfloat16, device=device) - A = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 - h = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - do_t = torch.randn(B, T, H, V, dtype=dtype, device=device) - dh = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - dv = torch.randn(B, T, H, V, dtype=dtype, device=device) + v = torch.randn(B, T, HV, V, dtype=dtype, device=device) + v_new = torch.randn(B, T, HV, V, dtype=dtype, device=device) + g = torch.randn(B, T, HV, K, dtype=torch.float32, device=device) * 0.1 + beta = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + A = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1 + h = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + do_t = torch.randn(B, T, HV, V, dtype=dtype, device=device) + dh = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + dv = torch.randn(B, T, HV, V, dtype=dtype, device=device) print("\n=== Compilation Test ===") try: From 86675ff2143fed6a12f0c7611e1f44583d628a64 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Sat, 9 May 2026 20:05:14 +0800 Subject: [PATCH 13/26] use cute.arch.atomic_add and update check --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 2 +- cula/ops/chunk_wy_dqkg_sm100.py | 13 +------------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py index b61c47a..04d9bd0 100644 --- a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -180,7 +180,7 @@ def check_determinism(H=4, HV=None, total_T=2001, num_seqs=4, iters=1000, beta_d assert torch.equal(dg_out, ref_dg), f"dg mismatch at iter {i}" assert torch.equal(dA_out, ref_dA), f"dA mismatch at iter {i}" # NOTE: for db, kernel uses atomic add which can cause non-determinism, so we use a looser check here - torch.testing.assert_close(db_out, ref_db, rtol=1e-5, atol=1e-5), f"db mismatch at iter {i}" + torch.testing.assert_close(db_out, ref_db, rtol=1e-5, atol=1e-5, msg=f"db mismatch at iter {i}") assert torch.isnan(dq_out).sum() == 0, f"dq contains NaNs at iter {i}" assert torch.isnan(dk_out).sum() == 0, f"dk contains NaNs at iter {i}" assert torch.isnan(dv_out).sum() == 0, f"dv contains NaNs at iter {i}" diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 6d36916..fa191fc 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -131,17 +131,6 @@ def _exclusive_cumsum(a: list[int]): def _ir(val, loc=None, ip=None): return val.ir_value(loc=loc, ip=ip) if hasattr(val, "ir_value") else val -@dsl_user_op -def atomicAdd(dst_ptr: cute.Pointer, val: Int32 | Float32, *, loc=None, ip=None) -> Int32 | Float32: - return cute.arch.atomic_add( - ptr=dst_ptr.llvm_ptr, - val=val, - sem="relaxed", - scope="sys", - loc=loc, - ip=ip, - ) - @dsl_user_op def bf16_to_f32(val, *, loc=None, ip=None): """Convert a BFloat16 value to Float32 using arith.extf (no inline asm).""" @@ -2060,7 +2049,7 @@ def kernel( # atomic add for each row of db if row < sub_seq_len: sDb_row_ptr = cute.make_ptr(Float32, (sDb.iterator + row).toint(), cute.AddressSpace.smem, assumed_align=4) - atomicAdd(sDb_row_ptr, db_val) + cute.arch.atomic_add(sDb_row_ptr, db_val) self.cuda_wg_sync_barrier.arrive_and_wait() # store db to GMEM if local_tidx < sub_seq_len: From b1410301e600d83f732719fb1b11009e81de15ba Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Sat, 9 May 2026 20:52:33 +0800 Subject: [PATCH 14/26] fix db atomic add and store --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 4 ++-- cula/ops/chunk_wy_dqkg_sm100.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py index 04d9bd0..6428679 100644 --- a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -180,7 +180,7 @@ def check_determinism(H=4, HV=None, total_T=2001, num_seqs=4, iters=1000, beta_d assert torch.equal(dg_out, ref_dg), f"dg mismatch at iter {i}" assert torch.equal(dA_out, ref_dA), f"dA mismatch at iter {i}" # NOTE: for db, kernel uses atomic add which can cause non-determinism, so we use a looser check here - torch.testing.assert_close(db_out, ref_db, rtol=1e-5, atol=1e-5, msg=f"db mismatch at iter {i}") + torch.testing.assert_close(db_out, ref_db, rtol=1e-5, atol=1e-5) assert torch.isnan(dq_out).sum() == 0, f"dq contains NaNs at iter {i}" assert torch.isnan(dk_out).sum() == 0, f"dk contains NaNs at iter {i}" assert torch.isnan(dv_out).sum() == 0, f"dv contains NaNs at iter {i}" @@ -446,7 +446,7 @@ def main(): for H in args.heads: HV = args.hv if args.hv is not None else H - check_determinism(H=H, HV=HV, iters=10000) + check_determinism(H=H, HV=HV, iters=100000) fixed_res, varlen_res = [], [] diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index fa191fc..b68559a 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -2047,12 +2047,23 @@ def kernel( for i in cutlass.range_constexpr(bk_num_cols_per_wg): db_val += rKgb_kg[i] # atomic add for each row of db + # NOTE: must pass `.llvm_ptr` (LLVM ptr addrspace(3)) and set + # explicit sem/scope. Passing the raw cute._Pointer makes + # _normalize_ptr fall through (no `to_llvm_ptr` method on + # _Pointer), losing the SMEM address-space tag, which causes + # rare large mismatches in `db` (one partition's db_val gets + # silently dropped/corrupted). if row < sub_seq_len: sDb_row_ptr = cute.make_ptr(Float32, (sDb.iterator + row).toint(), cute.AddressSpace.smem, assumed_align=4) - cute.arch.atomic_add(sDb_row_ptr, db_val) + cute.arch.atomic_add( + ptr=sDb_row_ptr.llvm_ptr, + val=db_val, + sem="relaxed", + scope="cta", + ) self.cuda_wg_sync_barrier.arrive_and_wait() # store db to GMEM - if local_tidx < sub_seq_len: + if wg_idx == 0 and local_tidx < sub_seq_len: db_gmem[(tok_offset + tile_idx * self.BT + local_tidx, (i_hv, Int32(0)))] = sDb[(local_tidx,)] # dk = dk * exp2(gn[None, :] - g) From fb90057cf155f2dfef2b0498f6582ee7d67c7a82 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Mon, 11 May 2026 15:44:58 +0800 Subject: [PATCH 15/26] change to deterministic db reduce --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 13 ++-- cula/ops/chunk_wy_dqkg_sm100.py | 92 ++++++++--------------- 2 files changed, 37 insertions(+), 68 deletions(-) diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py index 6428679..0cb70c2 100644 --- a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -174,13 +174,6 @@ def check_determinism(H=4, HV=None, total_T=2001, num_seqs=4, iters=1000, beta_d ref_dq, ref_dk, ref_dv, ref_db, ref_dg, ref_dA = run_cutedsl(inputs) for i in range(iters): dq_out, dk_out, dv_out, db_out, dg_out, dA_out = run_cutedsl(inputs) - assert torch.equal(dq_out, ref_dq), f"dq mismatch at iter {i}" - assert torch.equal(dk_out, ref_dk), f"dk mismatch at iter {i}" - assert torch.equal(dv_out, ref_dv), f"dv mismatch at iter {i}" - assert torch.equal(dg_out, ref_dg), f"dg mismatch at iter {i}" - assert torch.equal(dA_out, ref_dA), f"dA mismatch at iter {i}" - # NOTE: for db, kernel uses atomic add which can cause non-determinism, so we use a looser check here - torch.testing.assert_close(db_out, ref_db, rtol=1e-5, atol=1e-5) assert torch.isnan(dq_out).sum() == 0, f"dq contains NaNs at iter {i}" assert torch.isnan(dk_out).sum() == 0, f"dk contains NaNs at iter {i}" assert torch.isnan(dv_out).sum() == 0, f"dv contains NaNs at iter {i}" @@ -193,6 +186,12 @@ def check_determinism(H=4, HV=None, total_T=2001, num_seqs=4, iters=1000, beta_d assert torch.isfinite(db_out).all(), f"db contains infs at iter {i}" assert torch.isfinite(dg_out).all(), f"dg contains infs at iter {i}" assert torch.isfinite(dA_out).all(), f"dA contains infs at iter {i}" + assert torch.equal(dq_out, ref_dq), f"dq mismatch at iter {i}" + assert torch.equal(dk_out, ref_dk), f"dk mismatch at iter {i}" + assert torch.equal(dv_out, ref_dv), f"dv mismatch at iter {i}" + assert torch.equal(dg_out, ref_dg), f"dg mismatch at iter {i}" + assert torch.equal(dA_out, ref_dA), f"dA mismatch at iter {i}" + assert torch.equal(db_out, ref_db), f"db mismatch at iter {i}" return True diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index b68559a..7d0534e 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -74,7 +74,7 @@ def _exclusive_cumsum(a: list[int]): TMEM_DK_ACC_OFF = 96 # [96,160) 64 cols dk fp32 acc TMEM_DW_ACC_OFF = 160 # [160,224] 64 cols dw fp32 acc TMEM_FLEX_OFF = 224 # [224,256) 32 cols dvb time-shared -TMEM_A_BF16_OFF = 256 # [256,272) 16 cols A_bf16 TS opA (persistent) +TMEM_A_BF16_OFF = 256 # [256,272) 16 cols A_bf16 TS opA (persistent) (not used currently) TMEM_DKGB_ACC_OFF = 272 # [272,336) 64 cols, dkgb fp32 acc TMEM_DA2_ACC_OFF = 336 # [336,368) 32 cols dA fp32 acc, used for dA=dA@A and dA=A@dA TMEM_DQ_SCALED_OFF = 368 # [368,432) 64 cols dq_scaled (stored for dg) @@ -252,22 +252,6 @@ def smem_store_f32x4_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32, d smem_t = cute.make_tensor(smem_ptr, cute.make_layout((4,), stride=(1,))) cute.autovec_copy(data, smem_t) -# SMEM B: MN-major -@cute.jit -def mma_ws_ts_m64n128_call( - tmem_a_base: Int32, - b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, - tmem_c: Int32, K: Int32 -): - with elect_one(): - b_outer = b_smem_layout.outer - for ks in cutlass.range_constexpr(K // 16): - scale = 0 if ks == 0 else 1 - b_off = cute.crd2idx(((0, 0), 0, ks, 0), b_outer) * ELEM_BYTES_BF16 - desc_b = desc_b_base + b_off - tmem_a = tmem_a_base + Int32(ks * 4) - tcgen05mma_ws_ts_f16(tmem_a, desc_b, tmem_c, IDESC_F16_M64_N128_K_MN, scale) - @cute.jit def mma_ws_ss_m64n128_call( a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, @@ -322,22 +306,6 @@ def mma_ws_ss_m64n128_mn_mn_call( tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N128_MN_MN, scale) scale = 1 -@cute.jit -def mma_ws_ts_m64n64_call( - tmem_a_base: Int32, - b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, - tmem_c: Int32, K: Int32 -): - pass - -@cute.jit -def mma_ws_ss_m64n64_call( - a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, - b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, - tmem_c: Int32, K: Int32 -): - pass - @cute.jit def mma_ws_ss_m64n64_k_k_call( a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, @@ -994,7 +962,6 @@ class SharedStorage: bar_mma_dA: cute.struct.MemRange[Int64, self.mma_stage * 2] bar_mma_dA2: cute.struct.MemRange[Int64, self.mma_stage * 2] bar_mma_dA3: cute.struct.MemRange[Int64, self.mma_stage * 2] - bar_mma_done_dA: cute.struct.MemRange[Int64, self.mma_stage] bar_mma_done_vloop: cute.struct.MemRange[Int64, self.mma_stage] bar_prologue_dw: cute.struct.MemRange[Int64, self.kloop_stage * 2] bar_prologue_kg: cute.struct.MemRange[Int64, self.kloop_stage * 2] @@ -1003,7 +970,7 @@ class SharedStorage: bar_store_dg: cute.struct.MemRange[Int64, self.kloop_stage * 2] # TMEM holding buffer tmem_holding_buf: Int32 - # A, stage=1, [BT,BT], 8KB + # A, stage=2, [BT,BT], 16KB buf_A: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(A_mn_opA_smem)], self.buffer_align_bytes, @@ -1061,8 +1028,10 @@ class SharedStorage: cute.struct.MemRange[cutlass.Float32, self.BT], 128, ] + # 2 slots per row, one per warpgroup, for deterministic db reduction + # (avoids cross-wg fp32 atomicAdd on shared memory). s_db: cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, self.BT], + cute.struct.MemRange[cutlass.Float32, self.BT * 2], 128, ] s_gn: cute.struct.Align[ @@ -1666,9 +1635,11 @@ def kernel( cute.make_ptr(Float32, storage.s_beta.data_ptr().toint(), cute.AddressSpace.smem), cute.make_layout((self.BT, ), stride=(1, )), ) + # sDb layout: (BT, 2). Inner dim = wg_idx slot. Stride (1, BT) so each + # wg's column is contiguous (better for the reduce in Phase 3). sDb = cute.make_tensor( cute.make_ptr(Float32, storage.s_db.data_ptr().toint(), cute.AddressSpace.smem), - cute.make_layout((self.BT, ), stride=(1, )), + cute.make_layout((self.BT, 2), stride=(1, self.BT)), ) sDgk = cute.make_tensor( cute.make_ptr(Float32, storage.s_dgk.data_ptr().toint(), cute.AddressSpace.smem), @@ -1793,9 +1764,10 @@ def kernel( # reading sDgk[col] above. This was the source of the # non-deterministic dg accuracy bug. self.cuda_wg_sync_barrier.arrive_and_wait() - # fill db, dgk to 0 + # fill db, dgk to 0. Each wg zeroes its own sDb column. if local_tidx < self.BT: - sDb[local_tidx] = Float32(0.0) + sDb[local_tidx, 0] = Float32(0.0) + sDb[local_tidx, 1] = Float32(0.0) if local_tidx < self.BK: sDgk[local_tidx] = Float32(0.0) self.cuda_wg_sync_barrier.arrive_and_wait() @@ -2046,25 +2018,28 @@ def kernel( if row < sub_seq_len: for i in cutlass.range_constexpr(bk_num_cols_per_wg): db_val += rKgb_kg[i] - # atomic add for each row of db - # NOTE: must pass `.llvm_ptr` (LLVM ptr addrspace(3)) and set - # explicit sem/scope. Passing the raw cute._Pointer makes - # _normalize_ptr fall through (no `to_llvm_ptr` method on - # _Pointer), losing the SMEM address-space tag, which causes - # rare large mismatches in `db` (one partition's db_val gets - # silently dropped/corrupted). - if row < sub_seq_len: - sDb_row_ptr = cute.make_ptr(Float32, (sDb.iterator + row).toint(), cute.AddressSpace.smem, assumed_align=4) - cute.arch.atomic_add( - ptr=sDb_row_ptr.llvm_ptr, - val=db_val, - sem="relaxed", - scope="cta", - ) + + # Deterministic db reduction without atomicAdd. + # 4 partitions per row come from 4 warps (warp_row_tile in {0,1}, + # warp_col_tile in {0,1}) x 2 wgs. Reduce in a fixed order so + # the result is bitwise reproducible across launches: + # Phase 1: warp_col_tile==0 writes its db_val into + # sDb[row, wg_idx] (single writer per slot) + # Phase 2: warp_col_tile==1 RMW-adds its db_val into the + # same slot (still single writer per slot) + # Phase 3: WG0 sums the 2 wg-slots in fixed order and stores + # to GMEM. + # No race, no atomic, no fp ordering nondeterminism. + if warp_col_tile == 0 and row < sub_seq_len: + sDb[row, wg_idx] = db_val + self.cuda_wg_sync_barrier.arrive_and_wait() + if warp_col_tile == 1 and row < sub_seq_len: + sDb[row, wg_idx] = sDb[row, wg_idx] + db_val self.cuda_wg_sync_barrier.arrive_and_wait() - # store db to GMEM + # store db to GMEM (WG0 only). Sum order is fixed (slot 0 + slot 1). if wg_idx == 0 and local_tidx < sub_seq_len: - db_gmem[(tok_offset + tile_idx * self.BT + local_tidx, (i_hv, Int32(0)))] = sDb[(local_tidx,)] + db_sum = sDb[(local_tidx, 0)] + sDb[(local_tidx, 1)] + db_gmem[(tok_offset + tile_idx * self.BT + local_tidx, (i_hv, Int32(0)))] = db_sum # dk = dk * exp2(gn[None, :] - g) pipeline_mma_dk.consumer_wait(mma_dk_consumer_state) @@ -3341,11 +3316,6 @@ def main(): chunk_size=BT, ) torch.cuda.synchronize() - # do_slice = do_t[0, :, 1, :].to(torch.float32) - # h_slice = h[0, 0, 1, :, :].to(torch.float32) - # dq_ref = do_slice @ h_slice.T - import pdb;pdb.set_trace() - # torch.testing.assert_close(dq_ref, dq[0,:,1,:], rtol=1e-2, atol=1e-2) print(f" dq shape: {dq.shape}, dtype: {dq.dtype}") print(f" dk shape: {dk.shape}, dtype: {dk.dtype}") print(f" dv2 shape: {dv2.shape}, dtype: {dv2.dtype}") From e0592171f546d296f789446771caf2bde04a7de9 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Mon, 11 May 2026 16:29:55 +0800 Subject: [PATCH 16/26] change iters --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py index 0cb70c2..296a775 100644 --- a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -445,7 +445,7 @@ def main(): for H in args.heads: HV = args.hv if args.hv is not None else H - check_determinism(H=H, HV=HV, iters=100000) + check_determinism(H=H, HV=HV, iters=10000) fixed_res, varlen_res = [], [] From 6d194ec4f765263e1042b3e224f6d97b14a1e268 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Tue, 12 May 2026 00:01:34 +0800 Subject: [PATCH 17/26] change to umma pipeline --- cula/ops/chunk_wy_dqkg_sm100.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 7d0534e..3987d7e 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -1222,7 +1222,7 @@ def kernel( mbarrier_init_fence() # ====== Pipeline Definition ====== - pipeline_load_A = pipeline.PipelineTmaAsync.create( + pipeline_load_A = pipeline.PipelineTmaUmma.create( barrier_storage=storage.bar_load_A.data_ptr(), num_stages=self.a_stage, producer_group=make_thread_cooperative_group(len([self.load_warp_id])), @@ -1250,7 +1250,6 @@ def kernel( consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) + num_cuda_warps_total), tx_count=self.tma_bytes_dh, ) - # NOTE: UMMA as consumer to call tcgen05.commit pipeline_load_do = pipeline.PipelineTmaUmma.create( barrier_storage=storage.bar_load_do.data_ptr(), num_stages=self.vloop_stage, @@ -1299,22 +1298,22 @@ def kernel( producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), ) - pipeline_mma_dq = pipeline.PipelineAsync.create( + pipeline_mma_dq = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.bar_mma_dq.data_ptr(), num_stages=self.mma_stage, - producer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), ) - pipeline_mma_dk = pipeline.PipelineAsync.create( + pipeline_mma_dk = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.bar_mma_dk.data_ptr(), num_stages=self.mma_stage, - producer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), ) - pipeline_mma_dw = pipeline.PipelineAsync.create( + pipeline_mma_dw = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.bar_mma_dw.data_ptr(), num_stages=self.mma_stage, - producer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), ) pipeline_mma_dA = pipeline.PipelineUmmaAsync.create( @@ -1335,29 +1334,29 @@ def kernel( producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), consumer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), ) - pipeline_prologue_dw = pipeline.PipelineAsync.create( + pipeline_prologue_dw = pipeline.PipelineAsyncUmma.create( barrier_storage=storage.bar_prologue_dw.data_ptr(), num_stages=self.kloop_stage, producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), - consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), ) - pipeline_prologue_kg = pipeline.PipelineAsync.create( + pipeline_prologue_kg = pipeline.PipelineAsyncUmma.create( barrier_storage=storage.bar_prologue_kg.data_ptr(), num_stages=self.kloop_stage, producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), - consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), ) - pipeline_prologue_dA2 = pipeline.PipelineAsync.create( + pipeline_prologue_dA2 = pipeline.PipelineAsyncUmma.create( barrier_storage=storage.bar_prologue_dA2.data_ptr(), num_stages=self.mma_stage, producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), - consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), ) - pipeline_prologue_dA3 = pipeline.PipelineAsync.create( + pipeline_prologue_dA3 = pipeline.PipelineAsyncUmma.create( barrier_storage=storage.bar_prologue_dA3.data_ptr(), num_stages=self.mma_stage, producer_group=make_thread_cooperative_group(num_cuda_warps_total * 32), - consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) * 32), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), ) pipeline_mma_dkgb = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.bar_mma_dkgb.data_ptr(), From a763ddafa685de94d1087f6da5b1fa3fbf176af3 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Tue, 12 May 2026 00:49:04 +0800 Subject: [PATCH 18/26] change h, dh and v pipelines to different consumers --- cula/ops/chunk_wy_dqkg_sm100.py | 132 +++++++++++++------------------- 1 file changed, 52 insertions(+), 80 deletions(-) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 3987d7e..e81974e 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -16,6 +16,7 @@ mbarrier_init, mbarrier_init_fence, mbarrier_wait, + mbarrier_arrive_and_expect_tx, mbarrier_arrive, sync_threads, ) @@ -947,11 +948,14 @@ class SharedStorage: bar_load_dv: cute.struct.MemRange[Int64, self.vloop_stage * 2] bar_mma_dvb: cute.struct.MemRange[Int64, self.mma_stage * 2] bar_load_beta: cute.struct.MemRange[Int64, 1 * 2] - bar_load_h: cute.struct.MemRange[Int64, self.vloop_stage * 2] - bar_load_dh: cute.struct.MemRange[Int64, self.vloop_stage * 2] + bar_tma_h: cute.struct.MemRange[Int64, self.vloop_stage] + bar_mma_cuda_h: cute.struct.MemRange[Int64, self.vloop_stage] + bar_tma_dh: cute.struct.MemRange[Int64, self.vloop_stage] + bar_mma_cuda_dh: cute.struct.MemRange[Int64, self.vloop_stage] + bar_tma_v: cute.struct.MemRange[Int64, self.vloop_stage] + bar_mma_cuda_v: cute.struct.MemRange[Int64, self.vloop_stage] bar_load_do: cute.struct.MemRange[Int64, self.vloop_stage * 2] bar_load_g: cute.struct.MemRange[Int64, self.kloop_stage * 2] - bar_load_v: cute.struct.MemRange[Int64, self.vloop_stage * 2] bar_load_vnew: cute.struct.MemRange[Int64, self.vloop_stage * 2] bar_load_q: cute.struct.MemRange[Int64, self.kloop_stage * 2] bar_load_k: cute.struct.MemRange[Int64, self.kloop_stage * 2] @@ -1215,10 +1219,24 @@ def kernel( # Barrier Initialization bar_mma_done_vloop_ptr = storage.bar_mma_done_vloop.data_ptr() + # NOTE: for h, dh and v, consumer contains both MMA and CUDA Core, so we use original mbarrier declaration instead of pipeline utils + bar_tma_h_ptr = storage.bar_tma_h.data_ptr() + bar_mma_cuda_h_ptr = storage.bar_mma_cuda_h.data_ptr() + bar_tma_dh_ptr = storage.bar_tma_dh.data_ptr() + bar_mma_cuda_dh_ptr = storage.bar_mma_cuda_dh.data_ptr() + bar_tma_v_ptr = storage.bar_tma_v.data_ptr() + bar_mma_cuda_v_ptr = storage.bar_mma_cuda_v.data_ptr() if warp_idx == 0: with elect_one(): - for i in cutlass.range(self.mma_stage): + for i in cutlass.range(self.mma_stage, unroll_full=True): mbarrier_init(bar_mma_done_vloop_ptr + i, 1) + for i in cutlass.range(self.vloop_stage, unroll_full=True): + mbarrier_init(bar_tma_h_ptr + i, 1) + mbarrier_init(bar_mma_cuda_h_ptr + i, num_cuda_warps_total * 32 + 1) + mbarrier_init(bar_tma_dh_ptr + i, 1) + mbarrier_init(bar_mma_cuda_dh_ptr + i, num_cuda_warps_total * 32 + 1) + mbarrier_init(bar_tma_v_ptr + i, 1) + mbarrier_init(bar_mma_cuda_v_ptr + i, num_cuda_warps_total * 32 + 1) mbarrier_init_fence() # ====== Pipeline Definition ====== @@ -1236,20 +1254,6 @@ def kernel( consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), tx_count=self.tma_bytes_dv, ) - pipeline_load_h = pipeline.PipelineTmaAsync.create( - barrier_storage=storage.bar_load_h.data_ptr(), - num_stages=self.vloop_stage, - producer_group=make_thread_cooperative_group(len([self.load_warp_id])), - consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) + num_cuda_warps_total), - tx_count=self.tma_bytes_h, - ) - pipeline_load_dh = pipeline.PipelineTmaAsync.create( - barrier_storage=storage.bar_load_dh.data_ptr(), - num_stages=self.vloop_stage, - producer_group=make_thread_cooperative_group(len([self.load_warp_id])), - consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) + num_cuda_warps_total), - tx_count=self.tma_bytes_dh, - ) pipeline_load_do = pipeline.PipelineTmaUmma.create( barrier_storage=storage.bar_load_do.data_ptr(), num_stages=self.vloop_stage, @@ -1271,13 +1275,6 @@ def kernel( consumer_group=make_thread_cooperative_group(num_cuda_warps_total + len(self.aux_warp_ids)), tx_count=self.tma_bytes_g, ) - pipeline_load_v = pipeline.PipelineTmaAsync.create( - barrier_storage=storage.bar_load_v.data_ptr(), - num_stages=self.vloop_stage, - producer_group=make_thread_cooperative_group(len([self.load_warp_id])), - consumer_group=make_thread_cooperative_group(len([self.mma_warp_id]) + num_cuda_warps_total), - tx_count=self.tma_bytes_v, - ) pipeline_load_k = pipeline.PipelineTmaAsync.create( barrier_storage=storage.bar_load_k.data_ptr(), num_stages=self.kloop_stage, @@ -1666,18 +1663,9 @@ def kernel( load_beta_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, 1 ) - load_h_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) - load_dh_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) load_g_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.kloop_stage ) - load_v_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) mma_dvb_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_stage ) @@ -1744,6 +1732,7 @@ def kernel( num_stores_f32 = bk_num_cols_per_wg // 8 vloop_stage_idx = 0 + vloop_phase = 0 for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x G = HV // H @@ -1778,8 +1767,8 @@ def kernel( db_val = Float32(0.0) for v_iter in cutlass.range(self.num_v_tiles): # dgk += sum(h * dh, axis=0) - pipeline_load_h.consumer_wait(load_h_consumer_state) - pipeline_load_dh.consumer_wait(load_dh_consumer_state) + mbarrier_wait(bar_tma_h_ptr + vloop_stage_idx, vloop_phase) + mbarrier_wait(bar_tma_dh_ptr + vloop_stage_idx, vloop_phase) sH_raw_ptr = cute.make_ptr( self.io_dtype, sH_ptr_base + vloop_stage_idx * vloop_opB_bytes_per_stage, cute.AddressSpace.smem @@ -1799,10 +1788,8 @@ def kernel( for j in cutlass.range_constexpr(8): sDgk[(local_tidx,)] += h_dh_vals[j] - pipeline_load_dh.consumer_release(load_dh_consumer_state) - load_dh_consumer_state.advance() - pipeline_load_h.consumer_release(load_h_consumer_state) - load_h_consumer_state.advance() + mbarrier_arrive(bar_mma_cuda_h_ptr + vloop_stage_idx) + mbarrier_arrive(bar_mma_cuda_dh_ptr + vloop_stage_idx) pipeline_mma_dvb.consumer_wait(mma_dvb_consumer_state) tcgen05_fence_after() @@ -1817,7 +1804,7 @@ def kernel( dvb_f32_val = TensorSSA(dvb_f32, (bv_num_cols_per_wg,), Float32) # db += sum(dvb * v, axis=1) - pipeline_load_v.consumer_wait(load_v_consumer_state) + mbarrier_wait(bar_tma_v_ptr + vloop_stage_idx, vloop_phase) rV_bf16 = cute.make_rmem_tensor((bv_num_cols_per_wg,), self.io_dtype) sV_raw_ptr_cur = cute.make_ptr( self.io_dtype, sV_ptr_base + vloop_stage_idx * v_opB_bytes_per_stage, cute.AddressSpace.smem @@ -1843,8 +1830,7 @@ def kernel( for i in cutlass.range_constexpr(bv_num_cols_per_wg): db_val += rV_fp32[i] - pipeline_load_v.consumer_release(load_v_consumer_state) - load_v_consumer_state.advance() + mbarrier_arrive(bar_mma_cuda_v_ptr + vloop_stage_idx) # ── dv2 epilogue: dv2 = dvb * beta, cast to bf16, store to gmem ── dvb_f32_rmem = cute.make_rmem_tensor((bv_num_cols_per_wg,), Float32) @@ -1874,6 +1860,7 @@ def kernel( store_256b(base_addr + s * 32, chunk) vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage + vloop_phase ^= 1 # gk_exp = exp2(g) pipeline_load_g.consumer_wait(load_g_consumer_state) @@ -2278,12 +2265,6 @@ def kernel( load_dv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.vloop_stage ) - load_h_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.vloop_stage - ) - load_dh_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.vloop_stage - ) load_do_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.vloop_stage ) @@ -2293,9 +2274,6 @@ def kernel( load_g_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.kloop_stage ) - load_v_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.vloop_stage - ) load_k_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.kloop_stage ) @@ -2304,6 +2282,7 @@ def kernel( ) vloop_stage_idx = 0 + vloop_phase = 1 # init as 1 for producer for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x G = HV // H @@ -2349,14 +2328,15 @@ def kernel( vloop_tiled_mma, i_hv, i_t ) - pipeline_load_h.producer_acquire(load_h_producer_state) + mbarrier_wait(bar_mma_cuda_h_ptr + vloop_stage_idx, vloop_phase) + with elect_one(): + mbarrier_arrive_and_expect_tx(bar_tma_h_ptr + vloop_stage_idx, self.tma_bytes_h) cute.copy( tma_atom_h, tHgH[(None, 0, 0)], tHsH[(None, vloop_stage_idx)], - tma_bar_ptr=pipeline_load_h.producer_get_barrier(load_h_producer_state), + tma_bar_ptr=bar_tma_h_ptr + vloop_stage_idx, ) - load_h_producer_state.advance() tma_dh_v = cute.domain_offset((0, v_iter * self.BV, (0, 0)), tma_tensor_dh) tDHsDH, tDHgDH = self._tma_partition_B( @@ -2367,14 +2347,15 @@ def kernel( vloop_tiled_mma, i_hv, i_t ) - pipeline_load_dh.producer_acquire(load_dh_producer_state) + mbarrier_wait(bar_mma_cuda_dh_ptr + vloop_stage_idx, vloop_phase) + with elect_one(): + mbarrier_arrive_and_expect_tx(bar_tma_dh_ptr + vloop_stage_idx, self.tma_bytes_dh) cute.copy( tma_atom_dh, tDHgDH[(None, 0, 0)], tDHsDH[(None, vloop_stage_idx)], - tma_bar_ptr=pipeline_load_dh.producer_get_barrier(load_dh_producer_state), + tma_bar_ptr=bar_tma_dh_ptr + vloop_stage_idx, ) - load_dh_producer_state.advance() tma_do_v = cute.domain_offset((tok_offset, v_iter * self.BV, (0, 0)), tma_tensor_do) tDOsDo, tDOgDo = self._tma_partition_A( @@ -2421,14 +2402,15 @@ def kernel( dA_vloop_tiled_mma, Int32(0), i_hv, ) - pipeline_load_v.producer_acquire(load_v_producer_state) + mbarrier_wait(bar_mma_cuda_v_ptr + vloop_stage_idx, vloop_phase) + with elect_one(): + mbarrier_arrive_and_expect_tx(bar_tma_v_ptr + vloop_stage_idx, self.tma_bytes_v) cute.copy( tma_atom_v, tVgV[(None, tile_idx, 0)], tVsV[(None, vloop_stage_idx)], - tma_bar_ptr=pipeline_load_v.producer_get_barrier(load_v_producer_state), + tma_bar_ptr=bar_tma_v_ptr + vloop_stage_idx, ) - load_v_producer_state.advance() # load v_new tma_vnew_v = cute.domain_offset((tok_offset, v_iter * self.BV, (0, 0)), tma_tensor_vnew) @@ -2450,6 +2432,7 @@ def kernel( load_vnew_producer_state.advance() vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage + vloop_phase ^= 1 # Load g tma_g_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_g) @@ -2515,12 +2498,6 @@ def kernel( mma_dvb_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_stage ) - load_h_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) - load_dh_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) load_do_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.vloop_stage ) @@ -2533,9 +2510,6 @@ def kernel( mma_dk_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_stage ) - load_v_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) mma_dw_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_stage ) @@ -2567,6 +2541,7 @@ def kernel( vloop_stage_idx = 0 a_stage_idx = 0 mma_vloop_phase = 0 + vloop_phase = 0 for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x G = HV // H @@ -2603,7 +2578,7 @@ def kernel( for v_iter in cutlass.range(self.num_v_tiles): is_accum = False if v_iter == 0 else True - pipeline_load_h.consumer_wait(load_h_consumer_state) + mbarrier_wait(bar_tma_h_ptr + vloop_stage_idx, vloop_phase) pipeline_load_do.consumer_wait(load_do_consumer_state) sDo_raw_ptr = cute.make_ptr( self.io_dtype, @@ -2680,7 +2655,7 @@ def kernel( mma_ws_ss_m64n128_k_k_call(vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DW_ACC_OFF, self.BV, is_accum) # dA += dv @ v^T - pipeline_load_v.consumer_wait(load_v_consumer_state) + mbarrier_wait(bar_tma_v_ptr + vloop_stage_idx, vloop_phase) sV_raw = cute.make_ptr( self.io_dtype, sV_ptr_base + vloop_stage_idx * v_opB_bytes_per_stage, cute.AddressSpace.smem ) @@ -2711,11 +2686,8 @@ def kernel( pipeline_mma_dw.producer_commit(mma_dw_producer_state) mma_dw_producer_state.advance() - pipeline_load_h.consumer_release(load_h_consumer_state) - load_h_consumer_state.advance() - - pipeline_load_v.consumer_release(load_v_consumer_state) - load_v_consumer_state.advance() + umma_arrive(bar_mma_cuda_h_ptr + vloop_stage_idx) + umma_arrive(bar_mma_cuda_v_ptr + vloop_stage_idx) # dk += v_new @ dh pipeline_load_vnew.consumer_wait(load_vnew_consumer_state) @@ -2733,7 +2705,7 @@ def kernel( smem_store_bf16x8_sw128(sDvnew_raw_ptr, row, col * 8, zeros8) cute.arch.fence_proxy("async.shared", space="cta") - pipeline_load_dh.consumer_wait(load_dh_consumer_state) + mbarrier_wait(bar_tma_dh_ptr + vloop_stage_idx, vloop_phase) if v_iter == 0: pipeline_mma_dk.producer_acquire(mma_dk_producer_state) @@ -2753,8 +2725,7 @@ def kernel( pipeline_mma_dk.producer_commit(mma_dk_producer_state) mma_dk_producer_state.advance() - pipeline_load_dh.consumer_release(load_dh_consumer_state) - load_dh_consumer_state.advance() + umma_arrive(bar_mma_cuda_dh_ptr + vloop_stage_idx) # add tcgen05.commit and mbar.wait to make sure dq/dk/dw MMA finished umma_arrive(bar_mma_done_vloop_ptr + 0) @@ -2762,6 +2733,7 @@ def kernel( mma_vloop_phase ^= 1 vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage + vloop_phase ^= 1 pipeline_prologue_dw.consumer_wait(prologue_dw_consumer_state) cute.arch.fence_proxy("async.shared", space="cta") From 5196f106d6986a19a4020182ded74e6f39f7689c Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 12 May 2026 16:02:59 +0800 Subject: [PATCH 19/26] add tma store desc prefetch --- cula/ops/chunk_wy_dqkg_sm100.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index e81974e..231bbf0 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -1212,6 +1212,7 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_vnew) cpasync.prefetch_descriptor(tma_atom_k) cpasync.prefetch_descriptor(tma_atom_q) + cpasync.prefetch_descriptor(tma_atom_dg) # ===================== SMEM allocation ===================== smem = utils.SmemAllocator() From 73cd33c453f4952d8fa8240df16166a3e2a0ef5b Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 12 May 2026 23:40:09 +0800 Subject: [PATCH 20/26] skip deter check for ncu mode --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py index 296a775..f792cd3 100644 --- a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -445,7 +445,8 @@ def main(): for H in args.heads: HV = args.hv if args.hv is not None else H - check_determinism(H=H, HV=HV, iters=10000) + if not args.ncu: + check_determinism(H=H, HV=HV, iters=10000) fixed_res, varlen_res = [], [] From 7379391990a3dd2d4eb7d6703476b122b5b9e23e Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Wed, 13 May 2026 10:14:25 +0800 Subject: [PATCH 21/26] modify deter check --- benchmarks/bench_kda_fwd_bwd_e2e.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/benchmarks/bench_kda_fwd_bwd_e2e.py b/benchmarks/bench_kda_fwd_bwd_e2e.py index 2275bce..91dd555 100644 --- a/benchmarks/bench_kda_fwd_bwd_e2e.py +++ b/benchmarks/bench_kda_fwd_bwd_e2e.py @@ -221,11 +221,8 @@ def check_determinism(num_seqs=5, T=512, iters=20): ref = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) for i in range(iters): out = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) - for name in ("o", "ht", "dq", "dk", "dv", "dg", "dh0"): + for name in ("o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"): assert torch.equal(out[name], ref[name]), f"[determinism] cuLA {name} mismatch at iter {i}" - for name in ("dbeta",): - # NOTE: for db, kernel uses atomic add which can cause non-determinism, so we use a looser check here - torch.testing.assert_close(out[name], ref[name], rtol=1e-5, atol=1e-5), f"db mismatch at iter {i}" return True @@ -562,11 +559,6 @@ def main(): action="store_true", help="Disable recompute in both FLA and cuLA (pre-compute QG)", ) - parser.add_argument( - "--check_determinism", - action="store_true", - help="Run determinism check: verify cuLA produces identical outputs across repeated runs", - ) args = parser.parse_args() global NCU_MODE, SANITIZER_MODE, DISABLE_RECOMPUTE, PHASE @@ -581,12 +573,12 @@ def main(): print("[Disable recompute] pre-compute QG in forward") PHASE = args.phase - if args.check_determinism: + if not (args.ncu or args.sanitizer): det_configs = [(5, 1024), (10, 4096), (10, 8192), (10, 16384)] print("\n[Determinism Check] cuLA chunk_kda E2E ...") for num_seqs, T in det_configs: - result = check_determinism(num_seqs=num_seqs, T=T, iters=20) - print(f" num_seqs={num_seqs} T={T:5d} iters=20 {'PASS' if result else 'FAIL'}") + result = check_determinism(num_seqs=num_seqs, T=T, iters=1000) + print(f" num_seqs={num_seqs} T={T:5d} {'PASS' if result else 'FAIL'}") print("[Determinism Check] All passed.\n") return From 878775d1427fed475eb2aaca886bbb0257dc76e3 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Wed, 13 May 2026 10:15:15 +0800 Subject: [PATCH 22/26] fix --- benchmarks/bench_kda_fwd_bwd_e2e.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/bench_kda_fwd_bwd_e2e.py b/benchmarks/bench_kda_fwd_bwd_e2e.py index 91dd555..f8e193c 100644 --- a/benchmarks/bench_kda_fwd_bwd_e2e.py +++ b/benchmarks/bench_kda_fwd_bwd_e2e.py @@ -580,7 +580,6 @@ def main(): result = check_determinism(num_seqs=num_seqs, T=T, iters=1000) print(f" num_seqs={num_seqs} T={T:5d} {'PASS' if result else 'FAIL'}") print("[Determinism Check] All passed.\n") - return fixed_configs = [ # (B, T) From 8864442d825fdd6779c54806917925fc2ca9eb63 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Wed, 13 May 2026 10:34:59 +0800 Subject: [PATCH 23/26] code lint --- benchmarks/bench_kda_bwd_wy_dqkg_sm100.py | 5 +- benchmarks/utils.py | 1 + cula/ops/chunk_wy_dqkg_sm100.py | 549 ++++++++++------------ cula/ops/intrinsics_sm100.py | 4 +- tests/test_ptx_umma_ws.py | 12 +- 5 files changed, 251 insertions(+), 320 deletions(-) diff --git a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py index f792cd3..a8dbfe3 100644 --- a/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py +++ b/benchmarks/bench_kda_bwd_wy_dqkg_sm100.py @@ -65,11 +65,13 @@ N_ITERS = 100 NCU_MODE = False + def generate_balanced_seqlens(total_tokens, num_seqs): base = total_tokens // num_seqs remainder = total_tokens % num_seqs return [base] * (num_seqs - 1) + [base + remainder] + # ============================================================ # Helpers # ============================================================ @@ -151,6 +153,7 @@ def run_cutedsl(inputs: dict): chunk_indices=inputs["chunk_indices"], ) + def check_determinism(H=4, HV=None, total_T=2001, num_seqs=4, iters=1000, beta_dtype=DTYPE): """Verify deterministic outputs across repeated runs.""" if HV is None: @@ -252,7 +255,6 @@ def bench_fixed(configs, H: int, HV: int | None = None): } results.append(r) - del inputs torch.cuda.empty_cache() return results @@ -322,7 +324,6 @@ def bench_varlen(configs, H: int, HV: int | None = None): } results.append(r) - del inputs torch.cuda.empty_cache() return results diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 2e5b8b5..91ccdbc 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -346,6 +346,7 @@ def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_siz return q, k, v, g, beta, scale, cu_seqlens, chunk_indices + def prepare_bwd_wy_dqkg_fused_inputs( B: int, T: int, diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index 231bbf0..eb568a9 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -6,44 +6,41 @@ import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import torch -from cutlass._mlir.dialects import llvm, arith as _arith, nvvm as _nvvm from cutlass._mlir import ir -from cutlass.cutlass_dsl import dsl_user_op -from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream +from cutlass._mlir.dialects import arith as _arith from cutlass.cute.arch import ( elect_one, + mbarrier_arrive, + mbarrier_arrive_and_expect_tx, mbarrier_init, mbarrier_init_fence, mbarrier_wait, - mbarrier_arrive_and_expect_tx, - mbarrier_arrive, - sync_threads, ) +from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass.cute.nvgpu.tcgen05 import ( make_umma_smem_desc, smem_descriptor_to_int, ) +from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream from cutlass.cute.tensor import TensorSSA -from cutlass.cute.typing import Float32, Int32, Int64, BFloat16 +from cutlass.cute.typing import BFloat16, Float32, Int32, Int64 +from cutlass.cutlass_dsl import dsl_user_op from fla.ops.utils import prepare_chunk_indices -from cula.utils import USE_FAST_MATH, assert_blackwell, prepare_uniform_cu_seqlens - from cula.ops.intrinsics_sm100 import ( - tcgen05_fence_before, + reinterpret_cast, + store_256b, + subvec, tcgen05_fence_after, + tcgen05_fence_before, tcgen05_ld_32x32b, tcgen05_st_32x32b, - reinterpret_cast, - subvec, - store_256b, ) from cula.ops.ptx_umma_ext import ( Tcgen05SmemDescriptor, - tcgen05mma_ws_ts_f16, tcgen05mma_ws_ss_f16, ) +from cula.utils import USE_FAST_MATH, assert_blackwell, prepare_uniform_cu_seqlens PRINT_DEBUG = False @@ -58,9 +55,11 @@ torch.float32: cutlass.Float32, } + def make_thread_cooperative_group(size: int): return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) + def _exclusive_cumsum(a: list[int]): r = [0] for v in a: @@ -73,12 +72,12 @@ def _exclusive_cumsum(a: list[int]): TMEM_DA_ACC_OFF = 0 # [0,32) 32 cols dA fp32 acc; Phase 3: [0,16) overwritten by dA_bf16 TMEM_DQ_ACC_OFF = 32 # [32,96) 64 cols dq fp32 acc; Phase 3: step2/step3 result [32,64) TMEM_DK_ACC_OFF = 96 # [96,160) 64 cols dk fp32 acc -TMEM_DW_ACC_OFF = 160 # [160,224] 64 cols dw fp32 acc +TMEM_DW_ACC_OFF = 160 # [160,224] 64 cols dw fp32 acc TMEM_FLEX_OFF = 224 # [224,256) 32 cols dvb time-shared TMEM_A_BF16_OFF = 256 # [256,272) 16 cols A_bf16 TS opA (persistent) (not used currently) -TMEM_DKGB_ACC_OFF = 272 # [272,336) 64 cols, dkgb fp32 acc +TMEM_DKGB_ACC_OFF = 272 # [272,336) 64 cols, dkgb fp32 acc TMEM_DA2_ACC_OFF = 336 # [336,368) 32 cols dA fp32 acc, used for dA=dA@A and dA=A@dA -TMEM_DQ_SCALED_OFF = 368 # [368,432) 64 cols dq_scaled (stored for dg) +TMEM_DQ_SCALED_OFF = 368 # [368,432) 64 cols dq_scaled (stored for dg) TMEM_TOTAL = 512 # Instruction descriptor for M=64, N=64, BF16, dense, TransposeB=1 @@ -102,26 +101,20 @@ def _exclusive_cumsum(a: list[int]): # Bits: M>>4=4 at [24:28], N>>3=16 at [17:22], # TransposeB at [16], TransposeA at [15], # btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] -IDESC_F16_M64_N128_MN_MN = ( - (4 << 24) | (16 << 17) | (1 << 16) | (1 << 15) | (1 << 10) | (1 << 7) | (1 << 4) -) +IDESC_F16_M64_N128_MN_MN = (4 << 24) | (16 << 17) | (1 << 16) | (1 << 15) | (1 << 10) | (1 << 7) | (1 << 4) assert IDESC_F16_M64_N128_MN_MN == 0x4218490 # Instruction descriptor for M=64, N=64, BF16, dense, TransposeA=1, TransposeB=1 # Bits: M>>4=4 at [24:28], N>>3=8 at [17:22], # TransposeB at [16], TransposeA at [15], # btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] -IDESC_F16_M64_N64_MN_MN = ( - (4 << 24) | (8 << 17) | (1 << 16) | (1 << 15) | (1 << 10) | (1 << 7) | (1 << 4) -) +IDESC_F16_M64_N64_MN_MN = (4 << 24) | (8 << 17) | (1 << 16) | (1 << 15) | (1 << 10) | (1 << 7) | (1 << 4) # Instruction descriptor for M=64, N=64, BF16, dense # Bits: M>>4=4 at [24:28], N>>3=8 at [17:22], # TransposeB at [16], TransposeA at [15], # btype=bf16(1) at [10:12], atype=bf16(1) at [7:9], dtype=f32(1) at [4:5] -IDESC_F16_M64_N64_K_K = ( - (4 << 24) | (8 << 17) | (1 << 10) | (1 << 7) | (1 << 4) -) +IDESC_F16_M64_N64_K_K = (4 << 24) | (8 << 17) | (1 << 10) | (1 << 7) | (1 << 4) ELEM_BYTES_BF16 = BFloat16.width // 8 @@ -129,9 +122,11 @@ def _exclusive_cumsum(a: list[int]): # Helpers: _ir, Float32 conversion # ============================================================ + def _ir(val, loc=None, ip=None): return val.ir_value(loc=loc, ip=ip) if hasattr(val, "ir_value") else val + @dsl_user_op def bf16_to_f32(val, *, loc=None, ip=None): """Convert a BFloat16 value to Float32 using arith.extf (no inline asm).""" @@ -147,6 +142,7 @@ def f32_to_bf16(val, *, loc=None, ip=None): bf16_ir = _arith.truncf(BFloat16.mlir_type, f32_ir, loc=loc, ip=ip) return BFloat16(bf16_ir) + @cute.jit def smem_load_bf16x8_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32): """ @@ -164,14 +160,17 @@ def smem_load_bf16x8_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32): swizzled = k_inner ^ ((row & Int32(7)) << Int32(3)) elem_off = half * Int32(4096) + row * Int32(64) + swizzled aligned_ptr = cute.make_ptr( - BFloat16, (raw_ptr + elem_off).toint(), - cute.AddressSpace.smem, assumed_align=16, + BFloat16, + (raw_ptr + elem_off).toint(), + cute.AddressSpace.smem, + assumed_align=16, ) smem_t = cute.make_tensor(aligned_ptr, cute.make_layout((8,), stride=(1,))) rmem_t = cute.make_fragment_like(smem_t) cute.autovec_copy(smem_t, rmem_t) return rmem_t + @cute.jit def smem_store_bf16x8_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32, data: cute.Tensor): """ @@ -193,22 +192,25 @@ def smem_store_bf16x8_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32, swizzled = k_inner ^ ((row & Int32(7)) << Int32(3)) elem_off = half * Int32(4096) + row * Int32(64) + swizzled smem_ptr = cute.make_ptr( - BFloat16, (raw_ptr + elem_off).toint(), - cute.AddressSpace.smem, assumed_align=16, + BFloat16, + (raw_ptr + elem_off).toint(), + cute.AddressSpace.smem, + assumed_align=16, ) smem_t = cute.make_tensor(smem_ptr, cute.make_layout((8,), stride=(1,))) cute.autovec_copy(data, smem_t) + @cute.jit def smem_load_f32x4_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32): """ Load 4 consecutive float32 from SMEM with K_SW128 layout. Logical layout: [BT=64, BK=128] ROW_MAJOR, tiled over a Float32 K_SW128 atom. - The atom provides a 32-element row stride. The 128-element column is broken + The atom provides a 32-element row stride. The 128-element column is broken into 4 blocks of 32 elements. PyCutlass tiles this such that outer blocks stride by 2048 elements: elem_idx = row * 32 + (col_base % 32) + (col_base / 32) * 2048 - + The TMA hardware performs a 128B Swizzle on physical byte addresses: byte_idx = elem_idx * 4 swizzled_byte = byte_idx ^ (((byte_idx >> 7) & 7) << 4) @@ -221,17 +223,20 @@ def smem_load_f32x4_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32): c_inner = col_base & Int32(31) c_outer = col_base >> Int32(5) swizzled_inner = c_inner ^ ((row & Int32(7)) << Int32(2)) - + elem_offset = row * Int32(32) + swizzled_inner + c_outer * Int32(2048) - + aligned_ptr = cute.make_ptr( - Float32, (raw_ptr + elem_offset).toint(), - cute.AddressSpace.smem, assumed_align=16, + Float32, + (raw_ptr + elem_offset).toint(), + cute.AddressSpace.smem, + assumed_align=16, ) t = cute.make_tensor(aligned_ptr, cute.make_layout((4,), stride=(1,))) vals = t.load() return (vals[0], vals[1], vals[2], vals[3]) + @cute.jit def smem_store_f32x4_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32, data: cute.Tensor): """ @@ -247,17 +252,24 @@ def smem_store_f32x4_sw128(raw_ptr: cute.Pointer, row: Int32, col_base: Int32, d swizzled_inner = c_inner ^ ((row & Int32(7)) << Int32(2)) elem_offset = row * Int32(32) + swizzled_inner + c_outer * Int32(2048) smem_ptr = cute.make_ptr( - Float32, (raw_ptr + elem_offset).toint(), - cute.AddressSpace.smem, assumed_align=16, + Float32, + (raw_ptr + elem_offset).toint(), + cute.AddressSpace.smem, + assumed_align=16, ) smem_t = cute.make_tensor(smem_ptr, cute.make_layout((4,), stride=(1,))) cute.autovec_copy(data, smem_t) + @cute.jit def mma_ws_ss_m64n128_call( - a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, - b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, - tmem_c: Int32, K: Int32, is_accum: bool = False, + a_smem_layout: cute.Layout, + desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, + desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, + K: Int32, + is_accum: bool = False, ): with elect_one(): a_outer = a_smem_layout.outer @@ -271,11 +283,16 @@ def mma_ws_ss_m64n128_call( tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N128_K_MN, scale) scale = 1 + @cute.jit def mma_ws_ss_m64n128_k_k_call( - a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, - b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, - tmem_c: Int32, K: Int32, is_accum: bool = False, + a_smem_layout: cute.Layout, + desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, + desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, + K: Int32, + is_accum: bool = False, ): with elect_one(): a_outer = a_smem_layout.outer @@ -289,11 +306,16 @@ def mma_ws_ss_m64n128_k_k_call( tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N128_K_K, scale) scale = 1 + @cute.jit def mma_ws_ss_m64n128_mn_mn_call( - a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, - b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, - tmem_c: Int32, K: Int32, is_accum: bool = False, + a_smem_layout: cute.Layout, + desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, + desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, + K: Int32, + is_accum: bool = False, ): with elect_one(): a_outer = a_smem_layout.outer @@ -307,11 +329,16 @@ def mma_ws_ss_m64n128_mn_mn_call( tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N128_MN_MN, scale) scale = 1 + @cute.jit def mma_ws_ss_m64n64_k_k_call( - a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, - b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, - tmem_c: Int32, K: Int32, is_accum: bool = False, + a_smem_layout: cute.Layout, + desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, + desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, + K: Int32, + is_accum: bool = False, ): with elect_one(): a_outer = a_smem_layout.outer @@ -325,11 +352,16 @@ def mma_ws_ss_m64n64_k_k_call( tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N64_K_K, scale) scale = 1 + @cute.jit def mma_ws_ss_m64n64_mn_mn_call( - a_smem_layout: cute.Layout, desc_a_base: Tcgen05SmemDescriptor, - b_smem_layout: cute.Layout, desc_b_base: Tcgen05SmemDescriptor, - tmem_c: Int32, K: Int32, is_accum: bool = False, + a_smem_layout: cute.Layout, + desc_a_base: Tcgen05SmemDescriptor, + b_smem_layout: cute.Layout, + desc_b_base: Tcgen05SmemDescriptor, + tmem_c: Int32, + K: Int32, + is_accum: bool = False, ): with elect_one(): a_outer = a_smem_layout.outer @@ -343,12 +375,14 @@ def mma_ws_ss_m64n64_mn_mn_call( tcgen05mma_ws_ss_f16(desc_a, desc_b, tmem_c, IDESC_F16_M64_N64_MN_MN, scale) scale = 1 + @cute.jit def umma_arrive(mbar_ptr: cute.Pointer): """tcgen05.commit.cta_group::1.mbarrier::arrive::one — signal MMA done.""" with elect_one(): tcgen05.commit(mbar_ptr, cta_group=tcgen05.CtaGroup.ONE) + class ChunkKdaBwdWyDqkgFused: """ CuTe DSL kernel for chunk_kda_bwd_kernel_wy_dqkg_fused. @@ -525,8 +559,6 @@ def __call__( B, T, H, HV, K, V = problem_size BT = self.BT - BK = self.BK - BV = self.BV data_B = Int32(1) NT = total_nt @@ -567,11 +599,6 @@ def __call__( beta = cute.make_tensor(beta_ptr, beta_layout) # A: (T, BT, (HV, data_B)) bf16 - a_layout = cute.make_layout( - (T, BT, (HV, data_B)), - stride=(HV * BT, 1, (BT, T * HV * BT)), - ) - A = cute.make_tensor(A_ptr, a_layout) # NOTE: for A as operand A, A is loaded as transposed view to do MMA a_t_layout = cute.make_layout( (BT, T, (HV, data_B)), @@ -610,13 +637,6 @@ def __call__( h = cute.make_tensor(h_ptr, h_layout) dh = cute.make_tensor(dh_ptr, h_layout) - # Transposed views for V-loop TMA (data loaded as MMA B-operands): - vt_layout = cute.make_layout( - (V, T, (data_B, HV)), - stride=(1, HV * V, (T * HV * V, V)), - ) - v_T = cute.make_tensor(v_ptr, vt_layout) - # ===================== MMA setup (4 objects) ===================== # All use tcgen05.mma.ws (Layout E, M=64, cta_group::1). # 1. vloop_tiled_mma: SS K,K (64,128) — dq, dk, dw @@ -663,7 +683,7 @@ def __call__( self.cta_group, self.kloop_dkgb_tiler[:2], # (64, 128) ) - + # dA_kloop_tiled_mma: SS K,K (64, 64) # dA += dw @ kg^T dA_kloop_tiled_mma = sm100_utils.make_trivial_tiled_mma( @@ -672,7 +692,7 @@ def __call__( tcgen05.OperandMajorMode.K, self.acc_dtype, self.cta_group, - self.kloop_dA_tiler[:2] # (64, 64) + self.kloop_dA_tiler[:2], # (64, 64) ) # dA2post_tiled_mma: SS K,K (64,64) @@ -1104,7 +1124,16 @@ class SharedStorage: k_epi_smem_layout, q_epi_smem_layout, # GMEM tensors - q, k, g, beta, dq, dk, dv2, dg, db, dA_out, + q, + k, + g, + beta, + dq, + dk, + dv2, + dg, + db, + dA_out, # Metadata cu_seqlens, chunk_indices, @@ -1116,7 +1145,7 @@ class SharedStorage: stream=stream, min_blocks_per_mp=self.min_occupancy, ) - + @cute.kernel def kernel( self, @@ -1183,19 +1212,17 @@ def kernel( ): B, T, H, HV, K, V = problem_size BT = self.BT - BK, BV = self.BK, self.BV # ===================== Persistent work decode ===================== # Grid: (min(num_sm * occ, total_tiles), 1, 1) — persistent block_idx_x = cute.arch.block_idx()[0] grid_dim_x = cute.arch.grid_dim()[0] - thread_idx = cute.arch.thread_idx()[0] + thread_idx = cute.arch.thread_idx()[0] lane_idx = thread_idx % 32 total_work_units = chunk_indices.layout.shape[0] * HV num_iters = (total_work_units - block_idx_x + grid_dim_x - 1) // grid_dim_x - - num_cuda_warps = len(self.cuda_warp_ids) + num_cuda_warps_total = len(self.cuda_warp_ids) + len(self.cuda2_warp_ids) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -1388,7 +1415,6 @@ def kernel( vloop_opA_smem_no_stage = cute.select(vloop_opA_smem, mode=[0, 1, 2]) vloop_opB_smem_no_stage = cute.select(vloop_opB_smem, mode=[0, 1, 2]) A_mn_opA_smem_no_stage = cute.select(A_mn_opA_smem, mode=[0, 1, 2]) - dv_mn_opB_smem_no_stage = cute.select(dv_mn_opB_smem, mode=[0, 1, 2]) v_opB_smem_no_stage = cute.select(v_opB_smem, mode=[0, 1, 2]) sA = storage.buf_A.get_tensor(A_mn_opA_smem.outer, swizzle=A_mn_opA_smem.inner) @@ -1399,9 +1425,6 @@ def kernel( sVnew = storage.buf_vnew.get_tensor(vloop_opA_smem.outer, swizzle=vloop_opA_smem.inner) sV = storage.buf_v.get_tensor(v_opB_smem.outer, swizzle=v_opB_smem.inner) - sA_raw = cute.make_ptr( - self.io_dtype, storage.buf_A.data_ptr().toint(), cute.AddressSpace.smem, - ) sDv_ptr_base = storage.buf_dv.data_ptr().toint() vloop_opA_bytes_per_stage = cute.size_in_bytes(self.io_dtype, vloop_opA_smem_no_stage) sDo_ptr_base = storage.buf_do.data_ptr().toint() @@ -1585,12 +1608,7 @@ def kernel( ), g_epi_smem_layout.outer, ) - sG_raw_ptr = cute.make_ptr( - self.g_dtype, storage.buf_g.data_ptr().toint(), cute.AddressSpace.smem - ) - sV_raw_ptr = cute.make_ptr( - self.io_dtype, storage.buf_v.data_ptr().toint(), cute.AddressSpace.smem - ) + sG_raw_ptr = cute.make_ptr(self.g_dtype, storage.buf_g.data_ptr().toint(), cute.AddressSpace.smem) sK_raw = cute.make_tensor( cute.recast_ptr( cute.make_ptr( @@ -1604,12 +1622,8 @@ def kernel( ), k_epi_smem_layout.outer, ) - sK_raw_ptr = cute.make_ptr( - self.io_dtype, storage.buf_k.data_ptr().toint(), cute.AddressSpace.smem - ) - sDw_raw_ptr = cute.make_ptr( - self.io_dtype, storage.buf_dw.data_ptr().toint(), cute.AddressSpace.smem - ) + sK_raw_ptr = cute.make_ptr(self.io_dtype, storage.buf_k.data_ptr().toint(), cute.AddressSpace.smem) + sDw_raw_ptr = cute.make_ptr(self.io_dtype, storage.buf_dw.data_ptr().toint(), cute.AddressSpace.smem) sQ_raw = cute.make_tensor( cute.recast_ptr( cute.make_ptr( @@ -1623,14 +1637,12 @@ def kernel( ), q_epi_smem_layout.outer, ) - sQ_raw_ptr = cute.make_ptr( - self.io_dtype, storage.buf_q.data_ptr().toint(), cute.AddressSpace.smem - ) + sQ_raw_ptr = cute.make_ptr(self.io_dtype, storage.buf_q.data_ptr().toint(), cute.AddressSpace.smem) # Scalar SMEM buffers (plain layouts, no swizzle) sBeta = cute.make_tensor( cute.make_ptr(Float32, storage.s_beta.data_ptr().toint(), cute.AddressSpace.smem), - cute.make_layout((self.BT, ), stride=(1, )), + cute.make_layout((self.BT,), stride=(1,)), ) # sDb layout: (BT, 2). Inner dim = wg_idx slot. Stride (1, BT) so each # wg's column is contiguous (better for the reduce in Phase 3). @@ -1640,11 +1652,11 @@ def kernel( ) sDgk = cute.make_tensor( cute.make_ptr(Float32, storage.s_dgk.data_ptr().toint(), cute.AddressSpace.smem), - cute.make_layout((self.BK, ), stride=(1, )), + cute.make_layout((self.BK,), stride=(1,)), ) sGn = cute.make_tensor( cute.make_ptr(Float32, storage.s_gn.data_ptr().toint(), cute.AddressSpace.smem), - cute.make_layout((self.BK, ), stride=(1, )), + cute.make_layout((self.BK,), stride=(1,)), ) # @@ -1661,71 +1673,36 @@ def kernel( if warp_idx in self.cuda_warp_ids or warp_idx in self.cuda2_warp_ids: cute.arch.setmaxregister_increase(self.num_regs_cuda) - load_beta_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, 1 - ) - load_g_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.kloop_stage - ) - mma_dvb_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - mma_dq_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - mma_dw_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - mma_dk_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - load_k_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.kloop_stage - ) - prologue_dw_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.kloop_stage - ) - prologue_kg_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.kloop_stage - ) - mma_dgkb_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - load_q_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.kloop_stage - ) - mma_dA_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - mma_dA2_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - mma_dA3_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - prologue_dA2_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - prologue_dA3_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - store_dg_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.kloop_stage - ) + load_beta_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1) + load_g_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kloop_stage) + mma_dvb_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + mma_dq_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + mma_dw_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + mma_dk_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + load_k_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kloop_stage) + prologue_dw_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.kloop_stage) + prologue_kg_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.kloop_stage) + mma_dgkb_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + load_q_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kloop_stage) + mma_dA_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + mma_dA2_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + mma_dA3_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + prologue_dA2_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + prologue_dA3_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + store_dg_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.kloop_stage) wg_idx = tidx // 128 local_tidx = tidx % 128 - sub_wg_idx = tidx // 64 warp_id = local_tidx // 32 warp_row_tile = warp_id % 2 warp_col_tile = warp_id // 2 - row = warp_row_tile * 32 + lane_idx # BT1 + row = warp_row_tile * 32 + lane_idx # BT1 bk_num_cols = self.BK // 2 bv_num_cols = self.BV // 2 bk_num_cols_per_wg = bk_num_cols // 2 bv_num_cols_per_wg = bv_num_cols // 2 bt_num_cols_per_wg = self.BT // 4 - # ref: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e + # ref: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e bv_col_base = warp_col_tile * (self.BV // 2) + wg_idx * bv_num_cols_per_wg bk_col_base = warp_col_tile * (self.BK // 2) + wg_idx * bk_num_cols_per_wg bt_col_base = warp_col_tile * (self.BT // 2) + wg_idx * bt_num_cols_per_wg @@ -1739,7 +1716,7 @@ def kernel( G = HV // H i_t = work_idx // HV # chunk index (global) i_hv = work_idx % HV # value-head index - i_h = i_hv // G # q/k head index + i_h = i_hv // G # q/k head index # Decode chunk_indices batch_idx = chunk_indices[(i_t, 0)] tile_idx = chunk_indices[(i_t, 1)] @@ -1797,7 +1774,7 @@ def kernel( dvb_i32 = tcgen05_ld_32x32b(bv_num_cols_per_wg, TMEM_FLEX_OFF + wg_idx * bv_num_cols_per_wg) tcgen05_fence_before() cute.arch.fence_view_async_tmem_load() - + pipeline_mma_dvb.consumer_release(mma_dvb_consumer_state) mma_dvb_consumer_state.advance() @@ -1842,9 +1819,7 @@ def kernel( # bf16 vector → i32 vector for store_256b (8 i32 = 16 bf16 = 32 bytes per store). dvb_bf16_val = dvb_bf16_rmem.load() - dvb_i32_vec = reinterpret_cast( - dvb_bf16_val, self.io_dtype, bv_num_cols_per_wg, Int32 - ) + dvb_i32_vec = reinterpret_cast(dvb_bf16_val, self.io_dtype, bv_num_cols_per_wg, Int32) # bv_num_cols bf16 = bv_num_cols // 16 stores of 256b each. num_stores_per_row = bv_num_cols_per_wg // 16 # = 4 for BV=128 @@ -1869,7 +1844,7 @@ def kernel( sGn[local_tidx] = sG_raw[(sub_seq_len - 1, local_tidx, 0)] # row-major load, match TMEM layout - rG = cute.make_rmem_tensor((self.BK // 4, ), self.g_dtype) + rG = cute.make_rmem_tensor((self.BK // 4,), self.g_dtype) if row < sub_seq_len: for i in cutlass.range_constexpr(self.BK // 4 // 4): col_base = bk_col_base + i * 4 @@ -1900,17 +1875,12 @@ def kernel( rDq.store(dq_f32_val * rG_exp_val * Float32(self.scale)) dq_f32_val_store = rDq.load() - dq_i32_vec = reinterpret_cast( - dq_f32_val_store, Float32, bk_num_cols_per_wg, Int32 - ) + dq_i32_vec = reinterpret_cast(dq_f32_val_store, Float32, bk_num_cols_per_wg, Int32) # store to TMEM first to reduce register usage tcgen05_st_32x32b(bk_num_cols_per_wg, TMEM_DQ_SCALED_OFF + wg_idx * bk_num_cols_per_wg, dq_i32_vec) cute.arch.fence_view_async_tmem_store() dq_base_addr = ( - dq_gmem.iterator - + (tok_offset + tile_idx * self.BT + row) * HV * K - + i_hv * K - + bk_col_base + dq_gmem.iterator + (tok_offset + tile_idx * self.BT + row) * HV * K + i_hv * K + bk_col_base ).toint() if row < sub_seq_len: for s in cutlass.range_constexpr(num_stores_f32): @@ -1951,7 +1921,7 @@ def kernel( pipeline_load_k.consumer_wait(load_k_consumer_state) # compute kg = k * gk_exp - rK = cute.make_rmem_tensor((self.BK // 4, ), self.io_dtype) + rK = cute.make_rmem_tensor((self.BK // 4,), self.io_dtype) if row < sub_seq_len: for i in cutlass.range_constexpr(self.BK // 4 // 8): col_base = bk_col_base + i * 8 @@ -1966,14 +1936,14 @@ def kernel( rK[i * 8 + 7] = vals[7] else: rK.fill(BFloat16(0.0)) - rK_fp32 = cute.make_rmem_tensor((self.BK // 4, ), Float32) + rK_fp32 = cute.make_rmem_tensor((self.BK // 4,), Float32) rK_fp32.store(rK.load().to(Float32)) rK_fp32_val = rK_fp32.load() rKG_val = rK_fp32_val * rG_exp_val - # write kg to K smem, + # write kg to K smem, # notify dA += dw @ kg^T - rKG_bf16 = cute.make_rmem_tensor((self.BK // 4, ), BFloat16) + rKG_bf16 = cute.make_rmem_tensor((self.BK // 4,), BFloat16) rKG_bf16.store(rKG_val.to(BFloat16)) pipeline_prologue_kg.producer_acquire(prologue_kg_producer_state) @@ -2005,7 +1975,7 @@ def kernel( if row < sub_seq_len: for i in cutlass.range_constexpr(bk_num_cols_per_wg): db_val += rKgb_kg[i] - + # Deterministic db reduction without atomicAdd. # 4 partitions per row come from 4 warps (warp_row_tile in {0,1}, # warp_col_tile in {0,1}) x 2 wgs. Reduce in a fixed order so @@ -2054,22 +2024,17 @@ def kernel( rKdk.store(rK_fp32.load() * rDk.load()) # gb = gk_exp * beta[:, None] - rGb = cute.make_rmem_tensor((bk_num_cols_per_wg, ), Float32) + rGb = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) rGb.store(rG_exp_val * beta_val) # dk = dk + dkgb * gb rDk.store(rDk.load() + dkgb_f32_val * rGb.load()) rDk_val = rDk.load() - dk_i32_vec = reinterpret_cast( - rDk_val, Float32, bk_num_cols_per_wg, Int32 - ) + dk_i32_vec = reinterpret_cast(rDk_val, Float32, bk_num_cols_per_wg, Int32) # GMEM store dk # 8 fp32 store each time for store_256b dk_base_addr = ( - dk_gmem.iterator - + (tok_offset + tile_idx * self.BT + row) * HV * K - + i_hv * K - + bk_col_base + dk_gmem.iterator + (tok_offset + tile_idx * self.BT + row) * HV * K + i_hv * K + bk_col_base ).toint() if row < sub_seq_len: for s in cutlass.range_constexpr(num_stores_f32): @@ -2086,14 +2051,14 @@ def kernel( # dgk *= exp2(gn) if wg_idx == 0: - sDgk[(local_tidx, )] *= cute.exp2(sGn[(local_tidx,)], fastmath=self.use_fast_math) + sDgk[(local_tidx,)] *= cute.exp2(sGn[(local_tidx,)], fastmath=self.use_fast_math) self.cuda_wg_sync_barrier.arrive_and_wait() if wg_idx == 0: sum = Float32(0.0) for r in cutlass.range(self.BT, unroll_full=True): sum += sG_raw[(r, local_tidx, 0)] - sDgk[(local_tidx, )] += sum + sDgk[(local_tidx,)] += sum # dg1 = kg * dkgb * beta[:, None], can reuse kg RMEM rDg = cute.make_rmem_tensor((bk_num_cols_per_wg,), Float32) @@ -2145,7 +2110,7 @@ def kernel( pipeline_load_g.consumer_release(load_g_consumer_state) load_g_consumer_state.advance() - + pipeline_mma_dA.consumer_wait(mma_dA_consumer_state) tcgen05_fence_after() dA_i32 = tcgen05_ld_32x32b(bt_num_cols_per_wg, TMEM_DA_ACC_OFF + wg_idx * bt_num_cols_per_wg) @@ -2163,7 +2128,7 @@ def kernel( # and keeps only `row > col`. dA_f32 = reinterpret_cast(dA_i32, Int32, bt_num_cols_per_wg, Float32) dA_f32_val = TensorSSA(dA_f32, (bt_num_cols_per_wg,), Float32) - rDA = cute.make_rmem_tensor((bt_num_cols_per_wg, ), BFloat16) + rDA = cute.make_rmem_tensor((bt_num_cols_per_wg,), BFloat16) for i in cutlass.range_constexpr(bt_num_cols_per_wg): col = bt_col_base + i beta_col = sBeta[(col,)] @@ -2200,7 +2165,7 @@ def kernel( # write dA2 to smem notify dA2 = A @ dA2 dA2_f32 = reinterpret_cast(dA2_i32, Int32, bt_num_cols_per_wg, Float32) dA2_f32_val = TensorSSA(dA2_f32, (bt_num_cols_per_wg,), Float32) - rDA2 = cute.make_rmem_tensor((bt_num_cols_per_wg, ), BFloat16) + rDA2 = cute.make_rmem_tensor((bt_num_cols_per_wg,), BFloat16) if row < sub_seq_len: rDA2.store(dA2_f32_val.to(BFloat16)) else: @@ -2233,23 +2198,18 @@ def kernel( # dA = -dA, apply strict lower-triangular mask dA3_f32 = reinterpret_cast(dA3_i32, Int32, bt_num_cols_per_wg, Float32) dA3_f32_val = TensorSSA(dA3_f32, (bt_num_cols_per_wg,), Float32) - rDA3 = cute.make_rmem_tensor((bt_num_cols_per_wg, ), Float32) + rDA3 = cute.make_rmem_tensor((bt_num_cols_per_wg,), Float32) rDA3.store(-dA3_f32_val) for i in cutlass.range_constexpr(bt_num_cols_per_wg): col = bt_col_base + i if col >= row: rDA3[i] = Float32(0.0) rDA3_val = rDA3.load() - dA3_i32_vec = reinterpret_cast( - rDA3_val, Float32, bt_num_cols_per_wg, Int32 - ) + dA3_i32_vec = reinterpret_cast(rDA3_val, Float32, bt_num_cols_per_wg, Int32) # GMEM store dA num_stores_dA = bt_num_cols_per_wg // 8 dA_base_addr = ( - dA_gmem.iterator - + (tok_offset + tile_idx * self.BT + row) * HV * BT - + i_hv * BT - + bt_col_base + dA_gmem.iterator + (tok_offset + tile_idx * self.BT + row) * HV * BT + i_hv * BT + bt_col_base ).toint() if row < sub_seq_len: for s in cutlass.range_constexpr(num_stores_dA): @@ -2260,36 +2220,22 @@ def kernel( elif warp_idx == self.load_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_others) - load_A_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.a_stage - ) - load_dv_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.vloop_stage - ) - load_do_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.vloop_stage - ) - load_vnew_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.vloop_stage - ) - load_g_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.kloop_stage - ) - load_k_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.kloop_stage - ) - load_q_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.kloop_stage - ) + load_A_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.a_stage) + load_dv_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.vloop_stage) + load_do_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.vloop_stage) + load_vnew_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.vloop_stage) + load_g_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.kloop_stage) + load_k_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.kloop_stage) + load_q_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.kloop_stage) vloop_stage_idx = 0 - vloop_phase = 1 # init as 1 for producer + vloop_phase = 1 # init as 1 for producer for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x G = HV // H i_t = work_idx // HV # chunk index (global) i_hv = work_idx % HV # value-head index - i_h = i_hv // G # q/k head index + i_h = i_hv // G # q/k head index # Decode chunk_indices batch_idx = chunk_indices[(i_t, 0)] @@ -2304,7 +2250,7 @@ def kernel( tma_atom_A, tma_A_v, sA, - self.dvb_tiler, # [BT, BV, BT] + self.dvb_tiler, # [BT, BV, BT] dvb_tiled_mma, Int32(0), i_hv, @@ -2325,9 +2271,10 @@ def kernel( tma_atom_h, tma_h_v, sH, - self.vloop_gemm_tiler, # [BT, BK, BV] + self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - i_hv, i_t + i_hv, + i_t, ) mbarrier_wait(bar_mma_cuda_h_ptr + vloop_stage_idx, vloop_phase) with elect_one(): @@ -2344,9 +2291,10 @@ def kernel( tma_atom_dh, tma_dh_v, sDh, - self.vloop_gemm_tiler, # [BT, BK, BV] + self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - i_hv, i_t + i_hv, + i_t, ) mbarrier_wait(bar_mma_cuda_dh_ptr + vloop_stage_idx, vloop_phase) with elect_one(): @@ -2363,9 +2311,10 @@ def kernel( tma_atom_do, tma_do_v, sDo, - self.vloop_gemm_tiler, # [BT, BK, BV] + self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - Int32(0), i_hv, + Int32(0), + i_hv, ) pipeline_load_do.producer_acquire(load_do_producer_state) cute.copy( @@ -2381,9 +2330,10 @@ def kernel( tma_atom_dv, tma_dv_v, sDv, - self.vloop_gemm_tiler, # [BT, BK, BV] + self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - Int32(0), i_hv, + Int32(0), + i_hv, ) pipeline_load_dv.producer_acquire(load_dv_producer_state) cute.copy( @@ -2399,9 +2349,10 @@ def kernel( tma_atom_v, tma_v_v, sV, - self.dA_vloop_tiler, # [BT, BT, BV] + self.dA_vloop_tiler, # [BT, BT, BV] dA_vloop_tiled_mma, - Int32(0), i_hv, + Int32(0), + i_hv, ) mbarrier_wait(bar_mma_cuda_v_ptr + vloop_stage_idx, vloop_phase) with elect_one(): @@ -2419,9 +2370,10 @@ def kernel( tma_atom_vnew, tma_vnew_v, sVnew, - self.vloop_gemm_tiler, # [BT, BK, BV] + self.vloop_gemm_tiler, # [BT, BK, BV] vloop_tiled_mma, - Int32(0), i_hv, + Int32(0), + i_hv, ) pipeline_load_vnew.producer_acquire(load_vnew_producer_state) cute.copy( @@ -2434,7 +2386,7 @@ def kernel( vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage vloop_phase ^= 1 - + # Load g tma_g_v = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_g) tGsG, tGgG = self._epilog_partition_varlen( @@ -2447,7 +2399,7 @@ def kernel( cute.copy( tma_atom_g, tGgG[(None, tile_idx, 0)], - tGsG[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 + tGsG[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 tma_bar_ptr=pipeline_load_g.producer_get_barrier(load_g_producer_state), ) load_g_producer_state.advance() @@ -2464,7 +2416,7 @@ def kernel( cute.copy( tma_atom_k, tKgK[(None, tile_idx, 0)], - tKsK[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 + tKsK[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 tma_bar_ptr=pipeline_load_k.producer_get_barrier(load_k_producer_state), ) load_k_producer_state.advance() @@ -2480,64 +2432,31 @@ def kernel( cute.copy( tma_atom_q, tQgQ[(None, tile_idx, 0)], - tQsQ[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 + tQsQ[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 tma_bar_ptr=pipeline_load_q.producer_get_barrier(load_q_producer_state), ) load_q_producer_state.advance() - # MMA loop body elif warp_idx == self.mma_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_others) - load_A_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.a_stage - ) - load_dv_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) - mma_dvb_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - load_do_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) - load_vnew_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.vloop_stage - ) - mma_dq_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - mma_dk_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - mma_dw_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - prologue_dw_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.kloop_stage - ) - prologue_kg_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.kloop_stage - ) - mma_dgkb_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - mma_dA_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - mma_dA2_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - mma_dA3_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_stage - ) - prologue_dA2_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) - prologue_dA3_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_stage - ) + load_A_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.a_stage) + load_dv_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.vloop_stage) + mma_dvb_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + load_do_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.vloop_stage) + load_vnew_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.vloop_stage) + mma_dq_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + mma_dk_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + mma_dw_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + prologue_dw_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kloop_stage) + prologue_kg_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kloop_stage) + mma_dgkb_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + mma_dA_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + mma_dA2_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + mma_dA3_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_stage) + prologue_dA2_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) + prologue_dA3_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_stage) vloop_stage_idx = 0 a_stage_idx = 0 @@ -2548,7 +2467,7 @@ def kernel( G = HV // H i_t = work_idx // HV # chunk index (global) i_hv = work_idx % HV # value-head index (unused in MMA warp) - i_h = i_hv // G # q/k head index (unused in MMA warp) + i_h = i_hv // G # q/k head index (unused in MMA warp) # Decode chunk_indices batch_idx = chunk_indices[(i_t, 0)] @@ -2557,7 +2476,7 @@ def kernel( seq_len = cu_seqlens[(batch_idx + 1,)] - tok_offset sub_seq_len = min(self.BT, seq_len - tile_idx * self.BT) - zeros8 = cute.make_rmem_tensor((8, ), dtype=self.io_dtype) + zeros8 = cute.make_rmem_tensor((8,), dtype=self.io_dtype) zeros8.fill(BFloat16(0.0)) pipeline_load_A.consumer_wait(load_A_consumer_state) @@ -2594,7 +2513,7 @@ def kernel( # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sDo_raw_ptr, row, col * 8, zeros8) cute.arch.fence_proxy("async.shared", space="cta") - + if v_iter == 0: pipeline_mma_dq.producer_acquire(mma_dq_producer_state) @@ -2605,7 +2524,9 @@ def kernel( desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sH_k_cur.iterator, sH_k_cur.layout, "k")) desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) - mma_ws_ss_m64n128_k_k_call(vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DQ_ACC_OFF, self.BV, is_accum) + mma_ws_ss_m64n128_k_k_call( + vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DQ_ACC_OFF, self.BV, is_accum + ) pipeline_load_do.consumer_release(load_do_consumer_state) load_do_consumer_state.advance() @@ -2628,7 +2549,7 @@ def kernel( # dv tile uses the same Swizzle<3,4,3> physical mapping. smem_store_bf16x8_sw128(sDv_raw, row, col * 8, zeros8) cute.arch.fence_proxy("async.shared", space="cta") - + # if lane_idx == 0: # cute.printf("V_iter", v_iter) # cute.print_tensor(sDv[None, None, None, vloop_stage_idx]) @@ -2639,7 +2560,9 @@ def kernel( desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDv_mn_cur.iterator, sDv_mn_cur.layout, "mn")) desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) - mma_ws_ss_m64n64_mn_mn_call(A_mn_opA_smem, desc_a_base, dv_mn_opB_smem, desc_b_base, TMEM_FLEX_OFF, self.BT) + mma_ws_ss_m64n64_mn_mn_call( + A_mn_opA_smem, desc_a_base, dv_mn_opB_smem, desc_b_base, TMEM_FLEX_OFF, self.BT + ) pipeline_mma_dvb.producer_commit(mma_dvb_producer_state) mma_dvb_producer_state.advance() @@ -2647,13 +2570,15 @@ def kernel( # dw += dv @ h if v_iter == 0: pipeline_mma_dw.producer_acquire(mma_dw_producer_state) - + sDv_k_cur = sDv_k[(None, None, None, vloop_stage_idx)] desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDv_k_cur.iterator, sDv_k_cur.layout, "k")) desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sH_k_cur.iterator, sH_k_cur.layout, "k")) desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) - mma_ws_ss_m64n128_k_k_call(vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DW_ACC_OFF, self.BV, is_accum) + mma_ws_ss_m64n128_k_k_call( + vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DW_ACC_OFF, self.BV, is_accum + ) # dA += dv @ v^T mbarrier_wait(bar_tma_v_ptr + vloop_stage_idx, vloop_phase) @@ -2677,7 +2602,9 @@ def kernel( desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sV_k_cur.iterator, sV_k_cur.layout, "k")) desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) - mma_ws_ss_m64n64_k_k_call(vloop_opA_smem, desc_a_base, v_opB_smem, desc_b_base, TMEM_DA_ACC_OFF, self.BV, is_accum) + mma_ws_ss_m64n64_k_k_call( + vloop_opA_smem, desc_a_base, v_opB_smem, desc_b_base, TMEM_DA_ACC_OFF, self.BV, is_accum + ) # dv pipeline calls tcgen05.commit for dv@h and dv@v^T pipeline_load_dv.consumer_release(load_dv_consumer_state) @@ -2709,14 +2636,16 @@ def kernel( mbarrier_wait(bar_tma_dh_ptr + vloop_stage_idx, vloop_phase) if v_iter == 0: pipeline_mma_dk.producer_acquire(mma_dk_producer_state) - + sVnew_k_cur = sVnew_k[(None, None, None, vloop_stage_idx)] sDh_k_cur = sDh_k[(None, None, None, vloop_stage_idx)] desc_a_i64 = smem_descriptor_to_int(make_umma_smem_desc(sVnew_k_cur.iterator, sVnew_k_cur.layout, "k")) desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDh_k_cur.iterator, sDh_k_cur.layout, "k")) desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) - mma_ws_ss_m64n128_k_k_call(vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DK_ACC_OFF, self.BV, is_accum) + mma_ws_ss_m64n128_k_k_call( + vloop_opA_smem, desc_a_base, vloop_opB_smem, desc_b_base, TMEM_DK_ACC_OFF, self.BV, is_accum + ) # vnew pipeline calls tcgen05.commit pipeline_load_vnew.consumer_release(load_vnew_consumer_state) @@ -2735,7 +2664,7 @@ def kernel( vloop_stage_idx = (vloop_stage_idx + 1) % self.vloop_stage vloop_phase ^= 1 - + pipeline_prologue_dw.consumer_wait(prologue_dw_consumer_state) cute.arch.fence_proxy("async.shared", space="cta") # dkgb = A @ dw @@ -2746,7 +2675,9 @@ def kernel( desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sDw_mn_cur.iterator, sDw_mn_cur.layout, "mn")) desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) - mma_ws_ss_m64n128_mn_mn_call(A_mn_opA_smem, desc_a_base, dw_mn_opB_smem, desc_b_base, TMEM_DKGB_ACC_OFF, self.BT) + mma_ws_ss_m64n128_mn_mn_call( + A_mn_opA_smem, desc_a_base, dw_mn_opB_smem, desc_b_base, TMEM_DKGB_ACC_OFF, self.BT + ) pipeline_mma_dkgb.producer_commit(mma_dgkb_producer_state) mma_dgkb_producer_state.advance() @@ -2760,7 +2691,9 @@ def kernel( desc_b_i64 = smem_descriptor_to_int(make_umma_smem_desc(sKG_k_cur.iterator, sKG_k_cur.layout, "k")) desc_a_base = Tcgen05SmemDescriptor(desc_a_i64) desc_b_base = Tcgen05SmemDescriptor(desc_b_i64) - mma_ws_ss_m64n64_k_k_call(dw_k_opA_smem, desc_a_base, kg_k_opB_smem, desc_b_base, TMEM_DA_ACC_OFF, self.BK, True) + mma_ws_ss_m64n64_k_k_call( + dw_k_opA_smem, desc_a_base, kg_k_opB_smem, desc_b_base, TMEM_DA_ACC_OFF, self.BK, True + ) pipeline_mma_dA.producer_commit(mma_dA_producer_state) mma_dA_producer_state.advance() @@ -2816,22 +2749,16 @@ def kernel( cute.arch.setmaxregister_decrease(self.num_regs_others) tidx = thread_idx - (self.threads_per_cta - 64) - load_beta_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, 1 - ) - load_g_store_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.kloop_stage - ) - store_dg_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.kloop_stage - ) + load_beta_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1) + load_g_store_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kloop_stage) + store_dg_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kloop_stage) for wu_iter in cutlass.range(0, num_iters, unroll=0): work_idx = block_idx_x + wu_iter * grid_dim_x G = HV // H i_t = work_idx // HV # chunk index (global) i_hv = work_idx % HV # value-head index - i_h = i_hv // G # q/k head index (unused in aux warp) + i_h = i_hv // G # q/k head index (unused in aux warp) # Decode chunk_indices batch_idx = chunk_indices[(i_t, 0)] @@ -2844,7 +2771,7 @@ def kernel( beta_f32 = Float32(0.0) if tidx < sub_seq_len: beta_f32 = Float32(beta_gmem[(tok_offset + tile_idx * self.BT + tidx, (i_hv, Int32(0)))]) - sBeta[(tidx, )] = beta_f32 + sBeta[(tidx,)] = beta_f32 cute.arch.fence_proxy("async.shared", space="cta") pipeline_load_beta.producer_commit(load_beta_producer_state) @@ -2862,7 +2789,7 @@ def kernel( ) if sub_seq_len < self.BT: # Tail chunk, direct store - store_lane_row = tidx >> Int32(4) # 0..3 + store_lane_row = tidx >> Int32(4) # 0..3 store_col_base = (tidx & Int32(15)) * Int32(8) # 0,8,...,120 for row_quad in cutlass.range_constexpr(self.BT // 4): store_row = row_quad * 4 + store_lane_row @@ -2878,9 +2805,7 @@ def kernel( dg_store_rmem[5] = vals1[1] dg_store_rmem[6] = vals1[2] dg_store_rmem[7] = vals1[3] - dg_store_i32_vec = reinterpret_cast( - dg_store_rmem.load(), Float32, 8, Int32 - ) + dg_store_i32_vec = reinterpret_cast(dg_store_rmem.load(), Float32, 8, Int32) dg_base_addr = ( dg_gmem.iterator + (tok_offset + tile_idx * self.BT + store_row) * HV * K @@ -2893,7 +2818,7 @@ def kernel( cute.arch.fence_proxy("async.shared", space="cta") cute.copy( tma_atom_dg, - tDGsDG[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 + tDGsDG[(None, 0)], # hardcode stage to 0 because kloop_stage is 1 tDGgDG[(None, tile_idx, 0)], ) cute.arch.cp_async_bulk_commit_group() @@ -3034,7 +2959,9 @@ def _compile_bwd_wy_variant(H, HV, K, V, scale, chunk_size, beta_dtype, use_fast dA_fake = make_fake_compact_tensor(cutlass.Float32, (1, sym_b, HV, BT), stride_order=(3, 2, 1, 0), assumed_align=128) h_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128) - dh_fake = make_fake_compact_tensor(cutlass.BFloat16, (1, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128) + dh_fake = make_fake_compact_tensor( + cutlass.BFloat16, (1, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128 + ) cu_fake = make_fake_compact_tensor(cutlass.Int32, (sym_cu,), assumed_align=128) ci_fake = make_fake_compact_tensor(cutlass.Int32, (sym_ci, 2), stride_order=(1, 0), assumed_align=128) diff --git a/cula/ops/intrinsics_sm100.py b/cula/ops/intrinsics_sm100.py index 087e5c6..cb22290 100644 --- a/cula/ops/intrinsics_sm100.py +++ b/cula/ops/intrinsics_sm100.py @@ -347,6 +347,7 @@ def _do(addr_val, desc_val, *, loc=None, ip=None): _do(Int32(taddr), smem_desc.desc_i64[0]) + @cute.jit def tcgen05_cp_128x128b(taddr: int, smem_desc: Tcgen05SmemDescriptor): """Async copy SMEM → TMEM with shape ``128x128b`` (``cta_group::1``). @@ -387,6 +388,7 @@ def _do(addr_val, desc_val, *, loc=None, ip=None): _do(Int32(taddr), smem_desc.desc_i64[0]) + @cute.jit def tcgen05_fence_before(): """tcgen05.fence::before_thread_sync — non-blocking ordering fence.""" @@ -396,4 +398,4 @@ def tcgen05_fence_before(): @cute.jit def tcgen05_fence_after(): """tcgen05.fence::after_thread_sync — non-blocking ordering fence.""" - _nvvm.tcgen05_fence(kind=_nvvm.Tcgen05FenceKind.AFTER_THREAD_SYNC) \ No newline at end of file + _nvvm.tcgen05_fence(kind=_nvvm.Tcgen05FenceKind.AFTER_THREAD_SYNC) diff --git a/tests/test_ptx_umma_ws.py b/tests/test_ptx_umma_ws.py index b130740..24e2edd 100644 --- a/tests/test_ptx_umma_ws.py +++ b/tests/test_ptx_umma_ws.py @@ -41,8 +41,7 @@ smem_descriptor_to_int, ) from cutlass.cute.runtime import from_dlpack -from cutlass.cute.tensor import TensorSSA -from cutlass.cute.typing import Float16, Float32, Int32, Int64, TFloat32, BFloat16 +from cutlass.cute.typing import BFloat16, Float32, Int32, Int64, TFloat32 from cula.ops.intrinsics_sm100 import ( store_256b, @@ -60,8 +59,8 @@ M_DIM, N_DIM = 64, 64 # TODO: support arbitrary K K_DIM_TF32 = 8 # kind::tf32 → K>=8, tile size -A_K_STEP_BYTES_TF32 = M_DIM * 8 * 4 # smem offset for each K-atom in operand A -B_K_STEP_BYTES_TF32 = N_DIM * 8 * 4 # smem offset for each K-atom in operand B +A_K_STEP_BYTES_TF32 = M_DIM * 8 * 4 # smem offset for each K-atom in operand A +B_K_STEP_BYTES_TF32 = N_DIM * 8 * 4 # smem offset for each K-atom in operand B K_DIM_F16 = 128 # default after sweep # NOTE: per-K-atom byte offsets are derived from the SMEM layout at runtime # (see _WsSsF16Kernel) so K_DIM_F16 can be any multiple of 16. The layout's @@ -188,7 +187,7 @@ def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): # T2R → R2G: tcgen05_ld directly into store_256b (type-agnostic, like C++ reinterpret_cast) vec_i32 = tcgen05_ld_32x32b(ACC_NUM_COLS, tmem_col) cute.arch.fence_view_async_tmem_load() - + # 1. reinterpret_cast to f32 (zero-cost bitcast) # vec_f32 = reinterpret_cast(vec_i32, Int32, ACC_NUM_COLS, Float32) @@ -417,7 +416,7 @@ class _WsSsTf32CollectorKernel: @cute.kernel def kernel(self, A_in: cute.Tensor, B_in: cute.Tensor, C_out: cute.Tensor): - M, N, K = M_DIM, N_DIM, 8 # default K with 8 + M, N, K = M_DIM, N_DIM, 8 # default K with 8 ACC_NUM_COLS = N // 2 NUM_COLS = ACC_NUM_COLS tidx, _, _ = cute.arch.thread_idx() @@ -562,6 +561,7 @@ def test_ws_ss_f16(): assert rel < 0.02, f"FAIL N={N}, K={K}: rel={rel:.4f}" print(f" PASSED (N={N}, K={K})") + def test_ws_ss_tf32_collector(): """Explicit collector_b_buffer=B0, collector_op=DISCARD should match default.""" print("\n=== Test 2: tcgen05mma_ws_ss_tf32 + collector (B0::DISCARD) ===") From 2b3f1f599072f4721b0f8e4f639eea439f2a6aa1 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Wed, 13 May 2026 11:36:43 +0800 Subject: [PATCH 24/26] add nan and inf check --- benchmarks/bench_kda_fwd_bwd_e2e.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_kda_fwd_bwd_e2e.py b/benchmarks/bench_kda_fwd_bwd_e2e.py index f8e193c..1614011 100644 --- a/benchmarks/bench_kda_fwd_bwd_e2e.py +++ b/benchmarks/bench_kda_fwd_bwd_e2e.py @@ -222,6 +222,8 @@ def check_determinism(num_seqs=5, T=512, iters=20): for i in range(iters): out = run_kda_e2e_with_grads(**common, fn=cula_chunk_kda) for name in ("o", "ht", "dq", "dk", "dv", "dg", "dbeta", "dh0"): + assert torch.isnan(out[name]).sum() == 0, f"[determinism] cuLA {name} has NaNs at iter {i}" + assert torch.isfinite(out[name]).all(), f"[determinism] cuLA {name} has infs at iter {i}" assert torch.equal(out[name], ref[name]), f"[determinism] cuLA {name} mismatch at iter {i}" return True @@ -577,7 +579,7 @@ def main(): det_configs = [(5, 1024), (10, 4096), (10, 8192), (10, 16384)] print("\n[Determinism Check] cuLA chunk_kda E2E ...") for num_seqs, T in det_configs: - result = check_determinism(num_seqs=num_seqs, T=T, iters=1000) + result = check_determinism(num_seqs=num_seqs, T=T, iters=2000) print(f" num_seqs={num_seqs} T={T:5d} {'PASS' if result else 'FAIL'}") print("[Determinism Check] All passed.\n") From c56acfcd47ebbb9b2e112f70beb4d18b1bea0ff4 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Wed, 13 May 2026 11:37:36 +0800 Subject: [PATCH 25/26] fix --- benchmarks/bench_kda_fwd_bwd_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_kda_fwd_bwd_e2e.py b/benchmarks/bench_kda_fwd_bwd_e2e.py index 1614011..24c6721 100644 --- a/benchmarks/bench_kda_fwd_bwd_e2e.py +++ b/benchmarks/bench_kda_fwd_bwd_e2e.py @@ -579,7 +579,7 @@ def main(): det_configs = [(5, 1024), (10, 4096), (10, 8192), (10, 16384)] print("\n[Determinism Check] cuLA chunk_kda E2E ...") for num_seqs, T in det_configs: - result = check_determinism(num_seqs=num_seqs, T=T, iters=2000) + result = check_determinism(num_seqs=num_seqs, T=T, iters=1000) print(f" num_seqs={num_seqs} T={T:5d} {'PASS' if result else 'FAIL'}") print("[Determinism Check] All passed.\n") From d17632795c62fe09be489ddc214baf108dad4fc7 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Fri, 22 May 2026 15:09:33 +0800 Subject: [PATCH 26/26] add copyright --- cula/ops/ptx_umma_ext.py | 3 +++ tests/test_ptx_umma_masked.py | 3 +++ tests/test_ptx_umma_ws.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/cula/ops/ptx_umma_ext.py b/cula/ops/ptx_umma_ext.py index 4426361..2e8caee 100644 --- a/cula/ops/ptx_umma_ext.py +++ b/cula/ops/ptx_umma_ext.py @@ -1,3 +1,6 @@ +# Copyright (c) 2025 ANTGROUP. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + """CuteDSL UMMA extension wrappers for SM100 (Blackwell) ``tcgen05.mma``. CuteDSL's high-level ``cute.gemm()`` / ``make_tiled_mma()`` API does not diff --git a/tests/test_ptx_umma_masked.py b/tests/test_ptx_umma_masked.py index 356575e..e7bc319 100644 --- a/tests/test_ptx_umma_masked.py +++ b/tests/test_ptx_umma_masked.py @@ -1,3 +1,6 @@ +# Copyright (c) 2025 ANTGROUP. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + """ Standalone CuteDSL test for ptx_umma_masked.py inline PTX MMA wrappers. diff --git a/tests/test_ptx_umma_ws.py b/tests/test_ptx_umma_ws.py index 24e2edd..da211a9 100644 --- a/tests/test_ptx_umma_ws.py +++ b/tests/test_ptx_umma_ws.py @@ -1,3 +1,6 @@ +# Copyright (c) 2025 ANTGROUP. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + """ Standalone CuteDSL test for tcgen05.mma.ws (weight-stationary) inline PTX wrappers.