From 90f02f1b6f6c455427104a2756a2b0685fbcf32a Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Mon, 11 May 2026 18:17:58 +0800 Subject: [PATCH 01/28] format --- benchmarks/bench_chunk_delta_h_bwd_sm90.py | 136 +++++++ cula/ops/chunk_delta_h_bwd.py | 404 +++++++++++++++++++++ tests/test_chunk_delta_h_bwd_sm90.py | 109 ++++++ 3 files changed, 649 insertions(+) create mode 100644 benchmarks/bench_chunk_delta_h_bwd_sm90.py create mode 100644 cula/ops/chunk_delta_h_bwd.py create mode 100644 tests/test_chunk_delta_h_bwd_sm90.py diff --git a/benchmarks/bench_chunk_delta_h_bwd_sm90.py b/benchmarks/bench_chunk_delta_h_bwd_sm90.py new file mode 100644 index 0000000..29eda45 --- /dev/null +++ b/benchmarks/bench_chunk_delta_h_bwd_sm90.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +""" +Benchmark the SM90 CuTe DSL bwd_dhu prototype against FLA Triton. + +Current kernel scope: + - non-varlen + - K in {64, 128, 256}, BT=64 + - state layout [B, NT, H, K, V] + - optional gk/dht/h0 + +Example: + python benchmarks/bench_chunk_delta_h_bwd_sm90.py --B 1 --T 1024 --H 8 --K 128 --V 64 --gk --dht +""" + +import argparse +import pathlib +import sys + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) + +import torch +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu as fla_bwd_dhu + +from cula.ops.chunk_delta_h_bwd import chunk_gated_delta_rule_bwd_dhu_sm90 + + +def time_kernel(fn, warmup: int, iters: int) -> float: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / iters + + +def make_inputs(args): + torch.manual_seed(args.seed) + B, T, H, K, V = args.B, args.T, args.H, args.K, args.V + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + w = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + do = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + dv = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + gk = None + if args.gk: + gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) + dht = None + if args.dht: + dht = torch.randn(B, H, K, V, dtype=torch.float32, device="cuda") * 0.01 + h0 = torch.empty(B, H, K, V, dtype=torch.float32, device="cuda") if args.h0 else None + return q, k, w, do, dv, gk, dht, h0 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--B", type=int, default=1) + parser.add_argument("--T", type=int, default=1024) + parser.add_argument("--H", type=int, default=8) + parser.add_argument("--K", type=int, default=128, choices=[64, 128, 256]) + parser.add_argument("--V", type=int, default=64) + parser.add_argument("--gk", action="store_true") + parser.add_argument("--dht", action="store_true") + parser.add_argument("--h0", action="store_true") + parser.add_argument("--use-exp2", action="store_true") + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=100) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 9: + raise RuntimeError("This benchmark requires an SM90/Hopper GPU.") + if args.T % 64 != 0: + raise ValueError("Use T as a multiple of 64 for this prototype benchmark.") + + q, k, w, do, dv, gk, dht, h0 = make_inputs(args) + scale = args.K**-0.5 + + def run_fla(): + return fla_bwd_dhu( + q=q, + k=k, + w=w, + do=do, + dv=dv, + gk=gk, + h0=h0, + dht=dht, + scale=scale, + chunk_size=64, + use_exp2=args.use_exp2, + ) + + def run_cute(): + return chunk_gated_delta_rule_bwd_dhu_sm90( + q=q, + k=k, + w=w, + do=do, + dv=dv, + gk=gk, + h0=h0, + dht=dht, + scale=scale, + chunk_size=64, + use_exp2=args.use_exp2, + ) + + ref = run_fla() + got = run_cute() + torch.cuda.synchronize() + max_dh = (ref[0].float() - got[0].float()).abs().max().item() + max_dv = (ref[2].float() - got[2].float()).abs().max().item() + + fla_ms = time_kernel(run_fla, args.warmup, args.iters) + cute_ms = time_kernel(run_cute, args.warmup, args.iters) + + print( + f"bwd_dhu SM90 B={args.B} T={args.T} H={args.H} K={args.K} V={args.V} " + f"gk={args.gk} dht={args.dht} h0={args.h0} exp2={args.use_exp2}" + ) + print(f"max_diff dh={max_dh:.6f} dv2={max_dv:.6f}") + print(f"FLA Triton: {fla_ms:.4f} ms") + print(f"CuTe DSL : {cute_ms:.4f} ms") + print(f"speedup : {fla_ms / cute_ms:.3f}x") + + +if __name__ == "__main__": + main() diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py new file mode 100644 index 0000000..533a0a9 --- /dev/null +++ b/cula/ops/chunk_delta_h_bwd.py @@ -0,0 +1,404 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SM90 CuTe DSL prototype for chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64. + +This is intentionally a small, non-persistent kernel for the first Hopper path: + - fixed chunk size BT=64 + - K in {64, 128, 256} + - non-varlen tensors [B, T, H, D] + - non-transposed state layout [B, NT, H, K, V] + - optional gk final-state decay + +It mirrors the Triton bwd_dhu recurrence in FLA's common/chunk_delta_h.py. +The implementation favors clarity and testability over throughput; later +iterations can replace the shared-memory matrix products with WGMMA/TMA tiles. +""" + +from __future__ import annotations + +import functools +import math + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.runtime import from_dlpack + +from cula.utils import USE_FAST_MATH, assert_hopper + +BT = 64 +BV = 32 +NUM_THREADS = 256 + + +class ChunkDeltaRuleBwdDHUSm90: + def __init__( + self, + batch_size: int, + seq_len: int, + num_heads: int, + head_dim_k: int, + head_dim_v: int, + use_gk: bool, + use_dht: bool, + use_dh0: bool, + use_exp2: bool, + scale: float, + use_fast_math: bool = True, + ): + assert head_dim_k in (64, 128, 256), f"prototype only supports K in {{64, 128, 256}}, got K={head_dim_k}" + self.B = batch_size + self.T = seq_len + self.H = num_heads + self.K = head_dim_k + self.V = head_dim_v + self.use_gk = use_gk + self.use_dht = use_dht + self.use_dh0 = use_dh0 + self.use_exp2 = use_exp2 + self.scale = scale + self.use_fast_math = use_fast_math + self.BT = BT + self.BK = head_dim_k + self.BV = BV + self.num_threads = NUM_THREADS + + @cute.jit + def __call__( + self, + q_in: cute.Tensor, + k_in: cute.Tensor, + w_in: cute.Tensor, + gk_in: cute.Tensor, + dht_in: cute.Tensor, + dh0_in: cute.Tensor, + do_in: cute.Tensor, + dh_in: cute.Tensor, + dv_in: cute.Tensor, + dv2_in: cute.Tensor, + stream: cuda.CUstream, + ): + q_ptr = q_in.iterator + k_ptr = k_in.iterator + w_ptr = w_in.iterator + gk_ptr = gk_in.iterator + dht_ptr = dht_in.iterator + dh0_ptr = dh0_in.iterator + do_ptr = do_in.iterator + dh_ptr = dh_in.iterator + dv_ptr = dv_in.iterator + dv2_ptr = dv2_in.iterator + + NT = (self.T + self.BT - 1) // self.BT + + q_layout = cute.make_layout( + (self.B, self.T, self.H, self.BK), + stride=(self.T * self.H * self.BK, self.H * self.BK, self.BK, 1), + ) + q = cute.make_tensor(q_ptr, q_layout) + k = cute.make_tensor(k_ptr, q_layout) + w = cute.make_tensor(w_ptr, q_layout) + + v_layout = cute.make_layout( + (self.B, self.T, self.H, self.V), + stride=(self.T * self.H * self.V, self.H * self.V, self.V, 1), + ) + do = cute.make_tensor(do_ptr, v_layout) + dv = cute.make_tensor(dv_ptr, v_layout) + dv2 = cute.make_tensor(dv2_ptr, v_layout) + + gk_layout = cute.make_layout( + (self.B, self.T, self.H, self.BK), + stride=(self.T * self.H * self.BK, self.H * self.BK, self.BK, 1), + ) + gk = cute.make_tensor(gk_ptr, gk_layout) + + state_layout = cute.make_layout( + (self.B, NT, self.H, self.BK, self.V), + stride=( + NT * self.H * self.BK * self.V, + self.H * self.BK * self.V, + self.BK * self.V, + self.V, + 1, + ), + ) + dh = cute.make_tensor(dh_ptr, state_layout) + + final_layout = cute.make_layout( + (self.B, self.H, self.BK, self.V), + stride=(self.H * self.BK * self.V, self.BK * self.V, self.V, 1), + ) + dht = cute.make_tensor(dht_ptr, final_layout) + dh0 = cute.make_tensor(dh0_ptr, final_layout) + + self.kernel(q, k, w, gk, dht, dh0, do, dh, dv, dv2).launch( + grid=[cute.ceil_div(self.V, self.BV), self.B * self.H, 1], + block=[self.num_threads, 1, 1], + smem=(self.BK * self.BV + self.BT * self.BV) * 4 + 512, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + q: cute.Tensor, + k: cute.Tensor, + w: cute.Tensor, + gk: cute.Tensor, + dht: cute.Tensor, + dh0: cute.Tensor, + do: cute.Tensor, + dh: cute.Tensor, + dv: cute.Tensor, + dv2: cute.Tensor, + ): + tidx, _, _ = cute.arch.thread_idx() + i_v_tile, i_bh, _ = cute.arch.block_idx() + i_b = i_bh // self.H + i_h = i_bh - i_b * self.H + v_base = i_v_tile * self.BV + NT = (self.T + self.BT - 1) // self.BT + + smem = cutlass.utils.SmemAllocator() + s_dh = smem.allocate_tensor( + cutlass.Float32, + cute.make_layout((self.BK, self.BV), stride=(self.BV, 1)), + 16, + ) + s_dv = smem.allocate_tensor( + cutlass.Float32, + cute.make_layout((self.BT, self.BV), stride=(self.BV, 1)), + 16, + ) + + linear = tidx + while linear < self.BK * self.BV: + k_idx = linear // self.BV + v_rel = linear - k_idx * self.BV + v_idx = v_base + v_rel + init = cutlass.Float32(0.0) + if cutlass.const_expr(self.use_dht): + if v_idx < self.V: + init = cutlass.Float32(dht[i_b, i_h, k_idx, v_idx]) + s_dh[k_idx, v_rel] = init + linear += self.num_threads + cute.arch.barrier() + + for chunk_rev in cutlass.range_constexpr(NT): + i_t = NT - 1 - chunk_rev + chunk_start = i_t * self.BT + chunk_end = cutlass.min(chunk_start + self.BT, self.T) + last_idx = chunk_end - 1 + + linear = tidx + while linear < self.BK * self.BV: + k_idx = linear // self.BV + v_rel = linear - k_idx * self.BV + v_idx = v_base + v_rel + if v_idx < self.V: + dh[i_b, i_t, i_h, k_idx, v_idx] = s_dh[k_idx, v_rel].to(dh.element_type) + linear += self.num_threads + cute.arch.barrier() + + linear = tidx + while linear < self.BT * self.BV: + t_rel = linear // self.BV + v_rel = linear - t_rel * self.BV + t_idx = chunk_start + t_rel + v_idx = v_base + v_rel + acc = cutlass.Float32(0.0) + if t_idx < self.T and v_idx < self.V: + acc = cutlass.Float32(dv[i_b, t_idx, i_h, v_idx]) + for k_idx in cutlass.range(self.BK, unroll_full=True): + acc += cutlass.Float32(k[i_b, t_idx, i_h, k_idx]) * s_dh[k_idx, v_rel] + dv2[i_b, t_idx, i_h, v_idx] = acc.to(dv2.element_type) + s_dv[t_rel, v_rel] = acc + linear += self.num_threads + cute.arch.barrier() + + linear = tidx + while linear < self.BK * self.BV: + k_idx = linear // self.BV + v_rel = linear - k_idx * self.BV + v_idx = v_base + v_rel + acc = s_dh[k_idx, v_rel] + if v_idx < self.V: + if cutlass.const_expr(self.use_gk): + gk_last = cutlass.Float32(gk[i_b, last_idx, i_h, k_idx]) + if cutlass.const_expr(self.use_exp2): + acc *= cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + acc *= cute.exp(gk_last, fastmath=self.use_fast_math) + for t_rel in cutlass.range(self.BT, unroll_full=True): + t_idx = chunk_start + t_rel + if t_idx < self.T: + q_term = cutlass.Float32(q[i_b, t_idx, i_h, k_idx]) + do_term = cutlass.Float32(do[i_b, t_idx, i_h, v_idx]) + w_term = cutlass.Float32(w[i_b, t_idx, i_h, k_idx]) + acc += q_term * do_term * self.scale - w_term * s_dv[t_rel, v_rel] + s_dh[k_idx, v_rel] = acc + linear += self.num_threads + cute.arch.barrier() + + if cutlass.const_expr(self.use_dh0): + linear = tidx + while linear < self.BK * self.BV: + k_idx = linear // self.BV + v_rel = linear - k_idx * self.BV + v_idx = v_base + v_rel + if v_idx < self.V: + dh0[i_b, i_h, k_idx, v_idx] = s_dh[k_idx, v_rel] + linear += self.num_threads + + +def _as_cute(tensor: torch.Tensor): + return from_dlpack(tensor, assumed_align=16) + + +@functools.lru_cache(maxsize=64) +def _compile_bwd_dhu_sm90( + B: int, + T: int, + H: int, + K: int, + V: int, + use_gk: bool, + use_dht: bool, + use_dh0: bool, + use_exp2: bool, + scale: float, +): + kernel = ChunkDeltaRuleBwdDHUSm90( + batch_size=B, + seq_len=T, + num_heads=H, + head_dim_k=K, + head_dim_v=V, + use_gk=use_gk, + use_dht=use_dht, + use_dh0=use_dh0, + use_exp2=use_exp2, + scale=scale, + use_fast_math=USE_FAST_MATH, + ) + + q_fake = torch.empty(B, T, H, K, device="cuda", dtype=torch.bfloat16) + k_fake = torch.empty_like(q_fake) + w_fake = torch.empty_like(q_fake) + do_fake = torch.empty(B, T, H, V, device="cuda", dtype=torch.bfloat16) + dv_fake = torch.empty_like(do_fake) + dv2_fake = torch.empty_like(do_fake) + gk_fake = torch.empty(B, T, H, K, device="cuda", dtype=torch.float32) + dht_fake = torch.empty(B, H, K, V, device="cuda", dtype=torch.float32) + dh0_fake = torch.empty_like(dht_fake) + dh_fake = torch.empty(B, math.ceil(T / BT), H, K, V, device="cuda", dtype=torch.bfloat16) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + return cute.compile( + kernel, + _as_cute(q_fake), + _as_cute(k_fake), + _as_cute(w_fake), + _as_cute(gk_fake), + _as_cute(dht_fake), + _as_cute(dh0_fake), + _as_cute(do_fake), + _as_cute(dh_fake), + _as_cute(dv_fake), + _as_cute(dv2_fake), + stream=stream, + options="--enable-tvm-ffi", + ) + + +def chunk_gated_delta_rule_bwd_dhu_sm90( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + do: torch.Tensor, + dv: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + h0: torch.Tensor | None = None, + dht: torch.Tensor | None = None, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + chunk_size: int = BT, + chunk_indices: torch.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + """FLA-compatible wrapper for the current SM90 bwd_dhu prototype.""" + del chunk_indices + assert_hopper(q.device) + if cu_seqlens is not None: + raise NotImplementedError("SM90 bwd_dhu prototype only supports non-varlen tensors.") + if transpose_state_layout: + raise NotImplementedError("SM90 bwd_dhu prototype only supports [B, NT, H, K, V] state layout.") + if g is not None: + raise NotImplementedError("SM90 bwd_dhu prototype supports gk gating, not scalar g gating yet.") + if chunk_size != BT: + raise NotImplementedError(f"SM90 bwd_dhu prototype only supports chunk_size={BT}.") + + B, T, H, K = q.shape + V = do.shape[-1] + if K not in (64, 128, 256): + raise NotImplementedError(f"SM90 bwd_dhu prototype only supports K in {{64, 128, 256}}, got K={K}.") + if q.dtype != torch.bfloat16 or k.dtype != torch.bfloat16 or w.dtype != torch.bfloat16: + raise TypeError("q, k, and w must be bfloat16 for the SM90 bwd_dhu prototype.") + if do.dtype != torch.bfloat16 or dv.dtype != torch.bfloat16: + raise TypeError("do and dv must be bfloat16 for the SM90 bwd_dhu prototype.") + if not q.is_contiguous() or not k.is_contiguous() or not w.is_contiguous(): + raise ValueError("q, k, and w must be contiguous.") + if not do.is_contiguous() or not dv.is_contiguous(): + raise ValueError("do and dv must be contiguous.") + + NT = math.ceil(T / BT) + scale_value = 1.0 if scale is None else float(scale) + + dh = q.new_empty(B, NT, H, K, V) + dh0 = torch.empty(B, H, K, V, device=q.device, dtype=torch.float32) if h0 is not None else None + dv2 = torch.empty_like(dv) + + gk_arg = gk if gk is not None else torch.empty(B, T, H, K, device=q.device, dtype=torch.float32) + dht_arg = dht if dht is not None else torch.empty(B, H, K, V, device=q.device, dtype=torch.float32) + dh0_arg = dh0 if dh0 is not None else torch.empty(B, H, K, V, device=q.device, dtype=torch.float32) + if gk is not None and (gk.dtype != torch.float32 or not gk.is_contiguous()): + raise ValueError("gk must be contiguous float32.") + if dht is not None and (dht.dtype != torch.float32 or not dht.is_contiguous()): + raise ValueError("dht must be contiguous float32.") + + compiled = _compile_bwd_dhu_sm90( + B, + T, + H, + K, + V, + gk is not None, + dht is not None, + h0 is not None, + use_exp2, + scale_value, + ) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled(q, k, w, gk_arg, dht_arg, dh0_arg, do, dh, dv, dv2, stream) + return dh, dh0, dv2 + + +# Shorter alias for users who import this module directly. +chunk_gated_delta_rule_bwd_dhu = chunk_gated_delta_rule_bwd_dhu_sm90 diff --git a/tests/test_chunk_delta_h_bwd_sm90.py b/tests/test_chunk_delta_h_bwd_sm90.py new file mode 100644 index 0000000..6e0a0ba --- /dev/null +++ b/tests/test_chunk_delta_h_bwd_sm90.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for the SM90 CuTe DSL bwd_dhu prototype.""" + +import os +import sys + +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu as fla_bwd_dhu + +from cula.ops.chunk_delta_h_bwd import chunk_gated_delta_rule_bwd_dhu_sm90 + + +def _is_sm90() -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 9 + + +pytestmark = [ + pytest.mark.sm90_only, + pytest.mark.skipif(not _is_sm90(), reason="SM90/Hopper GPU is required"), +] + + +def _make_inputs(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, seed=42): + torch.manual_seed(seed) + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + w = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + do = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + dv = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + + gk = None + if use_gk: + gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) + + dht = None + if use_dht: + dht = torch.randn(B, H, K, V, dtype=torch.float32, device="cuda") * 0.01 + + h0 = None + if use_h0: + h0 = torch.empty(B, H, K, V, dtype=torch.float32, device="cuda") + + return q, k, w, do, dv, gk, dht, h0 + + +def _run_case(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, use_exp2=False): + q, k, w, do, dv, gk, dht, h0 = _make_inputs(B, T, H, K, V, use_gk, use_dht, use_h0) + scale = K**-0.5 + + ref_dh, ref_dh0, ref_dv2 = fla_bwd_dhu( + q=q, + k=k, + w=w, + do=do, + dv=dv, + gk=gk, + h0=h0, + dht=dht, + scale=scale, + chunk_size=64, + use_exp2=use_exp2, + ) + + got_dh, got_dh0, got_dv2 = chunk_gated_delta_rule_bwd_dhu_sm90( + q=q, + k=k, + w=w, + do=do, + dv=dv, + gk=gk, + h0=h0, + dht=dht, + scale=scale, + chunk_size=64, + use_exp2=use_exp2, + ) + + torch.testing.assert_close(got_dh.float(), ref_dh.float(), atol=3e-2, rtol=3e-2) + torch.testing.assert_close(got_dv2.float(), ref_dv2.float(), atol=3e-2, rtol=3e-2) + if use_h0: + assert got_dh0 is not None + torch.testing.assert_close(got_dh0, ref_dh0, atol=3e-2, rtol=3e-2) + else: + assert got_dh0 is None + + +@pytest.mark.parametrize("T", [64, 128]) +@pytest.mark.parametrize("V", [32, 64]) +def test_bwd_dhu_no_gating(T, V): + _run_case(B=1, T=T, H=1, K=64, V=V) + + +def test_bwd_dhu_with_gk_exp2_and_dht(): + _run_case(B=1, T=128, H=2, K=64, V=64, use_gk=True, use_dht=True, use_exp2=True) + + +def test_bwd_dhu_k128_with_gk_exp2_and_dht(): + _run_case(B=1, T=128, H=1, K=128, V=64, use_gk=True, use_dht=True, use_exp2=True) + + +def test_bwd_dhu_returns_dh0(): + _run_case(B=2, T=128, H=1, K=64, V=64, use_h0=True) From 80c1d78ba9e8561d4fa2564f13a5abb2c6c91b4d Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Tue, 12 May 2026 19:42:41 +0800 Subject: [PATCH 02/28] format --- benchmarks/bench_chunk_delta_h_bwd_sm90.py | 8 +- cula/ops/chunk_delta_h_bwd.py | 560 +++++++++++++++++---- tests/test_chunk_delta_h_bwd_sm90.py | 4 +- 3 files changed, 466 insertions(+), 106 deletions(-) diff --git a/benchmarks/bench_chunk_delta_h_bwd_sm90.py b/benchmarks/bench_chunk_delta_h_bwd_sm90.py index 29eda45..ddf8799 100644 --- a/benchmarks/bench_chunk_delta_h_bwd_sm90.py +++ b/benchmarks/bench_chunk_delta_h_bwd_sm90.py @@ -3,11 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 """ -Benchmark the SM90 CuTe DSL bwd_dhu prototype against FLA Triton. +Benchmark the SM90 CuTe DSL WGMMA bwd_dhu path against FLA Triton. Current kernel scope: - non-varlen - - K in {64, 128, 256}, BT=64 + - K in {64, 128, 256}, BT=64, BV=64 - state layout [B, NT, H, K, V] - optional gk/dht/h0 @@ -78,7 +78,9 @@ def main(): if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 9: raise RuntimeError("This benchmark requires an SM90/Hopper GPU.") if args.T % 64 != 0: - raise ValueError("Use T as a multiple of 64 for this prototype benchmark.") + raise ValueError("Use T as a multiple of 64 for this benchmark.") + if args.V % 64 != 0: + raise ValueError("Use V as a multiple of 64 for the SM90 WGMMA path.") q, k, w, do, dv, gk, dht, h0 = make_inputs(args) scale = args.K**-0.5 diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 533a0a9..55cb7b5 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -13,18 +13,22 @@ # limitations under the License. """ -SM90 CuTe DSL prototype for chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64. +SM90 CuTe DSL implementation for chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64. -This is intentionally a small, non-persistent kernel for the first Hopper path: +This is the first Hopper tensor-core path: - fixed chunk size BT=64 - - K in {64, 128, 256} + - BV=64, matching cula/ops/chunk_delta_h.py - non-varlen tensors [B, T, H, D] - non-transposed state layout [B, NT, H, K, V] - - optional gk final-state decay + - non-persistent scheduling -It mirrors the Triton bwd_dhu recurrence in FLA's common/chunk_delta_h.py. -The implementation favors clarity and testability over throughput; later -iterations can replace the shared-memory matrix products with WGMMA/TMA tiles. +The recurrence is the Triton bwd_dhu recurrence: + dv2 = dv + K @ dh + dh = decay(dh) + scale * Q^T @ do - W^T @ dv2 + +Each CTA owns one BV tile and one (batch, head). WGMMA computes the three +64x64 GEMMs per chunk; scalar CUDA code only stages operands and applies the +elementwise recurrence. """ from __future__ import annotations @@ -35,14 +39,23 @@ import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.hopper_helpers as sm90_utils import torch +from cutlass.cute.nvgpu import cpasync from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Float32, Int64 from cula.utils import USE_FAST_MATH, assert_hopper BT = 64 -BV = 32 -NUM_THREADS = 256 +BV = 64 +NUM_THREADS = 128 + + +def make_thread_cooperative_group(size: int): + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) class ChunkDeltaRuleBwdDHUSm90: @@ -60,7 +73,8 @@ def __init__( scale: float, use_fast_math: bool = True, ): - assert head_dim_k in (64, 128, 256), f"prototype only supports K in {{64, 128, 256}}, got K={head_dim_k}" + assert head_dim_k in (64, 128, 256), f"SM90 bwd_dhu supports K in {{64, 128, 256}}, got {head_dim_k}" + assert head_dim_v % BV == 0, f"SM90 bwd_dhu tensor-core path requires V to be a multiple of {BV}, got {head_dim_v}" self.B = batch_size self.T = seq_len self.H = num_heads @@ -72,10 +86,19 @@ def __init__( self.use_exp2 = use_exp2 self.scale = scale self.use_fast_math = use_fast_math + self.BT = BT - self.BK = head_dim_k self.BV = BV + self.num_v_tiles = (head_dim_v + BV - 1) // BV self.num_threads = NUM_THREADS + self.io_dtype = cutlass.BFloat16 + self.acc_dtype = cutlass.Float32 + self.buffer_align_bytes = 1024 + + self.mma_tiler = (BT, BV, head_dim_k) + self.update_mma_tiler = (BV, head_dim_k, BT) + self.atom_layout_mnk = (1, 1, 1) + self.cluster_shape_mnk = (1, 1, 1) @cute.jit def __call__( @@ -106,8 +129,8 @@ def __call__( NT = (self.T + self.BT - 1) // self.BT q_layout = cute.make_layout( - (self.B, self.T, self.H, self.BK), - stride=(self.T * self.H * self.BK, self.H * self.BK, self.BK, 1), + (self.B, self.T, self.H, self.K), + stride=(self.T * self.H * self.K, self.H * self.K, self.K, 1), ) q = cute.make_tensor(q_ptr, q_layout) k = cute.make_tensor(k_ptr, q_layout) @@ -122,17 +145,17 @@ def __call__( dv2 = cute.make_tensor(dv2_ptr, v_layout) gk_layout = cute.make_layout( - (self.B, self.T, self.H, self.BK), - stride=(self.T * self.H * self.BK, self.H * self.BK, self.BK, 1), + (self.B, self.T, self.H, self.K), + stride=(self.T * self.H * self.K, self.H * self.K, self.K, 1), ) gk = cute.make_tensor(gk_ptr, gk_layout) state_layout = cute.make_layout( - (self.B, NT, self.H, self.BK, self.V), + (self.B, NT, self.H, self.K, self.V), stride=( - NT * self.H * self.BK * self.V, - self.H * self.BK * self.V, - self.BK * self.V, + NT * self.H * self.K * self.V, + self.H * self.K * self.V, + self.K * self.V, self.V, 1, ), @@ -140,17 +163,186 @@ def __call__( dh = cute.make_tensor(dh_ptr, state_layout) final_layout = cute.make_layout( - (self.B, self.H, self.BK, self.V), - stride=(self.H * self.BK * self.V, self.BK * self.V, self.V, 1), + (self.B, self.H, self.K, self.V), + stride=(self.H * self.K * self.V, self.K * self.V, self.V, 1), ) dht = cute.make_tensor(dht_ptr, final_layout) dh0 = cute.make_tensor(dh0_ptr, final_layout) - self.kernel(q, k, w, gk, dht, dh0, do, dh, dv, dv2).launch( + tk_layout = cute.make_layout( + (self.T, self.K, (self.H, self.B)), stride=(self.H * self.K, 1, (self.K, self.T * self.H * self.K)) + ) + k_tk = cute.make_tensor(k_ptr, tk_layout) + + kt_layout = cute.make_layout( + (self.K, self.T, (self.H, self.B)), stride=(1, self.H * self.K, (self.K, self.T * self.H * self.K)) + ) + q_kt = cute.make_tensor(q_ptr, kt_layout) + w_kt = cute.make_tensor(w_ptr, kt_layout) + + vt_layout = cute.make_layout( + (self.V, self.T, (self.H, self.B)), stride=(1, self.H * self.V, (self.V, self.T * self.H * self.V)) + ) + do_vt = cute.make_tensor(do_ptr, vt_layout) + dv_vt = cute.make_tensor(dv_ptr, vt_layout) + dv2_vt = cute.make_tensor(dv2_ptr, vt_layout) + + tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.io_dtype, + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR.sm90_mma_major_mode(), + utils.LayoutEnum.ROW_MAJOR.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + self.mma_tiler[:2], + ) + + update_tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.io_dtype, + self.io_dtype, + utils.LayoutEnum.COL_MAJOR.sm90_mma_major_mode(), + utils.LayoutEnum.COL_MAJOR.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + self.update_mma_tiler[:2], + ) + + a_smem_layout_staged = sm90_utils.make_smem_layout_a( + utils.LayoutEnum.ROW_MAJOR, + self.mma_tiler, + self.io_dtype, + 1, + ) + b_smem_layout_staged = sm90_utils.make_smem_layout_b( + utils.LayoutEnum.ROW_MAJOR, + self.mma_tiler, + self.io_dtype, + 1, + ) + update_a_smem_layout_staged = sm90_utils.make_smem_layout_a( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + 1, + ) + update_b_smem_layout_staged = sm90_utils.make_smem_layout_b( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + 1, + ) + dv_smem_layout_staged = cute.make_layout((self.BV, self.BT, 1), stride=(1, self.BV, self.BV * self.BT)) + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp() + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + tma_atom_k, tma_tensor_k = cpasync.make_tiled_tma_atom( + tma_load_op, + k_tk, + cute.slice_(a_smem_layout_staged, (None, None, 0)), + (self.BT, self.K), + ) + tma_atom_dv, tma_tensor_dv = cpasync.make_tiled_tma_atom( + tma_load_op, + dv_vt, + cute.slice_(dv_smem_layout_staged, (None, None, 0)), + (self.BV, self.BT), + ) + tma_atom_do, tma_tensor_do = cpasync.make_tiled_tma_atom( + tma_load_op, + do_vt, + cute.slice_(update_a_smem_layout_staged, (None, None, 0)), + (self.BV, self.BT), + ) + tma_atom_q, tma_tensor_q = cpasync.make_tiled_tma_atom( + tma_load_op, + q_kt, + cute.slice_(update_b_smem_layout_staged, (None, None, 0)), + (self.K, self.BT), + ) + tma_atom_w, tma_tensor_w = cpasync.make_tiled_tma_atom( + tma_load_op, + w_kt, + cute.slice_(update_b_smem_layout_staged, (None, None, 0)), + (self.K, self.BT), + ) + tma_atom_dv2, tma_tensor_dv2 = cpasync.make_tiled_tma_atom( + tma_store_op, + dv2_vt, + cute.slice_(dv_smem_layout_staged, (None, None, 0)), + (self.BV, self.BT), + ) + self.tma_kdv_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(a_smem_layout_staged, (None, None, 0)) + ) + cute.size_in_bytes(self.io_dtype, cute.slice_(dv_smem_layout_staged, (None, None, 0))) + self.tma_qdo_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(update_a_smem_layout_staged, (None, None, 0)) + ) + cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) + self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) + + @cute.struct + class SharedStorage: + load_kdv_mbar: cute.struct.MemRange[Int64, 2] + load_qdo_mbar: cute.struct.MemRange[Int64, 2] + load_w_mbar: cute.struct.MemRange[Int64, 2] + sA: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(a_smem_layout_staged)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(b_smem_layout_staged)], + self.buffer_align_bytes, + ] + sUA: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(update_a_smem_layout_staged)], + self.buffer_align_bytes, + ] + sUB: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(update_b_smem_layout_staged)], + self.buffer_align_bytes, + ] + sDv2T: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(dv_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.kernel( + q, + k, + w, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + tiled_mma, + update_tiled_mma, + a_smem_layout_staged, + b_smem_layout_staged, + update_a_smem_layout_staged, + update_b_smem_layout_staged, + tma_atom_k, + tma_tensor_k, + tma_atom_dv, + tma_tensor_dv, + tma_atom_do, + tma_tensor_do, + tma_atom_q, + tma_tensor_q, + tma_atom_w, + tma_tensor_w, + tma_atom_dv2, + tma_tensor_dv2, + ).launch( grid=[cute.ceil_div(self.V, self.BV), self.B * self.H, 1], block=[self.num_threads, 1, 1], - smem=(self.BK * self.BV + self.BT * self.BV) * 4 + 512, + cluster=self.cluster_shape_mnk, stream=stream, + min_blocks_per_mp=1, ) @cute.kernel @@ -166,38 +358,115 @@ def kernel( dh: cute.Tensor, dv: cute.Tensor, dv2: cute.Tensor, + tiled_mma: cute.TiledMma, + update_tiled_mma: cute.TiledMma, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + update_a_smem_layout_staged: cute.ComposedLayout, + update_b_smem_layout_staged: cute.ComposedLayout, + tma_atom_k: cute.CopyAtom, + tma_tensor_k: cute.Tensor, + tma_atom_dv: cute.CopyAtom, + tma_tensor_dv: cute.Tensor, + tma_atom_do: cute.CopyAtom, + tma_tensor_do: cute.Tensor, + tma_atom_q: cute.CopyAtom, + tma_tensor_q: cute.Tensor, + tma_atom_w: cute.CopyAtom, + tma_tensor_w: cute.Tensor, + tma_atom_dv2: cute.CopyAtom, + tma_tensor_dv2: cute.Tensor, ): tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) i_v_tile, i_bh, _ = cute.arch.block_idx() i_b = i_bh // self.H i_h = i_bh - i_b * self.H v_base = i_v_tile * self.BV NT = (self.T + self.BT - 1) // self.BT - smem = cutlass.utils.SmemAllocator() - s_dh = smem.allocate_tensor( - cutlass.Float32, - cute.make_layout((self.BK, self.BV), stride=(self.BV, 1)), - 16, + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + sUA = storage.sUA.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) + sUB = storage.sUB.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) + sDv2T = storage.sDv2T.get_tensor(cute.make_layout((BV, BT, 1), stride=(1, BV, BV * BT))) + + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_k) + cpasync.prefetch_descriptor(tma_atom_dv) + cpasync.prefetch_descriptor(tma_atom_do) + cpasync.prefetch_descriptor(tma_atom_q) + cpasync.prefetch_descriptor(tma_atom_w) + cpasync.prefetch_descriptor(tma_atom_dv2) + + load_kdv_P, load_kdv_C = pipeline.PipelineTmaAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_threads // 32), + tx_count=self.tma_kdv_bytes, + barrier_storage=storage.load_kdv_mbar.data_ptr(), + ).make_participants() + load_qdo_P, load_qdo_C = pipeline.PipelineTmaAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_threads // 32), + tx_count=self.tma_qdo_bytes, + barrier_storage=storage.load_qdo_mbar.data_ptr(), + ).make_participants() + load_w_P, load_w_C = pipeline.PipelineTmaAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_threads // 32), + tx_count=self.tma_w_bytes, + barrier_storage=storage.load_w_mbar.data_ptr(), + ).make_participants() + + thr_mma = tiled_mma.get_slice(tidx) + update_thr_mma = update_tiled_mma.get_slice(tidx) + + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCrA = thr_mma.make_fragment_A(tCsA) + tCrB = thr_mma.make_fragment_B(tCsB) + tUsA = update_thr_mma.partition_A(sUA) + tUsB = update_thr_mma.partition_B(sUB) + tUrA = update_thr_mma.make_fragment_A(tUsA) + tUrB = update_thr_mma.make_fragment_B(tUsB) + + cDV = cute.make_identity_tensor((BT, BV)) + tCcDV = thr_mma.partition_C(cDV) + acc_dv = thr_mma.make_fragment_C(thr_mma.partition_shape_C((BT, BV))) + + cState = cute.make_identity_tensor((BV, self.K)) + tUcState = update_thr_mma.partition_C(cState) + rState = update_thr_mma.make_fragment_C(update_thr_mma.partition_shape_C((BV, self.K))) + acc_qdo = update_thr_mma.make_fragment_C(update_thr_mma.partition_shape_C((BV, self.K))) + acc_wdv = update_thr_mma.make_fragment_C(update_thr_mma.partition_shape_C((BV, self.K))) + + _, bSG_sK, bSG_gK = self._epilog_partition(tma_atom_k, tma_tensor_k[None, None, (i_h, i_b)], (self.BT, self.K), sA) + _, bSG_sDv, bSG_gDv = self._epilog_partition( + tma_atom_dv, tma_tensor_dv[None, None, (i_h, i_b)], (self.BV, self.BT), sDv2T + ) + _, bSG_sDo, bSG_gDo = self._epilog_partition( + tma_atom_do, tma_tensor_do[None, None, (i_h, i_b)], (self.BV, self.BT), sUA ) - s_dv = smem.allocate_tensor( - cutlass.Float32, - cute.make_layout((self.BT, self.BV), stride=(self.BV, 1)), - 16, + _, bSG_sQ, bSG_gQ = self._epilog_partition(tma_atom_q, tma_tensor_q[None, None, (i_h, i_b)], (self.K, self.BT), sUB) + _, bSG_sW, bSG_gW = self._epilog_partition(tma_atom_w, tma_tensor_w[None, None, (i_h, i_b)], (self.K, self.BT), sUB) + _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( + tma_atom_dv2, tma_tensor_dv2[None, None, (i_h, i_b)], (self.BV, self.BT), sDv2T ) - linear = tidx - while linear < self.BK * self.BV: - k_idx = linear // self.BV - v_rel = linear - k_idx * self.BV + # Initialize carried dh state. + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_idx = tUcState[ei] v_idx = v_base + v_rel - init = cutlass.Float32(0.0) + init = Float32(0.0) if cutlass.const_expr(self.use_dht): - if v_idx < self.V: - init = cutlass.Float32(dht[i_b, i_h, k_idx, v_idx]) - s_dh[k_idx, v_rel] = init - linear += self.num_threads - cute.arch.barrier() + init = dht[i_b, i_h, k_idx, v_idx].to(self.acc_dtype) + rState[ei] = init for chunk_rev in cutlass.range_constexpr(NT): i_t = NT - 1 - chunk_rev @@ -205,65 +474,153 @@ def kernel( chunk_end = cutlass.min(chunk_start + self.BT, self.T) last_idx = chunk_end - 1 - linear = tidx - while linear < self.BK * self.BV: - k_idx = linear // self.BV - v_rel = linear - k_idx * self.BV + # Store dh before applying this chunk's reverse update. + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_idx = tUcState[ei] v_idx = v_base + v_rel - if v_idx < self.V: - dh[i_b, i_t, i_h, k_idx, v_idx] = s_dh[k_idx, v_rel].to(dh.element_type) - linear += self.num_threads - cute.arch.barrier() - - linear = tidx - while linear < self.BT * self.BV: - t_rel = linear // self.BV - v_rel = linear - t_rel * self.BV + dh[i_b, i_t, i_h, k_idx, v_idx] = rState[ei].to(dh.element_type) + cute.arch.sync_threads() + + # dv2 = dv + K @ dh. + acc_dv.fill(0.0) + if warp_idx == 0: + kdv_h = load_kdv_P.acquire_and_advance() + cute.copy(tma_atom_k, bSG_gK[(None, i_t, 0)], bSG_sK[None, kdv_h.index], tma_bar_ptr=kdv_h.barrier) + cute.copy( + tma_atom_dv, + bSG_gDv[(None, i_v_tile, i_t)], + bSG_sDv[None, kdv_h.index], + tma_bar_ptr=kdv_h.barrier, + ) + + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_idx = tUcState[ei] + sB[v_rel, k_idx, 0] = rState[ei].to(self.io_dtype) + + kdv_wait = load_kdv_C.wait_and_advance() + cute.arch.sync_threads() + + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True): + tiled_mma.set( + cute.nvgpu.warpgroup.Field.ACCUMULATE, + cutlass.Boolean(kp != 0), + ) + cute.gemm( + tiled_mma, + acc_dv, + tCrA[None, None, kp, 0], + tCrB[None, None, kp, 0], + acc_dv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + kdv_wait.release() + cute.arch.sync_threads() + + for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): + t_rel, v_rel = tCcDV[ei] t_idx = chunk_start + t_rel v_idx = v_base + v_rel - acc = cutlass.Float32(0.0) - if t_idx < self.T and v_idx < self.V: - acc = cutlass.Float32(dv[i_b, t_idx, i_h, v_idx]) - for k_idx in cutlass.range(self.BK, unroll_full=True): - acc += cutlass.Float32(k[i_b, t_idx, i_h, k_idx]) * s_dh[k_idx, v_rel] - dv2[i_b, t_idx, i_h, v_idx] = acc.to(dv2.element_type) - s_dv[t_rel, v_rel] = acc - linear += self.num_threads - cute.arch.barrier() + out = Float32(0.0) + if t_idx < self.T: + out = acc_dv[ei] + sDv2T[v_rel, t_rel, 0].to(self.acc_dtype) + sDv2T[v_rel, t_rel, 0] = out.to(self.io_dtype) + cute.arch.fence_proxy("async.shared", space="cta") + cute.arch.sync_threads() + if warp_idx == 0: + cute.copy(tma_atom_dv2, bSG_sDv2[None, 0], bSG_gDv2[(None, i_v_tile, i_t)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.sync_threads() + + # Apply gk decay after dv2, before accumulating QO - WV into dh. + if cutlass.const_expr(self.use_gk): + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_idx = tUcState[ei] + gk_last = gk[i_b, last_idx, i_h, k_idx].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + scale = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + scale = cute.exp(gk_last, fastmath=self.use_fast_math) + rState[ei] = rState[ei] * scale + + # dh += scale * do^T @ q - dv2^T @ w. + if warp_idx == 0: + qdo_h = load_qdo_P.acquire_and_advance() + cute.copy(tma_atom_do, bSG_gDo[(None, i_v_tile, i_t)], bSG_sDo[None, qdo_h.index], tma_bar_ptr=qdo_h.barrier) + cute.copy(tma_atom_q, bSG_gQ[(None, 0, i_t)], bSG_sQ[None, qdo_h.index], tma_bar_ptr=qdo_h.barrier) + qdo_wait = load_qdo_C.wait_and_advance() + cute.arch.sync_threads() + + acc_qdo.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_qdo, + tUrA[None, None, kp, 0], + tUrB[None, None, kp, 0], + acc_qdo, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + qdo_wait.release() + cute.arch.sync_threads() + + if warp_idx == 0: + w_h = load_w_P.acquire_and_advance() + cute.copy(tma_atom_w, bSG_gW[(None, 0, i_t)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) linear = tidx - while linear < self.BK * self.BV: - k_idx = linear // self.BV - v_rel = linear - k_idx * self.BV - v_idx = v_base + v_rel - acc = s_dh[k_idx, v_rel] - if v_idx < self.V: - if cutlass.const_expr(self.use_gk): - gk_last = cutlass.Float32(gk[i_b, last_idx, i_h, k_idx]) - if cutlass.const_expr(self.use_exp2): - acc *= cute.exp2(gk_last, fastmath=self.use_fast_math) - else: - acc *= cute.exp(gk_last, fastmath=self.use_fast_math) - for t_rel in cutlass.range(self.BT, unroll_full=True): - t_idx = chunk_start + t_rel - if t_idx < self.T: - q_term = cutlass.Float32(q[i_b, t_idx, i_h, k_idx]) - do_term = cutlass.Float32(do[i_b, t_idx, i_h, v_idx]) - w_term = cutlass.Float32(w[i_b, t_idx, i_h, k_idx]) - acc += q_term * do_term * self.scale - w_term * s_dv[t_rel, v_rel] - s_dh[k_idx, v_rel] = acc + while linear < self.BV * self.BT: + v_rel = linear // self.BT + t_rel = linear - v_rel * self.BT + sUA[v_rel, t_rel, 0] = sDv2T[v_rel, t_rel, 0] linear += self.num_threads - cute.arch.barrier() + w_wait = load_w_C.wait_and_advance() + cute.arch.sync_threads() + + acc_wdv.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_wdv, + tUrA[None, None, kp, 0], + tUrB[None, None, kp, 0], + acc_wdv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + w_wait.release() + + for ei in cutlass.range(cute.size(rState), unroll_full=True): + update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] + rState[ei] = rState[ei] + update + cute.arch.sync_threads() if cutlass.const_expr(self.use_dh0): - linear = tidx - while linear < self.BK * self.BV: - k_idx = linear // self.BV - v_rel = linear - k_idx * self.BV + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_idx = tUcState[ei] v_idx = v_base + v_rel - if v_idx < self.V: - dh0[i_b, i_h, k_idx, v_idx] = s_dh[k_idx, v_rel] - linear += self.num_threads + dh0[i_b, i_h, k_idx, v_idx] = rState[ei] + + @cute.jit + def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): + gC_epi = cute.flat_divide(gC_mnl, epi_tile) + sC_g = cute.group_modes(sC, 0, 2) + gC_g = cute.group_modes(gC_epi, 0, 2) + bSG_sC, bSG_gC = cpasync.tma_partition( + atom, + 0, + cute.make_layout(1), + sC_g, + gC_g, + ) + return atom, bSG_sC, bSG_gC def _as_cute(tensor: torch.Tensor): @@ -343,26 +700,28 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( use_exp2: bool = False, transpose_state_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: - """FLA-compatible wrapper for the current SM90 bwd_dhu prototype.""" + """FLA-compatible wrapper for the SM90 WGMMA bwd_dhu path.""" del chunk_indices assert_hopper(q.device) if cu_seqlens is not None: - raise NotImplementedError("SM90 bwd_dhu prototype only supports non-varlen tensors.") + raise NotImplementedError("SM90 bwd_dhu currently only supports non-varlen tensors.") if transpose_state_layout: - raise NotImplementedError("SM90 bwd_dhu prototype only supports [B, NT, H, K, V] state layout.") + raise NotImplementedError("SM90 bwd_dhu currently only supports [B, NT, H, K, V] state layout.") if g is not None: - raise NotImplementedError("SM90 bwd_dhu prototype supports gk gating, not scalar g gating yet.") + raise NotImplementedError("SM90 bwd_dhu currently supports gk gating, not scalar g gating.") if chunk_size != BT: - raise NotImplementedError(f"SM90 bwd_dhu prototype only supports chunk_size={BT}.") + raise NotImplementedError(f"SM90 bwd_dhu only supports chunk_size={BT}.") B, T, H, K = q.shape V = do.shape[-1] if K not in (64, 128, 256): - raise NotImplementedError(f"SM90 bwd_dhu prototype only supports K in {{64, 128, 256}}, got K={K}.") + raise NotImplementedError(f"SM90 bwd_dhu only supports K in {{64, 128, 256}}, got K={K}.") + if V % BV != 0: + raise NotImplementedError(f"SM90 bwd_dhu WGMMA path requires V to be a multiple of {BV}, got V={V}.") if q.dtype != torch.bfloat16 or k.dtype != torch.bfloat16 or w.dtype != torch.bfloat16: - raise TypeError("q, k, and w must be bfloat16 for the SM90 bwd_dhu prototype.") + raise TypeError("q, k, and w must be bfloat16 for the SM90 bwd_dhu path.") if do.dtype != torch.bfloat16 or dv.dtype != torch.bfloat16: - raise TypeError("do and dv must be bfloat16 for the SM90 bwd_dhu prototype.") + raise TypeError("do and dv must be bfloat16 for the SM90 bwd_dhu path.") if not q.is_contiguous() or not k.is_contiguous() or not w.is_contiguous(): raise ValueError("q, k, and w must be contiguous.") if not do.is_contiguous() or not dv.is_contiguous(): @@ -400,5 +759,4 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( return dh, dh0, dv2 -# Shorter alias for users who import this module directly. chunk_gated_delta_rule_bwd_dhu = chunk_gated_delta_rule_bwd_dhu_sm90 diff --git a/tests/test_chunk_delta_h_bwd_sm90.py b/tests/test_chunk_delta_h_bwd_sm90.py index 6e0a0ba..cfdee3d 100644 --- a/tests/test_chunk_delta_h_bwd_sm90.py +++ b/tests/test_chunk_delta_h_bwd_sm90.py @@ -2,7 +2,7 @@ # Copyright 2025-2026 Ant Group Co., Ltd. # SPDX-License-Identifier: Apache-2.0 -"""Correctness tests for the SM90 CuTe DSL bwd_dhu prototype.""" +"""Correctness tests for the SM90 CuTe DSL WGMMA bwd_dhu path.""" import os import sys @@ -92,7 +92,7 @@ def _run_case(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, use_exp2 @pytest.mark.parametrize("T", [64, 128]) -@pytest.mark.parametrize("V", [32, 64]) +@pytest.mark.parametrize("V", [64, 128]) def test_bwd_dhu_no_gating(T, V): _run_case(B=1, T=T, H=1, K=64, V=V) From 9188c7f8952f372bfc7e303df94ab428a2fb89ef Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Wed, 13 May 2026 00:28:10 +0800 Subject: [PATCH 03/28] format --- benchmarks/bench_chunk_delta_h_bwd_sm90.py | 215 ++++++++++++--- cula/ops/chunk_delta_h_bwd.py | 293 +++++++++++++++++---- tests/test_chunk_delta_h_bwd_sm90.py | 244 ++++++++++++++++- 3 files changed, 656 insertions(+), 96 deletions(-) diff --git a/benchmarks/bench_chunk_delta_h_bwd_sm90.py b/benchmarks/bench_chunk_delta_h_bwd_sm90.py index ddf8799..095dc49 100644 --- a/benchmarks/bench_chunk_delta_h_bwd_sm90.py +++ b/benchmarks/bench_chunk_delta_h_bwd_sm90.py @@ -8,7 +8,7 @@ Current kernel scope: - non-varlen - K in {64, 128, 256}, BT=64, BV=64 - - state layout [B, NT, H, K, V] + - state layout [B, NT, H, K, V] or [B, NT, H, V, K] - optional gk/dht/h0 Example: @@ -16,6 +16,7 @@ """ import argparse +import math import pathlib import sys @@ -43,47 +44,55 @@ def time_kernel(fn, warmup: int, iters: int) -> float: def make_inputs(args): torch.manual_seed(args.seed) - B, T, H, K, V = args.B, args.T, args.H, args.K, args.V + seq_lens = getattr(args, "seq_lens", None) + is_varlen = seq_lens is not None + B = 1 if is_varlen else args.B + T = sum(seq_lens) if is_varlen else args.T + N = len(seq_lens) if is_varlen else B + H, K, V = args.H, args.K, args.V q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 w = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 do = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 dv = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + cu_seqlens = None + if is_varlen: + cu = [0] + for seq_len in seq_lens: + cu.append(cu[-1] + seq_len) + cu_seqlens = torch.tensor(cu, dtype=torch.int32, device="cuda") + g = None + if args.g: + if is_varlen: + g = torch.empty(B, T, H, dtype=torch.float32, device="cuda") + for i in range(N): + bos, eos = cu_seqlens[i].item(), cu_seqlens[i + 1].item() + seg = torch.randn(B, eos - bos, H, dtype=torch.float32, device="cuda") * 0.01 + g[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) + else: + g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) gk = None if args.gk: - gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) + if is_varlen: + gk = torch.empty(B, T, H, K, dtype=torch.float32, device="cuda") + for i in range(N): + bos, eos = cu_seqlens[i].item(), cu_seqlens[i + 1].item() + seg = torch.randn(B, eos - bos, H, K, dtype=torch.float32, device="cuda") * 0.01 + gk[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) + else: + gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) + state_shape = (N, H, V, K) if args.transpose_state else (N, H, K, V) dht = None if args.dht: - dht = torch.randn(B, H, K, V, dtype=torch.float32, device="cuda") * 0.01 - h0 = torch.empty(B, H, K, V, dtype=torch.float32, device="cuda") if args.h0 else None - return q, k, w, do, dv, gk, dht, h0 + dht = torch.randn(state_shape, dtype=torch.float32, device="cuda") * 0.01 + h0 = torch.empty(state_shape, dtype=torch.float32, device="cuda") if args.h0 else None + return q, k, w, do, dv, g, gk, dht, h0, cu_seqlens -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--B", type=int, default=1) - parser.add_argument("--T", type=int, default=1024) - parser.add_argument("--H", type=int, default=8) - parser.add_argument("--K", type=int, default=128, choices=[64, 128, 256]) - parser.add_argument("--V", type=int, default=64) - parser.add_argument("--gk", action="store_true") - parser.add_argument("--dht", action="store_true") - parser.add_argument("--h0", action="store_true") - parser.add_argument("--use-exp2", action="store_true") - parser.add_argument("--warmup", type=int, default=10) - parser.add_argument("--iters", type=int, default=100) - parser.add_argument("--seed", type=int, default=42) - args = parser.parse_args() - - if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 9: - raise RuntimeError("This benchmark requires an SM90/Hopper GPU.") - if args.T % 64 != 0: - raise ValueError("Use T as a multiple of 64 for this benchmark.") - if args.V % 64 != 0: - raise ValueError("Use V as a multiple of 64 for the SM90 WGMMA path.") - - q, k, w, do, dv, gk, dht, h0 = make_inputs(args) +def run_one(args): + q, k, w, do, dv, g, gk, dht, h0, cu_seqlens = make_inputs(args) scale = args.K**-0.5 + is_varlen = cu_seqlens is not None def run_fla(): return fla_bwd_dhu( @@ -92,12 +101,15 @@ def run_fla(): w=w, do=do, dv=dv, + g=g, gk=gk, h0=h0, dht=dht, scale=scale, + cu_seqlens=cu_seqlens.long() if cu_seqlens is not None else None, chunk_size=64, use_exp2=args.use_exp2, + transpose_state_layout=args.transpose_state, ) def run_cute(): @@ -107,12 +119,15 @@ def run_cute(): w=w, do=do, dv=dv, + g=g, gk=gk, h0=h0, dht=dht, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=64, use_exp2=args.use_exp2, + transpose_state_layout=args.transpose_state, ) ref = run_fla() @@ -120,18 +135,152 @@ def run_cute(): torch.cuda.synchronize() max_dh = (ref[0].float() - got[0].float()).abs().max().item() max_dv = (ref[2].float() - got[2].float()).abs().max().item() + max_dh0 = None + if ref[1] is not None: + max_dh0 = (ref[1].float() - got[1].float()).abs().max().item() fla_ms = time_kernel(run_fla, args.warmup, args.iters) cute_ms = time_kernel(run_cute, args.warmup, args.iters) + shape_tag = f"seq_lens={args.seq_lens}" if is_varlen else f"B={args.B} T={args.T}" print( - f"bwd_dhu SM90 B={args.B} T={args.T} H={args.H} K={args.K} V={args.V} " - f"gk={args.gk} dht={args.dht} h0={args.h0} exp2={args.use_exp2}" + f"bwd_dhu SM90 {shape_tag} H={args.H} K={args.K} V={args.V} " + f"g={args.g} gk={args.gk} dht={args.dht} h0={args.h0} exp2={args.use_exp2} transpose={args.transpose_state}" ) - print(f"max_diff dh={max_dh:.6f} dv2={max_dv:.6f}") + if max_dh0 is None: + print(f"max_diff dh={max_dh:.6f} dv2={max_dv:.6f}") + else: + print(f"max_diff dh={max_dh:.6f} dh0={max_dh0:.6f} dv2={max_dv:.6f}") print(f"FLA Triton: {fla_ms:.4f} ms") print(f"CuTe DSL : {cute_ms:.4f} ms") print(f"speedup : {fla_ms / cute_ms:.3f}x") + return { + "B": args.B, + "T": args.T, + "seq_lens": args.seq_lens, + "H": args.H, + "K": args.K, + "V": args.V, + "g": args.g, + "gk": args.gk, + "dht": args.dht, + "h0": args.h0, + "exp2": args.use_exp2, + "transpose": args.transpose_state, + "max_dh": max_dh, + "max_dh0": max_dh0, + "max_dv": max_dv, + "fla_ms": fla_ms, + "cute_ms": cute_ms, + "speedup": fla_ms / cute_ms, + } + + +def suite_configs(kind: str): + quick = [ + dict(B=1, T=512, H=4, K=64, V=64, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), + dict(B=1, T=512, H=4, K=128, V=64, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), + dict(B=1, T=512, H=4, K=128, V=128, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), + dict(B=1, T=512, H=2, K=256, V=64, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), + ] + if kind == "quick": + return quick + return quick + [ + dict( + seq_lens=[50, 192, 100], + H=2, + K=64, + V=64, + g=False, + gk=True, + dht=True, + h0=False, + use_exp2=True, + transpose_state=False, + ), + dict( + seq_lens=[33, 128, 200], + H=1, + K=128, + V=64, + g=True, + gk=False, + dht=True, + h0=True, + use_exp2=True, + transpose_state=False, + ), + dict(B=1, T=512, H=4, K=64, V=64, g=True, gk=False, dht=True, h0=False, use_exp2=True, transpose_state=False), + dict(B=1, T=512, H=2, K=128, V=64, g=True, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), + dict(B=2, T=1024, H=4, K=128, V=64, g=False, gk=True, dht=True, h0=True, use_exp2=True, transpose_state=False), + dict(B=1, T=2048, H=8, K=128, V=64, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), + dict(B=1, T=1024, H=8, K=64, V=128, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), + dict(B=1, T=512, H=4, K=128, V=64, g=False, gk=True, dht=True, h0=True, use_exp2=True, transpose_state=True), + ] + + +def _fmt_optional(value): + return "n/a" if value is None else f"{value:.6f}" + + +def print_suite(results): + print("\n" + "=" * 118) + print(" bwd_dhu SM90 Suite: CuTe DSL vs FLA Triton") + print("=" * 118) + for r in results: + flags = ",".join(name for name in ("g", "gk", "dht", "h0", "exp2", "transpose") if r[name]) + shape = f"seqs={r['seq_lens']!s:<17s}" if r["seq_lens"] is not None else f"B={r['B']:2d} T={r['T']:5d}" + print( + f" {shape} H={r['H']:2d} K={r['K']:3d} V={r['V']:3d} [{flags:<16s}] | " + f"diff dh={r['max_dh']:.6f} dh0={_fmt_optional(r['max_dh0'])} dv2={r['max_dv']:.6f} | " + f"FLA={r['fla_ms']:.4f}ms CuTe={r['cute_ms']:.4f}ms speedup={r['speedup']:.3f}x" + ) + speedups = [r["speedup"] for r in results if r["speedup"] > 0] + geo = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + print("-" * 118) + print(f" Geometric mean speedup: {geo:.3f}x") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--suite", choices=["none", "quick", "full"], default="none") + parser.add_argument("--B", type=int, default=1) + parser.add_argument("--T", type=int, default=1024) + parser.add_argument("--seq-lens", type=str, default=None) + parser.add_argument("--H", type=int, default=8) + parser.add_argument("--K", type=int, default=128, choices=[64, 128, 256]) + parser.add_argument("--V", type=int, default=64) + parser.add_argument("--g", action="store_true") + parser.add_argument("--gk", action="store_true") + parser.add_argument("--dht", action="store_true") + parser.add_argument("--h0", action="store_true") + parser.add_argument("--use-exp2", action="store_true") + parser.add_argument("--transpose-state", action="store_true") + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=100) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + if args.seq_lens is not None: + args.seq_lens = [int(x) for x in args.seq_lens.split(",") if x] + + if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 9: + raise RuntimeError("This benchmark requires an SM90/Hopper GPU.") + if args.suite == "none": + if args.seq_lens is None and args.T % 64 != 0: + raise ValueError("Use T as a multiple of 64 for this benchmark.") + if args.V % 64 != 0: + raise ValueError("Use V as a multiple of 64 for the SM90 WGMMA path.") + run_one(args) + return + + results = [] + for cfg in suite_configs(args.suite): + case_args = argparse.Namespace(**vars(args)) + case_args.seq_lens = None + for key, value in cfg.items(): + setattr(case_args, key, value) + results.append(run_one(case_args)) + print_suite(results) if __name__ == "__main__": diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 55cb7b5..92f6522 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -19,7 +19,7 @@ - fixed chunk size BT=64 - BV=64, matching cula/ops/chunk_delta_h.py - non-varlen tensors [B, T, H, D] - - non-transposed state layout [B, NT, H, K, V] + - state layout [B, NT, H, K, V] or [B, NT, H, V, K] - non-persistent scheduling The recurrence is the Triton bwd_dhu recurrence: @@ -45,7 +45,8 @@ import torch from cutlass.cute.nvgpu import cpasync from cutlass.cute.runtime import from_dlpack -from cutlass.cute.typing import Float32, Int64 +from cutlass.cute.typing import Float32, Int32, Int64 +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from cula.utils import USE_FAST_MATH, assert_hopper @@ -63,13 +64,18 @@ def __init__( self, batch_size: int, seq_len: int, + num_sequences: int, + total_chunks: int, num_heads: int, head_dim_k: int, head_dim_v: int, + is_varlen: bool, + use_g: bool, use_gk: bool, use_dht: bool, use_dh0: bool, use_exp2: bool, + transpose_state_layout: bool, scale: float, use_fast_math: bool = True, ): @@ -77,13 +83,18 @@ def __init__( assert head_dim_v % BV == 0, f"SM90 bwd_dhu tensor-core path requires V to be a multiple of {BV}, got {head_dim_v}" self.B = batch_size self.T = seq_len + self.N = num_sequences + self.NT = total_chunks self.H = num_heads self.K = head_dim_k self.V = head_dim_v + self.is_varlen = is_varlen + self.use_g = use_g self.use_gk = use_gk self.use_dht = use_dht self.use_dh0 = use_dh0 self.use_exp2 = use_exp2 + self.transpose_state_layout = transpose_state_layout self.scale = scale self.use_fast_math = use_fast_math @@ -106,6 +117,7 @@ def __call__( q_in: cute.Tensor, k_in: cute.Tensor, w_in: cute.Tensor, + g_in: cute.Tensor, gk_in: cute.Tensor, dht_in: cute.Tensor, dh0_in: cute.Tensor, @@ -113,11 +125,14 @@ def __call__( dh_in: cute.Tensor, dv_in: cute.Tensor, dv2_in: cute.Tensor, + cu_seqlens_in: cute.Tensor, + chunk_offsets_in: cute.Tensor, stream: cuda.CUstream, ): q_ptr = q_in.iterator k_ptr = k_in.iterator w_ptr = w_in.iterator + g_ptr = g_in.iterator gk_ptr = gk_in.iterator dht_ptr = dht_in.iterator dh0_ptr = dh0_in.iterator @@ -125,8 +140,10 @@ def __call__( dh_ptr = dh_in.iterator dv_ptr = dv_in.iterator dv2_ptr = dv2_in.iterator + cu_seqlens_ptr = cu_seqlens_in.iterator + chunk_offsets_ptr = chunk_offsets_in.iterator - NT = (self.T + self.BT - 1) // self.BT + NT_total = self.NT q_layout = cute.make_layout( (self.B, self.T, self.H, self.K), @@ -144,28 +161,54 @@ def __call__( dv = cute.make_tensor(dv_ptr, v_layout) dv2 = cute.make_tensor(dv2_ptr, v_layout) + g_layout = cute.make_layout( + (self.B, self.T, self.H), + stride=(self.T * self.H, self.H, 1), + ) + g = cute.make_tensor(g_ptr, g_layout) + gk_layout = cute.make_layout( (self.B, self.T, self.H, self.K), stride=(self.T * self.H * self.K, self.H * self.K, self.K, 1), ) gk = cute.make_tensor(gk_ptr, gk_layout) - - state_layout = cute.make_layout( - (self.B, NT, self.H, self.K, self.V), - stride=( - NT * self.H * self.K * self.V, - self.H * self.K * self.V, - self.K * self.V, - self.V, - 1, - ), - ) + cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((self.N + 1,))) + chunk_offsets = cute.make_tensor(chunk_offsets_ptr, cute.make_layout((self.N + 1,))) + + if cutlass.const_expr(self.transpose_state_layout): + state_layout = cute.make_layout( + (self.B, NT_total, self.H, self.V, self.K), + stride=( + NT_total * self.H * self.K * self.V, + self.H * self.K * self.V, + self.K * self.V, + self.K, + 1, + ), + ) + else: + state_layout = cute.make_layout( + (self.B, NT_total, self.H, self.K, self.V), + stride=( + NT_total * self.H * self.K * self.V, + self.H * self.K * self.V, + self.K * self.V, + self.V, + 1, + ), + ) dh = cute.make_tensor(dh_ptr, state_layout) - final_layout = cute.make_layout( - (self.B, self.H, self.K, self.V), - stride=(self.H * self.K * self.V, self.K * self.V, self.V, 1), - ) + if cutlass.const_expr(self.transpose_state_layout): + final_layout = cute.make_layout( + (self.N, self.H, self.V, self.K), + stride=(self.H * self.K * self.V, self.K * self.V, self.K, 1), + ) + else: + final_layout = cute.make_layout( + (self.N, self.H, self.K, self.V), + stride=(self.H * self.K * self.V, self.K * self.V, self.V, 1), + ) dht = cute.make_tensor(dht_ptr, final_layout) dh0 = cute.make_tensor(dh0_ptr, final_layout) @@ -312,6 +355,7 @@ class SharedStorage: q, k, w, + g, gk, dht, dh0, @@ -319,6 +363,8 @@ class SharedStorage: dh, dv, dv2, + cu_seqlens, + chunk_offsets, tiled_mma, update_tiled_mma, a_smem_layout_staged, @@ -338,7 +384,7 @@ class SharedStorage: tma_atom_dv2, tma_tensor_dv2, ).launch( - grid=[cute.ceil_div(self.V, self.BV), self.B * self.H, 1], + grid=[cute.ceil_div(self.V, self.BV), self.N * self.H, 1], block=[self.num_threads, 1, 1], cluster=self.cluster_shape_mnk, stream=stream, @@ -351,6 +397,7 @@ def kernel( q: cute.Tensor, k: cute.Tensor, w: cute.Tensor, + g: cute.Tensor, gk: cute.Tensor, dht: cute.Tensor, dh0: cute.Tensor, @@ -358,6 +405,8 @@ def kernel( dh: cute.Tensor, dv: cute.Tensor, dv2: cute.Tensor, + cu_seqlens: cute.Tensor, + chunk_offsets: cute.Tensor, tiled_mma: cute.TiledMma, update_tiled_mma: cute.TiledMma, a_smem_layout_staged: cute.ComposedLayout, @@ -380,10 +429,22 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) i_v_tile, i_bh, _ = cute.arch.block_idx() - i_b = i_bh // self.H - i_h = i_bh - i_b * self.H - v_base = i_v_tile * self.BV + i_n = i_bh // self.H + i_h = i_bh - i_n * self.H + data_b = i_n + state_b = i_n + seq_start = Int32(0) + seq_len = self.T NT = (self.T + self.BT - 1) // self.BT + chunk_base = Int32(0) + if cutlass.const_expr(self.is_varlen): + data_b = Int32(0) + state_b = Int32(0) + seq_start = cu_seqlens[i_n] + seq_len = cu_seqlens[i_n + 1] - seq_start + NT = (seq_len + self.BT - 1) // self.BT + chunk_base = chunk_offsets[i_n] + v_base = i_v_tile * self.BV smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) @@ -446,17 +507,38 @@ def kernel( acc_qdo = update_thr_mma.make_fragment_C(update_thr_mma.partition_shape_C((BV, self.K))) acc_wdv = update_thr_mma.make_fragment_C(update_thr_mma.partition_shape_C((BV, self.K))) - _, bSG_sK, bSG_gK = self._epilog_partition(tma_atom_k, tma_tensor_k[None, None, (i_h, i_b)], (self.BT, self.K), sA) + if cutlass.const_expr(self.is_varlen): + tma_tensor_k_use = cute.domain_offset((seq_start, 0, (0, 0)), tma_tensor_k) + tma_tensor_dv_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_dv) + tma_tensor_do_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_do) + tma_tensor_q_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_q) + tma_tensor_w_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_w) + tma_tensor_dv2_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_dv2) + else: + tma_tensor_k_use = tma_tensor_k + tma_tensor_dv_use = tma_tensor_dv + tma_tensor_do_use = tma_tensor_do + tma_tensor_q_use = tma_tensor_q + tma_tensor_w_use = tma_tensor_w + tma_tensor_dv2_use = tma_tensor_dv2 + + _, bSG_sK, bSG_gK = self._epilog_partition( + tma_atom_k, tma_tensor_k_use[None, None, (i_h, data_b)], (self.BT, self.K), sA + ) _, bSG_sDv, bSG_gDv = self._epilog_partition( - tma_atom_dv, tma_tensor_dv[None, None, (i_h, i_b)], (self.BV, self.BT), sDv2T + tma_atom_dv, tma_tensor_dv_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDv2T ) _, bSG_sDo, bSG_gDo = self._epilog_partition( - tma_atom_do, tma_tensor_do[None, None, (i_h, i_b)], (self.BV, self.BT), sUA + tma_atom_do, tma_tensor_do_use[None, None, (i_h, data_b)], (self.BV, self.BT), sUA + ) + _, bSG_sQ, bSG_gQ = self._epilog_partition( + tma_atom_q, tma_tensor_q_use[None, None, (i_h, data_b)], (self.K, self.BT), sUB + ) + _, bSG_sW, bSG_gW = self._epilog_partition( + tma_atom_w, tma_tensor_w_use[None, None, (i_h, data_b)], (self.K, self.BT), sUB ) - _, bSG_sQ, bSG_gQ = self._epilog_partition(tma_atom_q, tma_tensor_q[None, None, (i_h, i_b)], (self.K, self.BT), sUB) - _, bSG_sW, bSG_gW = self._epilog_partition(tma_atom_w, tma_tensor_w[None, None, (i_h, i_b)], (self.K, self.BT), sUB) _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( - tma_atom_dv2, tma_tensor_dv2[None, None, (i_h, i_b)], (self.BV, self.BT), sDv2T + tma_atom_dv2, tma_tensor_dv2_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDv2T ) # Initialize carried dh state. @@ -465,20 +547,34 @@ def kernel( v_idx = v_base + v_rel init = Float32(0.0) if cutlass.const_expr(self.use_dht): - init = dht[i_b, i_h, k_idx, v_idx].to(self.acc_dtype) + if cutlass.const_expr(self.transpose_state_layout): + init = dht[i_n, i_h, v_idx, k_idx].to(self.acc_dtype) + else: + init = dht[i_n, i_h, k_idx, v_idx].to(self.acc_dtype) rState[ei] = init - for chunk_rev in cutlass.range_constexpr(NT): + for chunk_rev in cutlass.range(0, NT, unroll=0): i_t = NT - 1 - chunk_rev chunk_start = i_t * self.BT - chunk_end = cutlass.min(chunk_start + self.BT, self.T) + chunk_end = cutlass.min(chunk_start + self.BT, seq_len) last_idx = chunk_end - 1 + g_last = Float32(0.0) + g_last_exp = Float32(1.0) + if cutlass.const_expr(self.use_g): + g_last = g[data_b, seq_start + last_idx, i_h].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_last_exp = cute.exp2(g_last, fastmath=self.use_fast_math) + else: + g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) # Store dh before applying this chunk's reverse update. for ei in cutlass.range(cute.size(rState), unroll_full=True): v_rel, k_idx = tUcState[ei] v_idx = v_base + v_rel - dh[i_b, i_t, i_h, k_idx, v_idx] = rState[ei].to(dh.element_type) + if cutlass.const_expr(self.transpose_state_layout): + dh[state_b, chunk_base + i_t, i_h, v_idx, k_idx] = rState[ei].to(dh.element_type) + else: + dh[state_b, chunk_base + i_t, i_h, k_idx, v_idx] = rState[ei].to(dh.element_type) cute.arch.sync_threads() # dv2 = dv + K @ dh. @@ -523,8 +619,16 @@ def kernel( t_idx = chunk_start + t_rel v_idx = v_base + v_rel out = Float32(0.0) - if t_idx < self.T: - out = acc_dv[ei] + sDv2T[v_rel, t_rel, 0].to(self.acc_dtype) + if t_idx < seq_len: + out = acc_dv[ei] + if cutlass.const_expr(self.use_g): + g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_decay = cute.exp2(g_last - g_cur, fastmath=self.use_fast_math) + else: + g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) + out = out * g_decay + out = out + sDv2T[v_rel, t_rel, 0].to(self.acc_dtype) sDv2T[v_rel, t_rel, 0] = out.to(self.io_dtype) cute.arch.fence_proxy("async.shared", space="cta") cute.arch.sync_threads() @@ -534,16 +638,19 @@ def kernel( cute.arch.cp_async_bulk_wait_group(0, read=True) cute.arch.sync_threads() - # Apply gk decay after dv2, before accumulating QO - WV into dh. + # Apply state decay after dv2, before accumulating QO - WV into dh. + if cutlass.const_expr(self.use_g): + for ei in cutlass.range(cute.size(rState), unroll_full=True): + rState[ei] = rState[ei] * g_last_exp if cutlass.const_expr(self.use_gk): for ei in cutlass.range(cute.size(rState), unroll_full=True): v_rel, k_idx = tUcState[ei] - gk_last = gk[i_b, last_idx, i_h, k_idx].to(self.acc_dtype) + gk_last = gk[data_b, seq_start + last_idx, i_h, k_idx].to(self.acc_dtype) if cutlass.const_expr(self.use_exp2): - scale = cute.exp2(gk_last, fastmath=self.use_fast_math) + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) else: - scale = cute.exp(gk_last, fastmath=self.use_fast_math) - rState[ei] = rState[ei] * scale + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + rState[ei] = rState[ei] * k_decay # dh += scale * do^T @ q - dv2^T @ w. if warp_idx == 0: @@ -553,6 +660,24 @@ def kernel( qdo_wait = load_qdo_C.wait_and_advance() cute.arch.sync_threads() + if cutlass.const_expr(self.use_g): + linear_q = tidx + while linear_q < self.K * self.BT: + k_rel = linear_q // self.BT + t_rel = linear_q - k_rel * self.BT + t_idx = chunk_start + t_rel + q_scaled = Float32(0.0) + if t_idx < seq_len: + g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) + else: + g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) + q_scaled = sUB[k_rel, t_rel, 0].to(self.acc_dtype) * g_exp + sUB[k_rel, t_rel, 0] = q_scaled.to(self.io_dtype) + linear_q += self.num_threads + cute.arch.sync_threads() + acc_qdo.fill(0.0) cute.nvgpu.warpgroup.fence() for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): @@ -606,7 +731,10 @@ def kernel( for ei in cutlass.range(cute.size(rState), unroll_full=True): v_rel, k_idx = tUcState[ei] v_idx = v_base + v_rel - dh0[i_b, i_h, k_idx, v_idx] = rState[ei] + if cutlass.const_expr(self.transpose_state_layout): + dh0[i_n, i_h, v_idx, k_idx] = rState[ei] + else: + dh0[i_n, i_h, k_idx, v_idx] = rState[ei] @cute.jit def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): @@ -631,25 +759,35 @@ def _as_cute(tensor: torch.Tensor): def _compile_bwd_dhu_sm90( B: int, T: int, + N: int, + NT: int, H: int, K: int, V: int, + is_varlen: bool, + use_g: bool, use_gk: bool, use_dht: bool, use_dh0: bool, use_exp2: bool, + transpose_state_layout: bool, scale: float, ): kernel = ChunkDeltaRuleBwdDHUSm90( batch_size=B, seq_len=T, + num_sequences=N, + total_chunks=NT, num_heads=H, head_dim_k=K, head_dim_v=V, + is_varlen=is_varlen, + use_g=use_g, use_gk=use_gk, use_dht=use_dht, use_dh0=use_dh0, use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, scale=scale, use_fast_math=USE_FAST_MATH, ) @@ -660,10 +798,18 @@ def _compile_bwd_dhu_sm90( do_fake = torch.empty(B, T, H, V, device="cuda", dtype=torch.bfloat16) dv_fake = torch.empty_like(do_fake) dv2_fake = torch.empty_like(do_fake) + g_fake = torch.empty(B, T, H, device="cuda", dtype=torch.float32) gk_fake = torch.empty(B, T, H, K, device="cuda", dtype=torch.float32) - dht_fake = torch.empty(B, H, K, V, device="cuda", dtype=torch.float32) - dh0_fake = torch.empty_like(dht_fake) - dh_fake = torch.empty(B, math.ceil(T / BT), H, K, V, device="cuda", dtype=torch.bfloat16) + if transpose_state_layout: + dht_fake = torch.empty(N, H, V, K, device="cuda", dtype=torch.float32) + dh0_fake = torch.empty_like(dht_fake) + dh_fake = torch.empty(B, NT, H, V, K, device="cuda", dtype=torch.bfloat16) + else: + dht_fake = torch.empty(N, H, K, V, device="cuda", dtype=torch.float32) + dh0_fake = torch.empty_like(dht_fake) + dh_fake = torch.empty(B, NT, H, K, V, device="cuda", dtype=torch.bfloat16) + cu_fake = torch.empty(N + 1, device="cuda", dtype=torch.int32) + offsets_fake = torch.empty(N + 1, device="cuda", dtype=torch.int32) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) return cute.compile( @@ -671,6 +817,7 @@ def _compile_bwd_dhu_sm90( _as_cute(q_fake), _as_cute(k_fake), _as_cute(w_fake), + _as_cute(g_fake), _as_cute(gk_fake), _as_cute(dht_fake), _as_cute(dh0_fake), @@ -678,6 +825,8 @@ def _compile_bwd_dhu_sm90( _as_cute(dh_fake), _as_cute(dv_fake), _as_cute(dv2_fake), + _as_cute(cu_fake), + _as_cute(offsets_fake), stream=stream, options="--enable-tvm-ffi", ) @@ -701,19 +850,15 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( transpose_state_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: """FLA-compatible wrapper for the SM90 WGMMA bwd_dhu path.""" - del chunk_indices assert_hopper(q.device) - if cu_seqlens is not None: - raise NotImplementedError("SM90 bwd_dhu currently only supports non-varlen tensors.") - if transpose_state_layout: - raise NotImplementedError("SM90 bwd_dhu currently only supports [B, NT, H, K, V] state layout.") - if g is not None: - raise NotImplementedError("SM90 bwd_dhu currently supports gk gating, not scalar g gating.") if chunk_size != BT: raise NotImplementedError(f"SM90 bwd_dhu only supports chunk_size={BT}.") B, T, H, K = q.shape V = do.shape[-1] + is_varlen = cu_seqlens is not None + if is_varlen and B != 1: + raise ValueError("varlen mode expects packed inputs with shape [1, total_T, H, D].") if K not in (64, 128, 256): raise NotImplementedError(f"SM90 bwd_dhu only supports K in {{64, 128, 256}}, got K={K}.") if V % BV != 0: @@ -726,36 +871,70 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( raise ValueError("q, k, and w must be contiguous.") if not do.is_contiguous() or not dv.is_contiguous(): raise ValueError("do and dv must be contiguous.") - - NT = math.ceil(T / BT) + if h0 is not None and (h0.dtype != torch.float32 or not h0.is_contiguous()): + raise ValueError("h0 must be contiguous float32.") + if cu_seqlens is not None and (cu_seqlens.device != q.device or not cu_seqlens.is_contiguous()): + raise ValueError("cu_seqlens must be contiguous and on the same CUDA device as q.") + if chunk_indices is not None and (chunk_indices.device != q.device or not chunk_indices.is_contiguous()): + raise ValueError("chunk_indices must be contiguous and on the same CUDA device as q.") + + if is_varlen: + if chunk_indices is None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + N = len(cu_seqlens) - 1 + NT = len(chunk_indices) + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT).int() + cu_seqlens_arg = cu_seqlens.int() if cu_seqlens.dtype != torch.int32 else cu_seqlens + else: + N = B + NT = math.ceil(T / BT) + cu_seqlens_arg = torch.arange(B + 1, device=q.device, dtype=torch.int32) * T + chunk_offsets = torch.arange(B + 1, device=q.device, dtype=torch.int32) * NT scale_value = 1.0 if scale is None else float(scale) - dh = q.new_empty(B, NT, H, K, V) - dh0 = torch.empty(B, H, K, V, device=q.device, dtype=torch.float32) if h0 is not None else None + state_shape = (N, H, V, K) if transpose_state_layout else (N, H, K, V) + dh = q.new_empty(B, NT, H, V, K) if transpose_state_layout else q.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None dv2 = torch.empty_like(dv) + g_arg = g if g is not None else torch.empty(B, T, H, device=q.device, dtype=torch.float32) gk_arg = gk if gk is not None else torch.empty(B, T, H, K, device=q.device, dtype=torch.float32) - dht_arg = dht if dht is not None else torch.empty(B, H, K, V, device=q.device, dtype=torch.float32) - dh0_arg = dh0 if dh0 is not None else torch.empty(B, H, K, V, device=q.device, dtype=torch.float32) + dht_arg = dht if dht is not None else torch.empty(state_shape, device=q.device, dtype=torch.float32) + dh0_arg = dh0 if dh0 is not None else torch.empty(state_shape, device=q.device, dtype=torch.float32) + if g is not None and (g.dtype != torch.float32 or not g.is_contiguous()): + raise ValueError("g must be contiguous float32.") + if g is not None and tuple(g.shape) != (B, T, H): + raise ValueError(f"g must have shape {(B, T, H)}, got {tuple(g.shape)}.") if gk is not None and (gk.dtype != torch.float32 or not gk.is_contiguous()): raise ValueError("gk must be contiguous float32.") + if gk is not None and tuple(gk.shape) != (B, T, H, K): + raise ValueError(f"gk must have shape {(B, T, H, K)}, got {tuple(gk.shape)}.") if dht is not None and (dht.dtype != torch.float32 or not dht.is_contiguous()): raise ValueError("dht must be contiguous float32.") + if dht is not None and tuple(dht.shape) != state_shape: + raise ValueError(f"dht must have shape {state_shape} for this state layout, got {tuple(dht.shape)}.") + if h0 is not None and tuple(h0.shape) != state_shape: + raise ValueError(f"h0 must have shape {state_shape} for this state layout, got {tuple(h0.shape)}.") compiled = _compile_bwd_dhu_sm90( B, T, + N, + NT, H, K, V, + is_varlen, + g is not None, gk is not None, dht is not None, h0 is not None, use_exp2, + transpose_state_layout, scale_value, ) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - compiled(q, k, w, gk_arg, dht_arg, dh0_arg, do, dh, dv, dv2, stream) + compiled(q, k, w, g_arg, gk_arg, dht_arg, dh0_arg, do, dh, dv, dv2, cu_seqlens_arg, chunk_offsets, stream) return dh, dh0, dv2 diff --git a/tests/test_chunk_delta_h_bwd_sm90.py b/tests/test_chunk_delta_h_bwd_sm90.py index cfdee3d..a5a4d1e 100644 --- a/tests/test_chunk_delta_h_bwd_sm90.py +++ b/tests/test_chunk_delta_h_bwd_sm90.py @@ -27,7 +27,19 @@ def _is_sm90() -> bool: ] -def _make_inputs(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, seed=42): +def _make_inputs( + B, + T, + H, + K, + V, + use_g=False, + use_gk=False, + use_dht=False, + use_h0=False, + seed=42, + transpose_state_layout=False, +): torch.manual_seed(seed) q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 @@ -35,23 +47,54 @@ def _make_inputs(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, seed= do = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 dv = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + g = None + if use_g: + g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) + gk = None if use_gk: gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) dht = None if use_dht: - dht = torch.randn(B, H, K, V, dtype=torch.float32, device="cuda") * 0.01 + state_shape = (B, H, V, K) if transpose_state_layout else (B, H, K, V) + dht = torch.randn(state_shape, dtype=torch.float32, device="cuda") * 0.01 h0 = None if use_h0: - h0 = torch.empty(B, H, K, V, dtype=torch.float32, device="cuda") + state_shape = (B, H, V, K) if transpose_state_layout else (B, H, K, V) + h0 = torch.empty(state_shape, dtype=torch.float32, device="cuda") - return q, k, w, do, dv, gk, dht, h0 + return q, k, w, do, dv, g, gk, dht, h0 -def _run_case(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, use_exp2=False): - q, k, w, do, dv, gk, dht, h0 = _make_inputs(B, T, H, K, V, use_gk, use_dht, use_h0) +def _run_case( + B, + T, + H, + K, + V, + use_g=False, + use_gk=False, + use_dht=False, + use_h0=False, + use_exp2=False, + transpose_state_layout=False, + seed=42, +): + q, k, w, do, dv, g, gk, dht, h0 = _make_inputs( + B, + T, + H, + K, + V, + use_g, + use_gk, + use_dht, + use_h0, + seed=seed, + transpose_state_layout=transpose_state_layout, + ) scale = K**-0.5 ref_dh, ref_dh0, ref_dv2 = fla_bwd_dhu( @@ -60,12 +103,14 @@ def _run_case(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, use_exp2 w=w, do=do, dv=dv, + g=g, gk=gk, h0=h0, dht=dht, scale=scale, chunk_size=64, use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, ) got_dh, got_dh0, got_dv2 = chunk_gated_delta_rule_bwd_dhu_sm90( @@ -74,12 +119,14 @@ def _run_case(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, use_exp2 w=w, do=do, dv=dv, + g=g, gk=gk, h0=h0, dht=dht, scale=scale, chunk_size=64, use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, ) torch.testing.assert_close(got_dh.float(), ref_dh.float(), atol=3e-2, rtol=3e-2) @@ -91,6 +138,121 @@ def _run_case(B, T, H, K, V, use_gk=False, use_dht=False, use_h0=False, use_exp2 assert got_dh0 is None +def _make_varlen_inputs( + seq_lens, + H, + K, + V, + use_g=False, + use_gk=False, + use_dht=False, + use_h0=False, + seed=42, + transpose_state_layout=False, +): + torch.manual_seed(seed) + T_total = sum(seq_lens) + N = len(seq_lens) + cu = [0] + for seq_len in seq_lens: + cu.append(cu[-1] + seq_len) + + q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + w = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 + do = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + dv = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + + g = None + if use_g: + g = torch.empty(1, T_total, H, dtype=torch.float32, device="cuda") + for i in range(N): + bos, eos = cu[i], cu[i + 1] + seg = torch.randn(1, eos - bos, H, dtype=torch.float32, device="cuda") * 0.01 + g[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) + + gk = None + if use_gk: + gk = torch.empty(1, T_total, H, K, dtype=torch.float32, device="cuda") + for i in range(N): + bos, eos = cu[i], cu[i + 1] + seg = torch.randn(1, eos - bos, H, K, dtype=torch.float32, device="cuda") * 0.01 + gk[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) + + state_shape = (N, H, V, K) if transpose_state_layout else (N, H, K, V) + dht = torch.randn(state_shape, dtype=torch.float32, device="cuda") * 0.01 if use_dht else None + h0 = torch.empty(state_shape, dtype=torch.float32, device="cuda") if use_h0 else None + cu_seqlens = torch.tensor(cu, dtype=torch.int32, device="cuda") + return q, k, w, do, dv, g, gk, dht, h0, cu_seqlens + + +def _run_varlen_case( + seq_lens, + H, + K, + V, + use_g=False, + use_gk=False, + use_dht=False, + use_h0=False, + use_exp2=False, + transpose_state_layout=False, + seed=42, +): + q, k, w, do, dv, g, gk, dht, h0, cu_seqlens = _make_varlen_inputs( + seq_lens, + H, + K, + V, + use_g=use_g, + use_gk=use_gk, + use_dht=use_dht, + use_h0=use_h0, + seed=seed, + transpose_state_layout=transpose_state_layout, + ) + scale = K**-0.5 + ref_dh, ref_dh0, ref_dv2 = fla_bwd_dhu( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + h0=h0, + dht=dht, + scale=scale, + cu_seqlens=cu_seqlens.long(), + chunk_size=64, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + got_dh, got_dh0, got_dv2 = chunk_gated_delta_rule_bwd_dhu_sm90( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + h0=h0, + dht=dht, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=64, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + torch.testing.assert_close(got_dh.float(), ref_dh.float(), atol=3e-2, rtol=3e-2) + torch.testing.assert_close(got_dv2.float(), ref_dv2.float(), atol=3e-2, rtol=3e-2) + if use_h0: + assert got_dh0 is not None + torch.testing.assert_close(got_dh0, ref_dh0, atol=3e-2, rtol=3e-2) + else: + assert got_dh0 is None + + @pytest.mark.parametrize("T", [64, 128]) @pytest.mark.parametrize("V", [64, 128]) def test_bwd_dhu_no_gating(T, V): @@ -101,9 +263,79 @@ def test_bwd_dhu_with_gk_exp2_and_dht(): _run_case(B=1, T=128, H=2, K=64, V=64, use_gk=True, use_dht=True, use_exp2=True) +def test_bwd_dhu_with_scalar_g_exp2_and_dht(): + _run_case(B=1, T=128, H=2, K=64, V=64, use_g=True, use_dht=True, use_exp2=True) + + +def test_bwd_dhu_with_scalar_g_and_gk(): + _run_case(B=1, T=128, H=1, K=128, V=64, use_g=True, use_gk=True, use_dht=True, use_exp2=True) + + def test_bwd_dhu_k128_with_gk_exp2_and_dht(): _run_case(B=1, T=128, H=1, K=128, V=64, use_gk=True, use_dht=True, use_exp2=True) def test_bwd_dhu_returns_dh0(): _run_case(B=2, T=128, H=1, K=64, V=64, use_h0=True) + + +@pytest.mark.parametrize( + "case", + [ + dict(B=1, T=256, H=4, K=64, V=64, use_gk=True, use_dht=True, use_exp2=True), + dict(B=1, T=256, H=4, K=64, V=64, use_g=True, use_dht=True, use_exp2=True), + dict(B=1, T=256, H=2, K=128, V=64, use_g=True, use_gk=True, use_dht=True, use_exp2=True), + dict(B=2, T=256, H=2, K=128, V=64, use_gk=True, use_dht=True, use_h0=True, use_exp2=True), + dict(B=1, T=512, H=4, K=128, V=128, use_gk=True, use_dht=True, use_exp2=True), + dict(B=1, T=128, H=2, K=256, V=64, use_gk=True, use_dht=True, use_exp2=True), + ], + ids=[ + "k64-v64-multihead-gk-dht", + "k64-v64-multihead-g-dht", + "k128-v64-g-and-gk", + "k128-v64-batch-h0", + "k128-v128-long", + "k256-v64", + ], +) +def test_bwd_dhu_forward_aligned_cases(case): + _run_case(**case, seed=123) + + +def test_bwd_dhu_transpose_state_layout(): + _run_case( + B=1, + T=128, + H=2, + K=128, + V=64, + use_gk=True, + use_dht=True, + use_h0=True, + use_exp2=True, + transpose_state_layout=True, + ) + + +@pytest.mark.parametrize( + "case", + [ + dict(seq_lens=[64, 128], H=1, K=64, V=64), + dict(seq_lens=[50, 192, 100], H=2, K=64, V=64, use_gk=True, use_dht=True, use_exp2=True), + dict(seq_lens=[33, 128, 200], H=1, K=128, V=64, use_g=True, use_dht=True, use_h0=True, use_exp2=True), + dict( + seq_lens=[96, 129], + H=1, + K=128, + V=64, + use_gk=True, + use_dht=True, + use_h0=True, + use_exp2=True, + transpose_state_layout=True, + ), + ], + ids=["basic", "gk-dht", "g-h0", "transpose-gk-h0"], +) +def test_bwd_dhu_varlen_cases(case): + _run_varlen_case(**case, seed=321) From 70a92846f1c5597d4c58fabddf8d972bb5e7bc07 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Wed, 13 May 2026 11:08:05 +0800 Subject: [PATCH 04/28] format --- benchmarks/bench_chunk_delta_h_bwd_sm90.py | 669 ++++++++++++++------- cula/ops/chunk_delta_h_bwd.py | 571 +++++++++++------- tests/test_chunk_delta_h_bwd_sm90.py | 411 ++++++------- 3 files changed, 982 insertions(+), 669 deletions(-) diff --git a/benchmarks/bench_chunk_delta_h_bwd_sm90.py b/benchmarks/bench_chunk_delta_h_bwd_sm90.py index 095dc49..4edebbc 100644 --- a/benchmarks/bench_chunk_delta_h_bwd_sm90.py +++ b/benchmarks/bench_chunk_delta_h_bwd_sm90.py @@ -3,117 +3,325 @@ # SPDX-License-Identifier: Apache-2.0 """ -Benchmark the SM90 CuTe DSL WGMMA bwd_dhu path against FLA Triton. +bench_chunk_delta_h_bwd_sm90.py - Benchmark: SM90 CuTe DSL bwd_dhu kernel + vs FLA Triton baseline. -Current kernel scope: - - non-varlen - - K in {64, 128, 256}, BT=64, BV=64 - - state layout [B, NT, H, K, V] or [B, NT, H, V, K] - - optional gk/dht/h0 +This mirrors benchmarks/bench_chunk_delta_h.py as closely as the backward API +allows: + - non-varlen and varlen modes + - K=128, V=128, BT=64, dtype=bf16 + - same default B/T/H and varlen sequence-count ranges as fwd + - dht/dh0 map to fwd initial_state/output_final_state -Example: - python benchmarks/bench_chunk_delta_h_bwd_sm90.py --B 1 --T 1024 --H 8 --K 128 --V 64 --gk --dht +Usage: + python benchmarks/bench_chunk_delta_h_bwd_sm90.py --mode both + python benchmarks/bench_chunk_delta_h_bwd_sm90.py --preset fwd --mode non-varlen """ import argparse import math +import os import pathlib import sys +os.environ.setdefault("CUDA_HOME", "/usr/local/cuda") +os.environ.setdefault("FLA_USE_FAST_OPS", os.getenv("CULA_USE_FAST_MATH", "1")) + sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) +import numpy as np import torch from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu as fla_bwd_dhu +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from cula.ops.chunk_delta_h_bwd import chunk_gated_delta_rule_bwd_dhu_sm90 +if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(line_buffering=True) + + +K, V, BT = 128, 128, 64 +dtype = torch.bfloat16 +device = "cuda" -def time_kernel(fn, warmup: int, iters: int) -> float: +WARMUP = 5 +N_ITERS = 30 +NCU_MODE = False + + +def time_kernel(fn, warmup=None, n_iters=None): + if warmup is None: + warmup = 1 if NCU_MODE else WARMUP + if n_iters is None: + n_iters = 1 if NCU_MODE else N_ITERS for _ in range(warmup): fn() torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + for _ in range(n_iters): fn() - end.record() + end_evt.record() torch.cuda.synchronize() - return start.elapsed_time(end) / iters - - -def make_inputs(args): - torch.manual_seed(args.seed) - seq_lens = getattr(args, "seq_lens", None) - is_varlen = seq_lens is not None - B = 1 if is_varlen else args.B - T = sum(seq_lens) if is_varlen else args.T - N = len(seq_lens) if is_varlen else B - H, K, V = args.H, args.K, args.V - q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - w = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - do = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 - dv = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 - cu_seqlens = None - if is_varlen: - cu = [0] - for seq_len in seq_lens: - cu.append(cu[-1] + seq_len) - cu_seqlens = torch.tensor(cu, dtype=torch.int32, device="cuda") + return start_evt.elapsed_time(end_evt) / n_iters + + +def accuracy_stats(ref, out): + diff = (ref.float() - out.float()).abs() + return diff.max().item(), diff.mean().item() + + +def bwd_accuracy_stats(ref_result, cute_result): + ref_dh, ref_dh0, ref_dv2 = ref_result + got_dh, got_dh0, got_dv2 = cute_result + dh_max, dh_mean = accuracy_stats(ref_dh, got_dh) + dv2_max, dv2_mean = accuracy_stats(ref_dv2, got_dv2) + dh0_max, dh0_mean = 0.0, 0.0 + if ref_dh0 is not None: + dh0_max, dh0_mean = accuracy_stats(ref_dh0, got_dh0) + return { + "dh_max": dh_max, + "dh_mean": dh_mean, + "dh0_max": dh0_max, + "dh0_mean": dh0_mean, + "dv2_max": dv2_max, + "dv2_mean": dv2_mean, + "max_diff": max(dh_max, dh0_max, dv2_max), + "mean_diff": max(dh_mean, dh0_mean, dv2_mean), + } + + +def make_non_varlen_inputs(B, T, H, use_g, use_gk, use_dht, use_dh0, transpose_state=False, seed=42): + torch.manual_seed(seed) + torch.cuda.empty_cache() + + q = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 + k = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 + w = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 + do = torch.randn(B, T, H, V, device=device, dtype=dtype) * 0.1 + dv = torch.randn(B, T, H, V, device=device, dtype=dtype) * 0.1 + g = None - if args.g: - if is_varlen: - g = torch.empty(B, T, H, dtype=torch.float32, device="cuda") - for i in range(N): - bos, eos = cu_seqlens[i].item(), cu_seqlens[i + 1].item() - seg = torch.randn(B, eos - bos, H, dtype=torch.float32, device="cuda") * 0.01 - g[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) - else: - g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) + if use_g: + g = -torch.abs(torch.randn(B, T, H, device=device, dtype=torch.float32) * 0.01).cumsum(dim=1) + gk = None - if args.gk: - if is_varlen: - gk = torch.empty(B, T, H, K, dtype=torch.float32, device="cuda") - for i in range(N): - bos, eos = cu_seqlens[i].item(), cu_seqlens[i + 1].item() - seg = torch.randn(B, eos - bos, H, K, dtype=torch.float32, device="cuda") * 0.01 - gk[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) - else: - gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) - state_shape = (N, H, V, K) if args.transpose_state else (N, H, K, V) - dht = None - if args.dht: - dht = torch.randn(state_shape, dtype=torch.float32, device="cuda") * 0.01 - h0 = torch.empty(state_shape, dtype=torch.float32, device="cuda") if args.h0 else None - return q, k, w, do, dv, g, gk, dht, h0, cu_seqlens - - -def run_one(args): - q, k, w, do, dv, g, gk, dht, h0, cu_seqlens = make_inputs(args) - scale = args.K**-0.5 - is_varlen = cu_seqlens is not None - - def run_fla(): - return fla_bwd_dhu( - q=q, - k=k, - w=w, - do=do, - dv=dv, - g=g, - gk=gk, - h0=h0, - dht=dht, - scale=scale, - cu_seqlens=cu_seqlens.long() if cu_seqlens is not None else None, - chunk_size=64, - use_exp2=args.use_exp2, - transpose_state_layout=args.transpose_state, + if use_gk: + gk = -torch.abs(torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.01).cumsum(dim=1) + + state_shape = (B, H, V, K) if transpose_state else (B, H, K, V) + dht = torch.randn(state_shape, device=device, dtype=torch.float32) * 0.01 if use_dht else None + dh0 = torch.empty(state_shape, device=device, dtype=torch.float32) if use_dh0 else None + return q, k, w, do, dv, g, gk, dht, dh0 + + +def generate_seq_lens(num_seqs, total_T, ratio, seed=42): + rng = np.random.RandomState(seed) + log_weights = rng.uniform(0, np.log(ratio), num_seqs) + weights = np.exp(log_weights) + raw_lens = weights / weights.sum() * total_T + seq_lens = np.maximum(np.round(raw_lens).astype(int), 1) + diff = total_T - seq_lens.sum() + if diff > 0: + indices = np.argsort(seq_lens) + for i in range(abs(diff)): + seq_lens[indices[i % num_seqs]] += 1 + elif diff < 0: + indices = np.argsort(-seq_lens) + for i in range(abs(diff)): + seq_lens[indices[i % num_seqs]] -= 1 + assert seq_lens.sum() == total_T + return list(seq_lens) + + +def make_varlen_inputs(num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0, seed=42): + seq_lens = generate_seq_lens(num_seqs, total_T, ratio, seed=seed) + cu_seqlens_list = [0] + for seq_len in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + seq_len) + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + cu_seqlens_long = cu_seqlens.long() + + chunk_indices = prepare_chunk_indices(cu_seqlens_long, BT) + chunk_offsets = prepare_chunk_offsets(cu_seqlens_long, BT).int() + + torch.manual_seed(seed) + torch.cuda.empty_cache() + + q = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 + k = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 + w = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 + do = torch.randn(1, total_T, H, V, device=device, dtype=dtype) * 0.1 + dv = torch.randn(1, total_T, H, V, device=device, dtype=dtype) * 0.1 + + g = None + if use_g: + g_raw = torch.randn(1, total_T, H, device=device, dtype=torch.float32) * 0.01 + g = torch.zeros_like(g_raw) + for i in range(num_seqs): + bos = cu_seqlens[i].item() + eos = cu_seqlens[i + 1].item() + g[:, bos:eos] = -torch.abs(g_raw[:, bos:eos]).cumsum(dim=1) + + gk = None + if use_gk: + gk_raw = torch.randn(1, total_T, H, K, device=device, dtype=torch.float32) * 0.01 + gk = torch.zeros_like(gk_raw) + for i in range(num_seqs): + bos = cu_seqlens[i].item() + eos = cu_seqlens[i + 1].item() + gk[:, bos:eos] = -torch.abs(gk_raw[:, bos:eos]).cumsum(dim=1) + + state_shape = (num_seqs, H, K, V) + dht = torch.randn(state_shape, device=device, dtype=torch.float32) * 0.01 if use_dht else None + dh0 = torch.empty(state_shape, device=device, dtype=torch.float32) if use_dh0 else None + return seq_lens, cu_seqlens, cu_seqlens_long, chunk_indices, chunk_offsets, q, k, w, do, dv, g, gk, dht, dh0 + + +def run_fla(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens_long=None): + return fla_bwd_dhu( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + h0=dh0, + dht=dht, + scale=K**-0.5, + cu_seqlens=cu_seqlens_long, + chunk_size=BT, + use_exp2=True, + ) + + +def run_cute(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens=None, chunk_indices=None, chunk_offsets=None): + return chunk_gated_delta_rule_bwd_dhu_sm90( + q=q, + k=k, + w=w, + do=do, + dv=dv, + g=g, + gk=gk, + h0=dh0, + dht=dht, + scale=K**-0.5, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + chunk_size=BT, + use_exp2=True, + ) + + +def flags_str(use_g, use_gk, use_dht, use_dh0): + flags = [] + if use_g: + flags.append("g") + if use_gk: + flags.append("gk") + if use_dht: + flags.append("dht") + if use_dh0: + flags.append("dh0") + return f" [{','.join(flags)}]" if flags else "" + + +def bench_non_varlen(configs): + print("\n" + "=" * 80) + print(" Non-Varlen Benchmark: CuTe DSL (SM90) bwd_dhu vs FLA Triton") + print("=" * 80) + results = [] + + for B, T, H, use_g, use_gk, use_dht, use_dh0 in configs: + q, k, w, do, dv, g, gk, dht, dh0 = make_non_varlen_inputs(B, T, H, use_g, use_gk, use_dht, use_dh0) + + ref = run_fla(q, k, w, do, dv, g, gk, dht, dh0) + got = run_cute(q, k, w, do, dv, g, gk, dht, dh0) + torch.cuda.synchronize() + acc = bwd_accuracy_stats(ref, got) + + def run_fla_case(q=q, k=k, w=w, do=do, dv=dv, g=g, gk=gk, dht=dht, dh0=dh0): + run_fla(q, k, w, do, dv, g, gk, dht, dh0) + + def run_cute_case(q=q, k=k, w=w, do=do, dv=dv, g=g, gk=gk, dht=dht, dh0=dh0): + run_cute(q, k, w, do, dv, g, gk, dht, dh0) + + ms_fla = time_kernel(run_fla_case) + ms_cute = time_kernel(run_cute_case) + speedup = ms_fla / ms_cute if ms_cute > 0 else float("inf") + flag_str = flags_str(use_g, use_gk, use_dht, use_dh0) + + r = { + "B": B, + "T": T, + "H": H, + "flags": flag_str, + "ms_fla": ms_fla, + "ms_cute": ms_cute, + "speedup": speedup, + **acc, + } + results.append(r) + print( + f" B={B:2d} T={T:5d} H={H:3d}{flag_str:<18s} | " + f"max={acc['max_diff']:.6f} mean={acc['mean_diff']:.8f} " + f"(dh={acc['dh_max']:.6f} dh0={acc['dh0_max']:.6f} dv2={acc['dv2_max']:.6f}) | " + f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | speedup={speedup:.2f}x" + ) + + return results + + +def bench_varlen(configs): + print("\n" + "=" * 80) + print(" Varlen Benchmark: CuTe DSL (SM90) bwd_dhu vs FLA Triton") + print("=" * 80) + results = [] + + for num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0 in configs: + ( + seq_lens, + cu_seqlens, + cu_seqlens_long, + chunk_indices, + chunk_offsets, + q, + k, + w, + do, + dv, + g, + gk, + dht, + dh0, + ) = make_varlen_inputs(num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0) + + ref = run_fla(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens_long=cu_seqlens_long) + got = run_cute( + q, + k, + w, + do, + dv, + g, + gk, + dht, + dh0, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, ) + torch.cuda.synchronize() + acc = bwd_accuracy_stats(ref, got) + + def run_fla_case(q=q, k=k, w=w, do=do, dv=dv, g=g, gk=gk, dht=dht, dh0=dh0, cu=cu_seqlens_long): + run_fla(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens_long=cu) - def run_cute(): - return chunk_gated_delta_rule_bwd_dhu_sm90( + def run_cute_case( q=q, k=k, w=w, @@ -121,166 +329,165 @@ def run_cute(): dv=dv, g=g, gk=gk, - h0=h0, dht=dht, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=64, - use_exp2=args.use_exp2, - transpose_state_layout=args.transpose_state, - ) + dh0=dh0, + cu=cu_seqlens, + ci=chunk_indices, + co=chunk_offsets, + ): + run_cute(q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens=cu, chunk_indices=ci, chunk_offsets=co) - ref = run_fla() - got = run_cute() - torch.cuda.synchronize() - max_dh = (ref[0].float() - got[0].float()).abs().max().item() - max_dv = (ref[2].float() - got[2].float()).abs().max().item() - max_dh0 = None - if ref[1] is not None: - max_dh0 = (ref[1].float() - got[1].float()).abs().max().item() - - fla_ms = time_kernel(run_fla, args.warmup, args.iters) - cute_ms = time_kernel(run_cute, args.warmup, args.iters) - - shape_tag = f"seq_lens={args.seq_lens}" if is_varlen else f"B={args.B} T={args.T}" - print( - f"bwd_dhu SM90 {shape_tag} H={args.H} K={args.K} V={args.V} " - f"g={args.g} gk={args.gk} dht={args.dht} h0={args.h0} exp2={args.use_exp2} transpose={args.transpose_state}" - ) - if max_dh0 is None: - print(f"max_diff dh={max_dh:.6f} dv2={max_dv:.6f}") - else: - print(f"max_diff dh={max_dh:.6f} dh0={max_dh0:.6f} dv2={max_dv:.6f}") - print(f"FLA Triton: {fla_ms:.4f} ms") - print(f"CuTe DSL : {cute_ms:.4f} ms") - print(f"speedup : {fla_ms / cute_ms:.3f}x") - return { - "B": args.B, - "T": args.T, - "seq_lens": args.seq_lens, - "H": args.H, - "K": args.K, - "V": args.V, - "g": args.g, - "gk": args.gk, - "dht": args.dht, - "h0": args.h0, - "exp2": args.use_exp2, - "transpose": args.transpose_state, - "max_dh": max_dh, - "max_dh0": max_dh0, - "max_dv": max_dv, - "fla_ms": fla_ms, - "cute_ms": cute_ms, - "speedup": fla_ms / cute_ms, - } + ms_fla = time_kernel(run_fla_case) + ms_cute = time_kernel(run_cute_case) + speedup = ms_fla / ms_cute if ms_cute > 0 else float("inf") + min_l, max_l = min(seq_lens), max(seq_lens) + avg_l = total_T // num_seqs + tag = f"{num_seqs}seqs T={total_T} [{min_l}..{max_l}] avg={avg_l}" + flag_str = flags_str(use_g, use_gk, use_dht, use_dh0) -def suite_configs(kind: str): - quick = [ - dict(B=1, T=512, H=4, K=64, V=64, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), - dict(B=1, T=512, H=4, K=128, V=64, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), - dict(B=1, T=512, H=4, K=128, V=128, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), - dict(B=1, T=512, H=2, K=256, V=64, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), - ] - if kind == "quick": - return quick - return quick + [ - dict( - seq_lens=[50, 192, 100], - H=2, - K=64, - V=64, - g=False, - gk=True, - dht=True, - h0=False, - use_exp2=True, - transpose_state=False, - ), - dict( - seq_lens=[33, 128, 200], - H=1, - K=128, - V=64, - g=True, - gk=False, - dht=True, - h0=True, - use_exp2=True, - transpose_state=False, - ), - dict(B=1, T=512, H=4, K=64, V=64, g=True, gk=False, dht=True, h0=False, use_exp2=True, transpose_state=False), - dict(B=1, T=512, H=2, K=128, V=64, g=True, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), - dict(B=2, T=1024, H=4, K=128, V=64, g=False, gk=True, dht=True, h0=True, use_exp2=True, transpose_state=False), - dict(B=1, T=2048, H=8, K=128, V=64, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), - dict(B=1, T=1024, H=8, K=64, V=128, g=False, gk=True, dht=True, h0=False, use_exp2=True, transpose_state=False), - dict(B=1, T=512, H=4, K=128, V=64, g=False, gk=True, dht=True, h0=True, use_exp2=True, transpose_state=True), - ] - - -def _fmt_optional(value): - return "n/a" if value is None else f"{value:.6f}" - - -def print_suite(results): - print("\n" + "=" * 118) - print(" bwd_dhu SM90 Suite: CuTe DSL vs FLA Triton") - print("=" * 118) - for r in results: - flags = ",".join(name for name in ("g", "gk", "dht", "h0", "exp2", "transpose") if r[name]) - shape = f"seqs={r['seq_lens']!s:<17s}" if r["seq_lens"] is not None else f"B={r['B']:2d} T={r['T']:5d}" + r = { + "tag": tag, + "T_total": total_T, + "H": H, + "n_seqs": num_seqs, + "flags": flag_str, + "ms_fla": ms_fla, + "ms_cute": ms_cute, + "speedup": speedup, + **acc, + } + results.append(r) print( - f" {shape} H={r['H']:2d} K={r['K']:3d} V={r['V']:3d} [{flags:<16s}] | " - f"diff dh={r['max_dh']:.6f} dh0={_fmt_optional(r['max_dh0'])} dv2={r['max_dv']:.6f} | " - f"FLA={r['fla_ms']:.4f}ms CuTe={r['cute_ms']:.4f}ms speedup={r['speedup']:.3f}x" + f" {tag:40s} H={H:3d}{flag_str:<18s} | " + f"max={acc['max_diff']:.6f} mean={acc['mean_diff']:.8f} " + f"(dh={acc['dh_max']:.6f} dh0={acc['dh0_max']:.6f} dv2={acc['dv2_max']:.6f}) | " + f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | speedup={speedup:.2f}x" ) - speedups = [r["speedup"] for r in results if r["speedup"] > 0] - geo = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) - print("-" * 118) - print(f" Geometric mean speedup: {geo:.3f}x") + + return results + + +def print_report(nv_results, vl_results): + sep = "=" * 120 + print(f"\n\n{sep}") + print(" BENCHMARK REPORT: chunk_delta_rule_bwd_dhu") + print(" CuTe DSL (Hopper SM90) vs FLA Triton") + print(f" K={K} V={V} BT={BT} dtype=bf16") + wu = 1 if NCU_MODE else WARMUP + ni = 1 if NCU_MODE else N_ITERS + ncu_tag = " [NCU mode]" if NCU_MODE else "" + print(f" Warmup={wu} Iters={ni}{ncu_tag}") + print(sep) + + if nv_results: + print("\n [Non-Varlen]") + print(f" {'-' * 112}") + print(f" {'Config':<37s} | {'max_diff':>10s} {'mean_diff':>12s} | {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}") + print(f" {'-' * 112}") + for r in nv_results: + label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d}{r['flags']}" + print( + f" {label:<37s} | {r['max_diff']:10.6f} {r['mean_diff']:12.8f} | " + f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x" + ) + print(f" {'-' * 112}") + speedups = [r["speedup"] for r in nv_results] + geo = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + print(f" {'Geometric mean':<37s} | {'':>10s} {'':>12s} | {'':>9s} {'':>9s} {geo:7.2f}x") + + if vl_results: + print("\n [Varlen]") + print(f" {'-' * 120}") + print(f" {'Config':>60s} | {'max_diff':>10s} {'mean_diff':>12s} | {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}") + print(f" {'-' * 120}") + for r in vl_results: + label = f"{r['tag']} H={r['H']:3d}{r['flags']}" + print( + f" {label:>60s} | {r['max_diff']:10.6f} {r['mean_diff']:12.8f} | " + f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x" + ) + print(f" {'-' * 120}") + speedups = [r["speedup"] for r in vl_results] + geo = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + print(f" {'Geometric mean':>60s} | {'':>10s} {'':>12s} | {'':>9s} {'':>9s} {geo:7.2f}x") + + print(f"\n{sep}\n") def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--suite", choices=["none", "quick", "full"], default="none") - parser.add_argument("--B", type=int, default=1) - parser.add_argument("--T", type=int, default=1024) - parser.add_argument("--seq-lens", type=str, default=None) - parser.add_argument("--H", type=int, default=8) - parser.add_argument("--K", type=int, default=128, choices=[64, 128, 256]) - parser.add_argument("--V", type=int, default=64) - parser.add_argument("--g", action="store_true") - parser.add_argument("--gk", action="store_true") - parser.add_argument("--dht", action="store_true") - parser.add_argument("--h0", action="store_true") - parser.add_argument("--use-exp2", action="store_true") - parser.add_argument("--transpose-state", action="store_true") - parser.add_argument("--warmup", type=int, default=10) - parser.add_argument("--iters", type=int, default=100) - parser.add_argument("--seed", type=int, default=42) + parser = argparse.ArgumentParser(description="bench_chunk_delta_h_bwd_sm90: CuTe DSL (SM90) vs FLA Triton") + parser.add_argument( + "--mode", + type=str, + default="both", + choices=["non-varlen", "varlen", "both"], + help="Which benchmark mode to run (default: both)", + ) + parser.add_argument( + "--preset", + type=str, + default="representative", + choices=["representative", "fwd"], + help="representative runs a short subset; fwd mirrors bench_chunk_delta_h.py's large default configs", + ) + parser.add_argument("--warmup", type=int, default=None, help="Override warmup iterations") + parser.add_argument("--iters", type=int, default=None, help="Override timed iterations") + parser.add_argument("--ncu", action="store_true", help="NCU profiling mode: warmup=1, iters=1") args = parser.parse_args() - if args.seq_lens is not None: - args.seq_lens = [int(x) for x in args.seq_lens.split(",") if x] if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 9: raise RuntimeError("This benchmark requires an SM90/Hopper GPU.") - if args.suite == "none": - if args.seq_lens is None and args.T % 64 != 0: - raise ValueError("Use T as a multiple of 64 for this benchmark.") - if args.V % 64 != 0: - raise ValueError("Use V as a multiple of 64 for the SM90 WGMMA path.") - run_one(args) - return - results = [] - for cfg in suite_configs(args.suite): - case_args = argparse.Namespace(**vars(args)) - case_args.seq_lens = None - for key, value in cfg.items(): - setattr(case_args, key, value) - results.append(run_one(case_args)) - print_suite(results) + global NCU_MODE, WARMUP, N_ITERS + if args.ncu: + NCU_MODE = True + print("[NCU mode] warmup=1, iters=1") + if args.warmup is not None: + WARMUP = args.warmup + if args.iters is not None: + N_ITERS = args.iters + + if args.preset == "fwd": + # Matches bench_chunk_delta_h.py's default dimensions. + # Tuple: (B, T, H, use_g, use_gk, use_dht, use_dh0) + non_varlen_configs = [ + (1, 8192, 64, False, True, True, True), + (2, 8192, 64, False, True, True, True), + (4, 8192, 64, False, True, True, True), + (8, 8192, 64, False, True, True, True), + ] + + # Tuple: (num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0) + varlen_configs = [ + (20, 8192, 64, 2.0, False, True, True, True), + (25, 8192, 64, 3.0, False, True, True, True), + (20, 8192, 64, 4.0, False, True, True, True), + (20, 32768, 64, 2.0, False, True, True, True), + (25, 32768, 64, 3.0, False, True, True, True), + ] + else: + # Short representative subset for day-to-day iteration. + # Tuple: (B, T, H, use_g, use_gk, use_dht, use_dh0) + non_varlen_configs = [ + (1, 512, 4, False, True, True, False), + (1, 512, 4, True, False, True, False), + (2, 1024, 64, False, True, True, True), + (1, 2048, 64, False, True, True, False), + ] + + # Tuple: (num_seqs, total_T, H, ratio, use_g, use_gk, use_dht, use_dh0) + varlen_configs = [ + (3, 512, 2, 3.0, False, True, True, False), + (4, 768, 2, 4.0, True, False, True, True), + ] + + nv_res, vl_res = [], [] + if args.mode in ("non-varlen", "both"): + nv_res = bench_non_varlen(non_varlen_configs) + if args.mode in ("varlen", "both"): + vl_res = bench_varlen(varlen_configs) + print_report(nv_res, vl_res) if __name__ == "__main__": diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 92f6522..412f3b9 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -15,10 +15,10 @@ """ SM90 CuTe DSL implementation for chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64. -This is the first Hopper tensor-core path: +This Hopper tensor-core path is scoped to match cula/ops/chunk_delta_h.py: - fixed chunk size BT=64 - - BV=64, matching cula/ops/chunk_delta_h.py - - non-varlen tensors [B, T, H, D] + - K=V=128, BV=64 + - non-varlen tensors [B, T, H, D] and packed varlen tensors - state layout [B, NT, H, K, V] or [B, NT, H, V, K] - non-persistent scheduling @@ -52,6 +52,7 @@ BT = 64 BV = 64 +BK = 128 NUM_THREADS = 128 @@ -79,8 +80,9 @@ def __init__( scale: float, use_fast_math: bool = True, ): - assert head_dim_k in (64, 128, 256), f"SM90 bwd_dhu supports K in {{64, 128, 256}}, got {head_dim_k}" - assert head_dim_v % BV == 0, f"SM90 bwd_dhu tensor-core path requires V to be a multiple of {BV}, got {head_dim_v}" + assert head_dim_k == 128 and head_dim_v == 128, ( + f"SM90 bwd_dhu currently aligns with ChunkDeltaRuleFwdH and requires K=V=128, got K={head_dim_k}, V={head_dim_v}" + ) self.B = batch_size self.T = seq_len self.N = num_sequences @@ -100,16 +102,24 @@ def __init__( self.BT = BT self.BV = BV + self.BK = head_dim_k + self.num_k_blocks = head_dim_k // self.BK self.num_v_tiles = (head_dim_v + BV - 1) // BV self.num_threads = NUM_THREADS + self.num_regs_compute = 232 + self.input_stage = 2 self.io_dtype = cutlass.BFloat16 self.acc_dtype = cutlass.Float32 self.buffer_align_bytes = 1024 - self.mma_tiler = (BT, BV, head_dim_k) - self.update_mma_tiler = (BV, head_dim_k, BT) + self.mma_tiler = (BT, BV, self.BK) + self.update_mma_tiler = (BV, self.BK, BT) self.atom_layout_mnk = (1, 1, 1) self.cluster_shape_mnk = (1, 1, 1) + self.gk_precompute_bar = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.num_threads, + ) @cute.jit def __call__( @@ -222,6 +232,7 @@ def __call__( ) q_kt = cute.make_tensor(q_ptr, kt_layout) w_kt = cute.make_tensor(w_ptr, kt_layout) + gk_kt = cute.make_tensor(gk_ptr, kt_layout) vt_layout = cute.make_layout( (self.V, self.T, (self.H, self.B)), stride=(1, self.H * self.V, (self.V, self.T * self.H * self.V)) @@ -254,7 +265,7 @@ def __call__( utils.LayoutEnum.ROW_MAJOR, self.mma_tiler, self.io_dtype, - 1, + self.input_stage, ) b_smem_layout_staged = sm90_utils.make_smem_layout_b( utils.LayoutEnum.ROW_MAJOR, @@ -266,15 +277,22 @@ def __call__( utils.LayoutEnum.COL_MAJOR, self.update_mma_tiler, self.io_dtype, - 1, + self.input_stage, ) update_b_smem_layout_staged = sm90_utils.make_smem_layout_b( utils.LayoutEnum.COL_MAJOR, self.update_mma_tiler, self.io_dtype, - 1, + self.input_stage, + ) + dv_smem_layout_staged = cute.make_layout( + (self.BV, self.BT, self.input_stage), + stride=(1, self.BV, self.BV * self.BT), + ) + gk_smem_layout_staged = cute.make_layout( + (self.BK, 1, self.input_stage), + stride=(1, self.BK, self.BK), ) - dv_smem_layout_staged = cute.make_layout((self.BV, self.BT, 1), stride=(1, self.BV, self.BV * self.BT)) tma_load_op = cpasync.CopyBulkTensorTileG2SOp() tma_store_op = cpasync.CopyBulkTensorTileS2GOp() @@ -283,7 +301,7 @@ def __call__( tma_load_op, k_tk, cute.slice_(a_smem_layout_staged, (None, None, 0)), - (self.BT, self.K), + (self.BT, self.BK), ) tma_atom_dv, tma_tensor_dv = cpasync.make_tiled_tma_atom( tma_load_op, @@ -301,13 +319,19 @@ def __call__( tma_load_op, q_kt, cute.slice_(update_b_smem_layout_staged, (None, None, 0)), - (self.K, self.BT), + (self.BK, self.BT), ) tma_atom_w, tma_tensor_w = cpasync.make_tiled_tma_atom( tma_load_op, w_kt, cute.slice_(update_b_smem_layout_staged, (None, None, 0)), - (self.K, self.BT), + (self.BK, self.BT), + ) + tma_atom_gk, tma_tensor_gk = cpasync.make_tiled_tma_atom( + tma_load_op, + gk_kt, + cute.slice_(gk_smem_layout_staged, (None, None, 0)), + (self.BK, 1), ) tma_atom_dv2, tma_tensor_dv2 = cpasync.make_tiled_tma_atom( tma_store_op, @@ -315,19 +339,21 @@ def __call__( cute.slice_(dv_smem_layout_staged, (None, None, 0)), (self.BV, self.BT), ) - self.tma_kdv_bytes = cute.size_in_bytes( - self.io_dtype, cute.slice_(a_smem_layout_staged, (None, None, 0)) - ) + cute.size_in_bytes(self.io_dtype, cute.slice_(dv_smem_layout_staged, (None, None, 0))) - self.tma_qdo_bytes = cute.size_in_bytes( - self.io_dtype, cute.slice_(update_a_smem_layout_staged, (None, None, 0)) - ) + cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) + self.tma_k_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(a_smem_layout_staged, (None, None, 0))) + self.tma_dv_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(dv_smem_layout_staged, (None, None, 0))) + self.tma_do_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_a_smem_layout_staged, (None, None, 0))) + self.tma_q_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) + self.tma_gk_bytes = cute.size_in_bytes(cutlass.Float32, cute.slice_(gk_smem_layout_staged, (None, None, 0))) @cute.struct class SharedStorage: - load_kdv_mbar: cute.struct.MemRange[Int64, 2] - load_qdo_mbar: cute.struct.MemRange[Int64, 2] - load_w_mbar: cute.struct.MemRange[Int64, 2] + load_k_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] + load_dv_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] + load_do_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] + load_q_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] + load_w_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] + load_gk_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] sA: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(a_smem_layout_staged)], self.buffer_align_bytes, @@ -340,10 +366,22 @@ class SharedStorage: cute.struct.MemRange[self.io_dtype, cute.cosize(update_a_smem_layout_staged)], self.buffer_align_bytes, ] + sDo: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(update_a_smem_layout_staged)], + self.buffer_align_bytes, + ] + sGK: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, BK * self.input_stage], + 128, + ] sUB: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(update_b_smem_layout_staged)], self.buffer_align_bytes, ] + sW: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(update_b_smem_layout_staged)], + self.buffer_align_bytes, + ] sDv2T: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(dv_smem_layout_staged)], self.buffer_align_bytes, @@ -381,6 +419,8 @@ class SharedStorage: tma_tensor_q, tma_atom_w, tma_tensor_w, + tma_atom_gk, + tma_tensor_gk, tma_atom_dv2, tma_tensor_dv2, ).launch( @@ -423,6 +463,8 @@ def kernel( tma_tensor_q: cute.Tensor, tma_atom_w: cute.CopyAtom, tma_tensor_w: cute.Tensor, + tma_atom_gk: cute.CopyAtom, + tma_tensor_gk: cute.Tensor, tma_atom_dv2: cute.CopyAtom, tma_tensor_dv2: cute.Tensor, ): @@ -452,8 +494,11 @@ def kernel( sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) sUA = storage.sUA.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) + sDo = storage.sDo.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) + sGK = storage.sGK.get_tensor(cute.make_layout((BK, 1, self.input_stage), stride=(1, BK, BK))) sUB = storage.sUB.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) - sDv2T = storage.sDv2T.get_tensor(cute.make_layout((BV, BT, 1), stride=(1, BV, BV * BT))) + sW = storage.sW.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) + sDv2T = storage.sDv2T.get_tensor(cute.make_layout((BV, BT, self.input_stage), stride=(1, BV, BV * BT))) if warp_idx == 0: cpasync.prefetch_descriptor(tma_atom_k) @@ -461,51 +506,51 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_do) cpasync.prefetch_descriptor(tma_atom_q) cpasync.prefetch_descriptor(tma_atom_w) + cpasync.prefetch_descriptor(tma_atom_gk) cpasync.prefetch_descriptor(tma_atom_dv2) - load_kdv_P, load_kdv_C = pipeline.PipelineTmaAsync.create( - num_stages=1, + load_k_P, load_k_C = pipeline.PipelineTmaAsync.create( + num_stages=self.input_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_threads // 32), - tx_count=self.tma_kdv_bytes, - barrier_storage=storage.load_kdv_mbar.data_ptr(), + tx_count=self.tma_k_bytes, + barrier_storage=storage.load_k_mbar.data_ptr(), ).make_participants() - load_qdo_P, load_qdo_C = pipeline.PipelineTmaAsync.create( - num_stages=1, + load_dv_P, load_dv_C = pipeline.PipelineTmaAsync.create( + num_stages=self.input_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_threads // 32), - tx_count=self.tma_qdo_bytes, - barrier_storage=storage.load_qdo_mbar.data_ptr(), + tx_count=self.tma_dv_bytes, + barrier_storage=storage.load_dv_mbar.data_ptr(), + ).make_participants() + load_do_P, load_do_C = pipeline.PipelineTmaAsync.create( + num_stages=self.input_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_threads // 32), + tx_count=self.tma_do_bytes, + barrier_storage=storage.load_do_mbar.data_ptr(), + ).make_participants() + load_q_P, load_q_C = pipeline.PipelineTmaAsync.create( + num_stages=self.input_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_threads // 32), + tx_count=self.tma_q_bytes, + barrier_storage=storage.load_q_mbar.data_ptr(), ).make_participants() load_w_P, load_w_C = pipeline.PipelineTmaAsync.create( - num_stages=1, + num_stages=self.input_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_threads // 32), tx_count=self.tma_w_bytes, barrier_storage=storage.load_w_mbar.data_ptr(), ).make_participants() - - thr_mma = tiled_mma.get_slice(tidx) - update_thr_mma = update_tiled_mma.get_slice(tidx) - - tCsA = thr_mma.partition_A(sA) - tCsB = thr_mma.partition_B(sB) - tCrA = thr_mma.make_fragment_A(tCsA) - tCrB = thr_mma.make_fragment_B(tCsB) - tUsA = update_thr_mma.partition_A(sUA) - tUsB = update_thr_mma.partition_B(sUB) - tUrA = update_thr_mma.make_fragment_A(tUsA) - tUrB = update_thr_mma.make_fragment_B(tUsB) - - cDV = cute.make_identity_tensor((BT, BV)) - tCcDV = thr_mma.partition_C(cDV) - acc_dv = thr_mma.make_fragment_C(thr_mma.partition_shape_C((BT, BV))) - - cState = cute.make_identity_tensor((BV, self.K)) - tUcState = update_thr_mma.partition_C(cState) - rState = update_thr_mma.make_fragment_C(update_thr_mma.partition_shape_C((BV, self.K))) - acc_qdo = update_thr_mma.make_fragment_C(update_thr_mma.partition_shape_C((BV, self.K))) - acc_wdv = update_thr_mma.make_fragment_C(update_thr_mma.partition_shape_C((BV, self.K))) + load_gk_P, load_gk_C = pipeline.PipelineTmaAsync.create( + num_stages=self.input_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_threads // 32), + tx_count=self.tma_gk_bytes, + barrier_storage=storage.load_gk_mbar.data_ptr(), + ).make_participants() if cutlass.const_expr(self.is_varlen): tma_tensor_k_use = cute.domain_offset((seq_start, 0, (0, 0)), tma_tensor_k) @@ -513,6 +558,7 @@ def kernel( tma_tensor_do_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_do) tma_tensor_q_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_q) tma_tensor_w_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_w) + tma_tensor_gk_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_gk) tma_tensor_dv2_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_dv2) else: tma_tensor_k_use = tma_tensor_k @@ -520,41 +566,109 @@ def kernel( tma_tensor_do_use = tma_tensor_do tma_tensor_q_use = tma_tensor_q tma_tensor_w_use = tma_tensor_w + tma_tensor_gk_use = tma_tensor_gk tma_tensor_dv2_use = tma_tensor_dv2 _, bSG_sK, bSG_gK = self._epilog_partition( - tma_atom_k, tma_tensor_k_use[None, None, (i_h, data_b)], (self.BT, self.K), sA + tma_atom_k, tma_tensor_k_use[None, None, (i_h, data_b)], (self.BT, self.BK), sA ) _, bSG_sDv, bSG_gDv = self._epilog_partition( tma_atom_dv, tma_tensor_dv_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDv2T ) _, bSG_sDo, bSG_gDo = self._epilog_partition( - tma_atom_do, tma_tensor_do_use[None, None, (i_h, data_b)], (self.BV, self.BT), sUA + tma_atom_do, tma_tensor_do_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDo ) _, bSG_sQ, bSG_gQ = self._epilog_partition( - tma_atom_q, tma_tensor_q_use[None, None, (i_h, data_b)], (self.K, self.BT), sUB + tma_atom_q, tma_tensor_q_use[None, None, (i_h, data_b)], (self.BK, self.BT), sUB ) _, bSG_sW, bSG_gW = self._epilog_partition( - tma_atom_w, tma_tensor_w_use[None, None, (i_h, data_b)], (self.K, self.BT), sUB + tma_atom_w, tma_tensor_w_use[None, None, (i_h, data_b)], (self.BK, self.BT), sW + ) + _, bSG_sGK, bSG_gGK = self._epilog_partition( + tma_atom_gk, tma_tensor_gk_use[None, None, (i_h, data_b)], (self.BK, 1), sGK ) _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( tma_atom_dv2, tma_tensor_dv2_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDv2T ) - # Initialize carried dh state. - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_idx = tUcState[ei] - v_idx = v_base + v_rel - init = Float32(0.0) - if cutlass.const_expr(self.use_dht): - if cutlass.const_expr(self.transpose_state_layout): - init = dht[i_n, i_h, v_idx, k_idx].to(self.acc_dtype) - else: - init = dht[i_n, i_h, k_idx, v_idx].to(self.acc_dtype) - rState[ei] = init + cute.arch.setmaxregister_increase(self.num_regs_compute) + + thr_mma = tiled_mma.get_slice(tidx) + update_thr_mma = update_tiled_mma.get_slice(tidx) + + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCrA = thr_mma.make_fragment_A(tCsA) + tCrB = thr_mma.make_fragment_B(tCsB) + tUsA = update_thr_mma.partition_A(sUA) + tUsDo = update_thr_mma.partition_A(sDo) + tUsB = update_thr_mma.partition_B(sUB) + tWsB = update_thr_mma.partition_B(sW) + tUrA = update_thr_mma.make_fragment_A(tUsA) + tUrDo = update_thr_mma.make_fragment_A(tUsDo) + tUrB = update_thr_mma.make_fragment_B(tUsB) + tWrB = update_thr_mma.make_fragment_B(tWsB) + + cDV = cute.make_identity_tensor((BT, BV)) + tCcDV = thr_mma.partition_C(cDV) + acc_dv = thr_mma.make_fragment_C(thr_mma.partition_shape_C((BT, BV))) + + cState = cute.make_identity_tensor((BV, self.BK)) + tUcState = update_thr_mma.partition_C(cState) + state_shape = update_thr_mma.partition_shape_C((BV, self.BK)) + rState0 = update_thr_mma.make_fragment_C(state_shape) + if cutlass.const_expr(self.num_k_blocks == 1): + rStates = (rState0,) + elif cutlass.const_expr(self.num_k_blocks == 2): + rState1 = update_thr_mma.make_fragment_C(state_shape) + rStates = (rState0, rState1) + else: + rState1 = update_thr_mma.make_fragment_C(state_shape) + rState2 = update_thr_mma.make_fragment_C(state_shape) + rState3 = update_thr_mma.make_fragment_C(state_shape) + rStates = (rState0, rState1, rState2, rState3) + acc_qdo = update_thr_mma.make_fragment_C(state_shape) + acc_wdv = update_thr_mma.make_fragment_C(state_shape) + + # Initialize carried dh state in register blocks. + for k_block in cutlass.range_constexpr(self.num_k_blocks): + k_base = k_block * self.BK + rState = rStates[k_block] + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + v_idx = v_base + v_rel + k_idx = k_base + k_rel + init = Float32(0.0) + if cutlass.const_expr(self.use_dht): + if cutlass.const_expr(self.transpose_state_layout): + init = dht[i_n, i_h, v_idx, k_idx].to(self.acc_dtype) + else: + init = dht[i_n, i_h, k_idx, v_idx].to(self.acc_dtype) + rState[ei] = init + + if warp_idx == 0 and NT > 0: + first_chunk = NT - 1 + k_h = load_k_P.acquire_and_advance() + cute.copy(tma_atom_k, bSG_gK[(None, first_chunk, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) + dv_h = load_dv_P.acquire_and_advance() + cute.copy( + tma_atom_dv, + bSG_gDv[(None, i_v_tile, first_chunk)], + bSG_sDv[None, dv_h.index], + tma_bar_ptr=dv_h.barrier, + ) + if cutlass.const_expr(self.use_gk): + gk_h = load_gk_P.acquire_and_advance() + cute.copy( + tma_atom_gk, + bSG_gGK[(None, 0, seq_len - 1)], + bSG_sGK[None, gk_h.index], + tma_bar_ptr=gk_h.barrier, + ) for chunk_rev in cutlass.range(0, NT, unroll=0): i_t = NT - 1 - chunk_rev + next_i_t = i_t - 1 chunk_start = i_t * self.BT chunk_end = cutlass.min(chunk_start + self.BT, seq_len) last_idx = chunk_end - 1 @@ -568,50 +682,75 @@ def kernel( g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) # Store dh before applying this chunk's reverse update. - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_idx = tUcState[ei] - v_idx = v_base + v_rel - if cutlass.const_expr(self.transpose_state_layout): - dh[state_b, chunk_base + i_t, i_h, v_idx, k_idx] = rState[ei].to(dh.element_type) - else: - dh[state_b, chunk_base + i_t, i_h, k_idx, v_idx] = rState[ei].to(dh.element_type) - cute.arch.sync_threads() + for k_block in cutlass.range_constexpr(self.num_k_blocks): + k_base = k_block * self.BK + rState = rStates[k_block] + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + v_idx = v_base + v_rel + k_idx = k_base + k_rel + if cutlass.const_expr(self.transpose_state_layout): + dh[state_b, chunk_base + i_t, i_h, v_idx, k_idx] = rState[ei].to(dh.element_type) + else: + dh[state_b, chunk_base + i_t, i_h, k_idx, v_idx] = rState[ei].to(dh.element_type) - # dv2 = dv + K @ dh. - acc_dv.fill(0.0) - if warp_idx == 0: - kdv_h = load_kdv_P.acquire_and_advance() - cute.copy(tma_atom_k, bSG_gK[(None, i_t, 0)], bSG_sK[None, kdv_h.index], tma_bar_ptr=kdv_h.barrier) + if warp_idx == 0 and next_i_t >= 0: + k_h = load_k_P.acquire_and_advance() + cute.copy(tma_atom_k, bSG_gK[(None, next_i_t, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) + dv_h = load_dv_P.acquire_and_advance() cute.copy( tma_atom_dv, - bSG_gDv[(None, i_v_tile, i_t)], - bSG_sDv[None, kdv_h.index], - tma_bar_ptr=kdv_h.barrier, + bSG_gDv[(None, i_v_tile, next_i_t)], + bSG_sDv[None, dv_h.index], + tma_bar_ptr=dv_h.barrier, ) + if cutlass.const_expr(self.use_gk): + next_gk_idx = cutlass.min(next_i_t * self.BT + self.BT, seq_len) - 1 + gk_h = load_gk_P.acquire_and_advance() + cute.copy( + tma_atom_gk, + bSG_gGK[(None, 0, next_gk_idx)], + bSG_sGK[None, gk_h.index], + tma_bar_ptr=gk_h.barrier, + ) - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_idx = tUcState[ei] - sB[v_rel, k_idx, 0] = rState[ei].to(self.io_dtype) + # dv2 = dv + K @ dh. + acc_dv.fill(0.0) + for k_block in cutlass.range_constexpr(self.num_k_blocks): + rState = rStates[k_block] + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + sB[v_rel, k_rel, 0] = rState[ei].to(self.io_dtype) - kdv_wait = load_kdv_C.wait_and_advance() - cute.arch.sync_threads() + k_wait = load_k_C.wait_and_advance() + cute.arch.sync_threads() - cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True): - tiled_mma.set( - cute.nvgpu.warpgroup.Field.ACCUMULATE, - cutlass.Boolean(kp != 0), - ) - cute.gemm( - tiled_mma, - acc_dv, - tCrA[None, None, kp, 0], - tCrB[None, None, kp, 0], - acc_dv, - ) - cute.nvgpu.warpgroup.commit_group() - cute.nvgpu.warpgroup.wait_group(0) - kdv_wait.release() + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True): + tiled_mma.set( + cute.nvgpu.warpgroup.Field.ACCUMULATE, + cutlass.Boolean((k_block != 0) or (kp != 0)), + ) + cute.gemm( + tiled_mma, + acc_dv, + tCrA[None, None, kp, k_wait.index], + tCrB[None, None, kp, 0], + acc_dv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + k_wait.release() + + if warp_idx == 0: + do_h = load_do_P.acquire_and_advance() + cute.copy(tma_atom_do, bSG_gDo[(None, i_v_tile, i_t)], bSG_sDo[None, do_h.index], tma_bar_ptr=do_h.barrier) + q_h = load_q_P.acquire_and_advance() + cute.copy(tma_atom_q, bSG_gQ[(None, 0, i_t)], bSG_sQ[None, q_h.index], tma_bar_ptr=q_h.barrier) + w_h = load_w_P.acquire_and_advance() + cute.copy(tma_atom_w, bSG_gW[(None, 0, i_t)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) + + dv_wait = load_dv_C.wait_and_advance() cute.arch.sync_threads() for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): @@ -628,113 +767,129 @@ def kernel( else: g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) out = out * g_decay - out = out + sDv2T[v_rel, t_rel, 0].to(self.acc_dtype) - sDv2T[v_rel, t_rel, 0] = out.to(self.io_dtype) + out = out + sDv2T[v_rel, t_rel, dv_wait.index].to(self.acc_dtype) + sDv2T[v_rel, t_rel, dv_wait.index] = out.to(self.io_dtype) cute.arch.fence_proxy("async.shared", space="cta") cute.arch.sync_threads() - if warp_idx == 0: - cute.copy(tma_atom_dv2, bSG_sDv2[None, 0], bSG_gDv2[(None, i_v_tile, i_t)]) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - cute.arch.sync_threads() - - # Apply state decay after dv2, before accumulating QO - WV into dh. - if cutlass.const_expr(self.use_g): - for ei in cutlass.range(cute.size(rState), unroll_full=True): - rState[ei] = rState[ei] * g_last_exp - if cutlass.const_expr(self.use_gk): - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_idx = tUcState[ei] - gk_last = gk[data_b, seq_start + last_idx, i_h, k_idx].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) - else: - k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) - rState[ei] = rState[ei] * k_decay + remaining = chunk_end - chunk_start + if remaining < self.BT: + linear_store = tidx + while linear_store < self.BV * self.BT: + v_rel = linear_store // self.BT + t_rel = linear_store - v_rel * self.BT + if t_rel < remaining: + dv2[data_b, seq_start + chunk_start + t_rel, i_h, v_base + v_rel] = sDv2T[v_rel, t_rel, dv_wait.index] + linear_store += self.num_threads + else: + if warp_idx == 0: + cute.copy(tma_atom_dv2, bSG_sDv2[None, dv_wait.index], bSG_gDv2[(None, i_v_tile, i_t)]) + cute.arch.cp_async_bulk_commit_group() # dh += scale * do^T @ q - dv2^T @ w. - if warp_idx == 0: - qdo_h = load_qdo_P.acquire_and_advance() - cute.copy(tma_atom_do, bSG_gDo[(None, i_v_tile, i_t)], bSG_sDo[None, qdo_h.index], tma_bar_ptr=qdo_h.barrier) - cute.copy(tma_atom_q, bSG_gQ[(None, 0, i_t)], bSG_sQ[None, qdo_h.index], tma_bar_ptr=qdo_h.barrier) - qdo_wait = load_qdo_C.wait_and_advance() - cute.arch.sync_threads() - - if cutlass.const_expr(self.use_g): - linear_q = tidx - while linear_q < self.K * self.BT: - k_rel = linear_q // self.BT - t_rel = linear_q - k_rel * self.BT - t_idx = chunk_start + t_rel - q_scaled = Float32(0.0) - if t_idx < seq_len: - g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) - else: - g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) - q_scaled = sUB[k_rel, t_rel, 0].to(self.acc_dtype) * g_exp - sUB[k_rel, t_rel, 0] = q_scaled.to(self.io_dtype) - linear_q += self.num_threads - cute.arch.sync_threads() - - acc_qdo.fill(0.0) - cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): - update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - update_tiled_mma, - acc_qdo, - tUrA[None, None, kp, 0], - tUrB[None, None, kp, 0], - acc_qdo, - ) - cute.nvgpu.warpgroup.commit_group() - cute.nvgpu.warpgroup.wait_group(0) - qdo_wait.release() - cute.arch.sync_threads() - - if warp_idx == 0: - w_h = load_w_P.acquire_and_advance() - cute.copy(tma_atom_w, bSG_gW[(None, 0, i_t)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) - linear = tidx while linear < self.BV * self.BT: v_rel = linear // self.BT t_rel = linear - v_rel * self.BT - sUA[v_rel, t_rel, 0] = sDv2T[v_rel, t_rel, 0] + sUA[v_rel, t_rel, 0] = sDv2T[v_rel, t_rel, dv_wait.index] linear += self.num_threads - w_wait = load_w_C.wait_and_advance() - cute.arch.sync_threads() + do_wait = load_do_C.wait_and_advance() - acc_wdv.fill(0.0) - cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): - update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - update_tiled_mma, - acc_wdv, - tUrA[None, None, kp, 0], - tUrB[None, None, kp, 0], - acc_wdv, - ) - cute.nvgpu.warpgroup.commit_group() - cute.nvgpu.warpgroup.wait_group(0) - w_wait.release() + if cutlass.const_expr(self.use_g or self.is_varlen): + linear_do = tidx + while linear_do < self.BV * self.BT: + v_rel = linear_do // self.BT + t_rel = linear_do - v_rel * self.BT + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_g): + g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) + else: + g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) + do_scaled = do_scaled * g_exp + sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) + linear_do += self.num_threads + cute.arch.sync_threads() - for ei in cutlass.range(cute.size(rState), unroll_full=True): - update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] - rState[ei] = rState[ei] + update - cute.arch.sync_threads() + for k_block in cutlass.range_constexpr(self.num_k_blocks): + rState = rStates[k_block] + q_wait = load_q_C.wait_and_advance() + + acc_qdo.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_qdo, + tUrDo[None, None, kp, do_wait.index], + tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) + cute.nvgpu.warpgroup.commit_group() + + # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. + if cutlass.const_expr(self.use_g): + for ei in cutlass.range(cute.size(rState), unroll_full=True): + rState[ei] = rState[ei] * g_last_exp + if cutlass.const_expr(self.use_gk): + gk_wait = load_gk_C.wait_and_advance() + gk_last = sGK[tidx, 0, gk_wait.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + sGK[tidx, 0, gk_wait.index] = k_decay + self.gk_precompute_bar.arrive_and_wait() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait.index] + gk_wait.release() + + cute.nvgpu.warpgroup.wait_group(0) + q_wait.release() + + w_wait = load_w_C.wait_and_advance() + + acc_wdv.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_wdv, + tUrA[None, None, kp, 0], + tWrB[None, None, kp, w_wait.index], + acc_wdv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + w_wait.release() + + for ei in cutlass.range(cute.size(rState), unroll_full=True): + update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] + rState[ei] = rState[ei] + update + do_wait.release() + if remaining >= self.BT: + if warp_idx == 0: + cute.arch.cp_async_bulk_wait_group(0, read=True) + dv_wait.release() if cutlass.const_expr(self.use_dh0): - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_idx = tUcState[ei] - v_idx = v_base + v_rel - if cutlass.const_expr(self.transpose_state_layout): - dh0[i_n, i_h, v_idx, k_idx] = rState[ei] - else: - dh0[i_n, i_h, k_idx, v_idx] = rState[ei] + for k_block in cutlass.range_constexpr(self.num_k_blocks): + k_base = k_block * self.BK + rState = rStates[k_block] + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + v_idx = v_base + v_rel + k_idx = k_base + k_rel + if cutlass.const_expr(self.transpose_state_layout): + dh0[i_n, i_h, v_idx, k_idx] = rState[ei] + else: + dh0[i_n, i_h, k_idx, v_idx] = rState[ei] @cute.jit def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): @@ -846,6 +1001,7 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( cu_seqlens: torch.Tensor | None = None, chunk_size: int = BT, chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, use_exp2: bool = False, transpose_state_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: @@ -859,10 +1015,8 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( is_varlen = cu_seqlens is not None if is_varlen and B != 1: raise ValueError("varlen mode expects packed inputs with shape [1, total_T, H, D].") - if K not in (64, 128, 256): - raise NotImplementedError(f"SM90 bwd_dhu only supports K in {{64, 128, 256}}, got K={K}.") - if V % BV != 0: - raise NotImplementedError(f"SM90 bwd_dhu WGMMA path requires V to be a multiple of {BV}, got V={V}.") + if K != 128 or V != 128: + raise NotImplementedError(f"SM90 bwd_dhu currently aligns with fwd and only supports K=V=128, got K={K}, V={V}.") if q.dtype != torch.bfloat16 or k.dtype != torch.bfloat16 or w.dtype != torch.bfloat16: raise TypeError("q, k, and w must be bfloat16 for the SM90 bwd_dhu path.") if do.dtype != torch.bfloat16 or dv.dtype != torch.bfloat16: @@ -877,13 +1031,18 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( raise ValueError("cu_seqlens must be contiguous and on the same CUDA device as q.") if chunk_indices is not None and (chunk_indices.device != q.device or not chunk_indices.is_contiguous()): raise ValueError("chunk_indices must be contiguous and on the same CUDA device as q.") + if chunk_offsets is not None and (chunk_offsets.device != q.device or not chunk_offsets.is_contiguous()): + raise ValueError("chunk_offsets must be contiguous and on the same CUDA device as q.") if is_varlen: if chunk_indices is None: chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) N = len(cu_seqlens) - 1 NT = len(chunk_indices) - chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT).int() + if chunk_offsets is None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT).int() + elif chunk_offsets.dtype != torch.int32: + chunk_offsets = chunk_offsets.int() cu_seqlens_arg = cu_seqlens.int() if cu_seqlens.dtype != torch.int32 else cu_seqlens else: N = B diff --git a/tests/test_chunk_delta_h_bwd_sm90.py b/tests/test_chunk_delta_h_bwd_sm90.py index a5a4d1e..cd93a8d 100644 --- a/tests/test_chunk_delta_h_bwd_sm90.py +++ b/tests/test_chunk_delta_h_bwd_sm90.py @@ -2,7 +2,12 @@ # Copyright 2025-2026 Ant Group Co., Ltd. # SPDX-License-Identifier: Apache-2.0 -"""Correctness tests for the SM90 CuTe DSL WGMMA bwd_dhu path.""" +"""Representative correctness tests for the SM90 CuTe DSL WGMMA bwd_dhu path. + +These cases follow the same logic as tests/test_chunk_delta_h.py but avoid the +full Cartesian sweep during kernel iteration. For bwd_dhu, fwd's +initial_state/output_final_state pair maps to dht/dh0. +""" import os import sys @@ -16,6 +21,11 @@ from cula.ops.chunk_delta_h_bwd import chunk_gated_delta_rule_bwd_dhu_sm90 +BT = 64 +ATOL = 3e-2 +RTOL = 3e-2 +device = "cuda" + def _is_sm90() -> bool: return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 9 @@ -27,77 +37,21 @@ def _is_sm90() -> bool: ] -def _make_inputs( - B, - T, - H, - K, - V, - use_g=False, - use_gk=False, - use_dht=False, - use_h0=False, - seed=42, - transpose_state_layout=False, -): - torch.manual_seed(seed) - q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - w = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - do = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 - dv = torch.randn(B, T, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 - - g = None - if use_g: - g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) - - gk = None - if use_gk: - gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device="cuda") * 0.01).cumsum(dim=1) - - dht = None - if use_dht: - state_shape = (B, H, V, K) if transpose_state_layout else (B, H, K, V) - dht = torch.randn(state_shape, dtype=torch.float32, device="cuda") * 0.01 - - h0 = None - if use_h0: - state_shape = (B, H, V, K) if transpose_state_layout else (B, H, K, V) - h0 = torch.empty(state_shape, dtype=torch.float32, device="cuda") - - return q, k, w, do, dv, g, gk, dht, h0 - - -def _run_case( - B, - T, - H, - K, - V, - use_g=False, - use_gk=False, - use_dht=False, - use_h0=False, - use_exp2=False, +def run_fla_ref( + q, + k, + w, + do, + dv, + g=None, + gk=None, + dht=None, + dh0=None, + cu_seqlens=None, + use_exp2=True, transpose_state_layout=False, - seed=42, ): - q, k, w, do, dv, g, gk, dht, h0 = _make_inputs( - B, - T, - H, - K, - V, - use_g, - use_gk, - use_dht, - use_h0, - seed=seed, - transpose_state_layout=transpose_state_layout, - ) - scale = K**-0.5 - - ref_dh, ref_dh0, ref_dv2 = fla_bwd_dhu( + return fla_bwd_dhu( q=q, k=k, w=w, @@ -105,15 +59,31 @@ def _run_case( dv=dv, g=g, gk=gk, - h0=h0, + h0=dh0, dht=dht, - scale=scale, - chunk_size=64, + scale=q.shape[-1] ** -0.5, + cu_seqlens=cu_seqlens.long() if cu_seqlens is not None else None, + chunk_size=BT, use_exp2=use_exp2, transpose_state_layout=transpose_state_layout, ) - got_dh, got_dh0, got_dv2 = chunk_gated_delta_rule_bwd_dhu_sm90( + +def run_cute_dsl( + q, + k, + w, + do, + dv, + g=None, + gk=None, + dht=None, + dh0=None, + cu_seqlens=None, + use_exp2=True, + transpose_state_layout=False, +): + return chunk_gated_delta_rule_bwd_dhu_sm90( q=q, k=k, w=w, @@ -121,221 +91,198 @@ def _run_case( dv=dv, g=g, gk=gk, - h0=h0, + h0=dh0, dht=dht, - scale=scale, - chunk_size=64, + scale=q.shape[-1] ** -0.5, + cu_seqlens=cu_seqlens, + chunk_size=BT, use_exp2=use_exp2, transpose_state_layout=transpose_state_layout, ) - torch.testing.assert_close(got_dh.float(), ref_dh.float(), atol=3e-2, rtol=3e-2) - torch.testing.assert_close(got_dv2.float(), ref_dv2.float(), atol=3e-2, rtol=3e-2) - if use_h0: + +def _assert_bwd_close(got, ref, expect_dh0, msg): + got_dh, got_dh0, got_dv2 = got + ref_dh, ref_dh0, ref_dv2 = ref + torch.testing.assert_close(got_dh.float(), ref_dh.float(), atol=ATOL, rtol=RTOL, msg=f"{msg}: dh") + torch.testing.assert_close(got_dv2.float(), ref_dv2.float(), atol=ATOL, rtol=RTOL, msg=f"{msg}: dv2") + if expect_dh0: assert got_dh0 is not None - torch.testing.assert_close(got_dh0, ref_dh0, atol=3e-2, rtol=3e-2) + torch.testing.assert_close(got_dh0.float(), ref_dh0.float(), atol=ATOL, rtol=RTOL, msg=f"{msg}: dh0") else: assert got_dh0 is None -def _make_varlen_inputs( - seq_lens, +def _make_inputs( + B, + T, H, K, V, use_g=False, use_gk=False, - use_dht=False, - use_h0=False, + use_state=False, seed=42, transpose_state_layout=False, ): torch.manual_seed(seed) + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) * 0.1 + w = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) * 0.1 + do = torch.randn(B, T, H, V, dtype=torch.bfloat16, device=device) * 0.1 + dv = torch.randn(B, T, H, V, dtype=torch.bfloat16, device=device) * 0.1 + + g = None + if use_g: + g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device=device) * 0.01).cumsum(dim=1) + + gk = None + if use_gk: + gk = -torch.abs(torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.01).cumsum(dim=1) + + state_shape = (B, H, V, K) if transpose_state_layout else (B, H, K, V) + dht = torch.randn(state_shape, dtype=torch.float32, device=device) * 0.01 if use_state else None + dh0 = torch.empty(state_shape, dtype=torch.float32, device=device) if use_state else None + return q, k, w, do, dv, g, gk, dht, dh0 + + +def _make_varlen_inputs(seq_lens, H, K, V, use_g=False, use_gk=False, use_state=False, seed=42): T_total = sum(seq_lens) - N = len(seq_lens) + num_seqs = len(seq_lens) cu = [0] for seq_len in seq_lens: cu.append(cu[-1] + seq_len) - q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - w = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device="cuda") * 0.1 - do = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 - dv = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device="cuda") * 0.1 + torch.manual_seed(seed) + q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device=device) * 0.1 + w = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device=device) * 0.1 + do = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device=device) * 0.1 + dv = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device=device) * 0.1 g = None if use_g: - g = torch.empty(1, T_total, H, dtype=torch.float32, device="cuda") - for i in range(N): + g = torch.empty(1, T_total, H, dtype=torch.float32, device=device) + for i in range(num_seqs): bos, eos = cu[i], cu[i + 1] - seg = torch.randn(1, eos - bos, H, dtype=torch.float32, device="cuda") * 0.01 + seg = torch.randn(1, eos - bos, H, dtype=torch.float32, device=device) * 0.01 g[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) gk = None if use_gk: - gk = torch.empty(1, T_total, H, K, dtype=torch.float32, device="cuda") - for i in range(N): + gk = torch.empty(1, T_total, H, K, dtype=torch.float32, device=device) + for i in range(num_seqs): bos, eos = cu[i], cu[i + 1] - seg = torch.randn(1, eos - bos, H, K, dtype=torch.float32, device="cuda") * 0.01 + seg = torch.randn(1, eos - bos, H, K, dtype=torch.float32, device=device) * 0.01 gk[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) - state_shape = (N, H, V, K) if transpose_state_layout else (N, H, K, V) - dht = torch.randn(state_shape, dtype=torch.float32, device="cuda") * 0.01 if use_dht else None - h0 = torch.empty(state_shape, dtype=torch.float32, device="cuda") if use_h0 else None - cu_seqlens = torch.tensor(cu, dtype=torch.int32, device="cuda") - return q, k, w, do, dv, g, gk, dht, h0, cu_seqlens + state_shape = (num_seqs, H, K, V) + dht = torch.randn(state_shape, dtype=torch.float32, device=device) * 0.01 if use_state else None + dh0 = torch.empty(state_shape, dtype=torch.float32, device=device) if use_state else None + cu_seqlens = torch.tensor(cu, dtype=torch.int32, device=device) + return q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens -def _run_varlen_case( - seq_lens, - H, - K, - V, - use_g=False, - use_gk=False, - use_dht=False, - use_h0=False, - use_exp2=False, - transpose_state_layout=False, - seed=42, -): - q, k, w, do, dv, g, gk, dht, h0, cu_seqlens = _make_varlen_inputs( - seq_lens, - H, - K, - V, - use_g=use_g, - use_gk=use_gk, - use_dht=use_dht, - use_h0=use_h0, - seed=seed, - transpose_state_layout=transpose_state_layout, - ) - scale = K**-0.5 - ref_dh, ref_dh0, ref_dv2 = fla_bwd_dhu( - q=q, - k=k, - w=w, - do=do, - dv=dv, - g=g, - gk=gk, - h0=h0, - dht=dht, - scale=scale, - cu_seqlens=cu_seqlens.long(), - chunk_size=64, - use_exp2=use_exp2, - transpose_state_layout=transpose_state_layout, - ) - got_dh, got_dh0, got_dv2 = chunk_gated_delta_rule_bwd_dhu_sm90( - q=q, - k=k, - w=w, - do=do, - dv=dv, - g=g, - gk=gk, - h0=h0, - dht=dht, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=64, - use_exp2=use_exp2, - transpose_state_layout=transpose_state_layout, - ) - torch.testing.assert_close(got_dh.float(), ref_dh.float(), atol=3e-2, rtol=3e-2) - torch.testing.assert_close(got_dv2.float(), ref_dv2.float(), atol=3e-2, rtol=3e-2) - if use_h0: - assert got_dh0 is not None - torch.testing.assert_close(got_dh0, ref_dh0, atol=3e-2, rtol=3e-2) - else: - assert got_dh0 is None - - -@pytest.mark.parametrize("T", [64, 128]) -@pytest.mark.parametrize("V", [64, 128]) -def test_bwd_dhu_no_gating(T, V): - _run_case(B=1, T=T, H=1, K=64, V=V) - - -def test_bwd_dhu_with_gk_exp2_and_dht(): - _run_case(B=1, T=128, H=2, K=64, V=64, use_gk=True, use_dht=True, use_exp2=True) +@pytest.mark.parametrize( + "case", + [ + dict(B=1, T=64, H=1, K=128, V=128, use_gk=False, use_state=False), + dict(B=1, T=128, H=4, K=128, V=128, use_gk=True, use_state=False), + dict(B=2, T=256, H=4, K=128, V=128, use_gk=True, use_state=True), + dict(B=1, T=1024, H=64, K=128, V=128, use_gk=True, use_state=False), + ], + ids=["minimal", "multihead-gk", "batch-state", "long-h64"], +) +def test_dhu_against_fla(case): + B, T, H, K, V = case["B"], case["T"], case["H"], case["K"], case["V"] + use_gk, use_state = case["use_gk"], case["use_state"] + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs(B, T, H, K, V, use_gk=use_gk, use_state=use_state) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + _assert_bwd_close(got, ref, use_state, f"B={B} H={H} T={T} gk={use_gk} state={use_state}") -def test_bwd_dhu_with_scalar_g_exp2_and_dht(): - _run_case(B=1, T=128, H=2, K=64, V=64, use_g=True, use_dht=True, use_exp2=True) +def test_dv2_no_gating(): + B, T, H, K, V = 4, 512, 4, 128, 128 + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs(B, T, H, K, V) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + _assert_bwd_close(got, ref, False, f"dv2 no-gating B={B} T={T} H={H}") -def test_bwd_dhu_with_scalar_g_and_gk(): - _run_case(B=1, T=128, H=1, K=128, V=64, use_g=True, use_gk=True, use_dht=True, use_exp2=True) +@pytest.mark.parametrize( + "case", + [ + dict(seq_lens=[50, 192, 100], H=2, use_g=False, use_gk=True, use_state=False), + dict(seq_lens=[33, 128, 200, 95], H=1, use_g=True, use_gk=False, use_state=True), + ], + ids=["gk-dht", "g-dh0"], +) +def test_varlen_against_fla(case): + K, V = 128, 128 + seq_lens = case["seq_lens"] + H = case["H"] + use_g = case["use_g"] + use_gk = case["use_gk"] + use_state = case["use_state"] + q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens = _make_varlen_inputs( + seq_lens, H, K, V, use_g=use_g, use_gk=use_gk, use_state=use_state + ) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, cu_seqlens=cu_seqlens) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, cu_seqlens=cu_seqlens) + _assert_bwd_close(got, ref, use_state, f"varlen seqs={seq_lens} H={H} gk={use_gk} state={use_state}") -def test_bwd_dhu_k128_with_gk_exp2_and_dht(): - _run_case(B=1, T=128, H=1, K=128, V=64, use_gk=True, use_dht=True, use_exp2=True) +def test_varlen_vs_nonvarlen(): + H, K, V = 2, 128, 128 + T = 256 + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs(1, T, H, K, V, use_gk=True, use_state=True) + dh_nv, dh0_nv, dv2_nv = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + cu_seqlens = torch.tensor([0, T], dtype=torch.int32, device=device) + dh_vl, dh0_vl, dv2_vl = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, cu_seqlens=cu_seqlens) -def test_bwd_dhu_returns_dh0(): - _run_case(B=2, T=128, H=1, K=64, V=64, use_h0=True) + torch.testing.assert_close(dh_nv.float(), dh_vl.float(), atol=1e-6, rtol=1e-6) + torch.testing.assert_close(dv2_nv.float(), dv2_vl.float(), atol=1e-6, rtol=1e-6) + torch.testing.assert_close(dh0_nv.float(), dh0_vl.float(), atol=1e-6, rtol=1e-6) @pytest.mark.parametrize( - "case", + "use_g,use_gk", [ - dict(B=1, T=256, H=4, K=64, V=64, use_gk=True, use_dht=True, use_exp2=True), - dict(B=1, T=256, H=4, K=64, V=64, use_g=True, use_dht=True, use_exp2=True), - dict(B=1, T=256, H=2, K=128, V=64, use_g=True, use_gk=True, use_dht=True, use_exp2=True), - dict(B=2, T=256, H=2, K=128, V=64, use_gk=True, use_dht=True, use_h0=True, use_exp2=True), - dict(B=1, T=512, H=4, K=128, V=128, use_gk=True, use_dht=True, use_exp2=True), - dict(B=1, T=128, H=2, K=256, V=64, use_gk=True, use_dht=True, use_exp2=True), - ], - ids=[ - "k64-v64-multihead-gk-dht", - "k64-v64-multihead-g-dht", - "k128-v64-g-and-gk", - "k128-v64-batch-h0", - "k128-v128-long", - "k256-v64", + (True, False), + (True, True), ], ) -def test_bwd_dhu_forward_aligned_cases(case): - _run_case(**case, seed=123) +def test_scalar_g_features(use_g, use_gk): + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs( + B=1, + T=128, + H=2, + K=128, + V=128, + use_g=use_g, + use_gk=use_gk, + use_state=True, + seed=123, + ) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) + _assert_bwd_close(got, ref, True, f"scalar-g g={use_g} gk={use_gk}") -def test_bwd_dhu_transpose_state_layout(): - _run_case( +def test_transpose_state_layout(): + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs( B=1, T=128, H=2, K=128, - V=64, + V=128, use_gk=True, - use_dht=True, - use_h0=True, - use_exp2=True, + use_state=True, + seed=456, transpose_state_layout=True, ) - - -@pytest.mark.parametrize( - "case", - [ - dict(seq_lens=[64, 128], H=1, K=64, V=64), - dict(seq_lens=[50, 192, 100], H=2, K=64, V=64, use_gk=True, use_dht=True, use_exp2=True), - dict(seq_lens=[33, 128, 200], H=1, K=128, V=64, use_g=True, use_dht=True, use_h0=True, use_exp2=True), - dict( - seq_lens=[96, 129], - H=1, - K=128, - V=64, - use_gk=True, - use_dht=True, - use_h0=True, - use_exp2=True, - transpose_state_layout=True, - ), - ], - ids=["basic", "gk-dht", "g-h0", "transpose-gk-h0"], -) -def test_bwd_dhu_varlen_cases(case): - _run_varlen_case(**case, seed=321) + ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, transpose_state_layout=True) + got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, transpose_state_layout=True) + _assert_bwd_close(got, ref, True, "transpose state layout") From e2d04dc87319611ad26d7f5de3fe81a87f202c7e Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Wed, 13 May 2026 20:56:24 +0800 Subject: [PATCH 05/28] format --- benchmarks/bench_chunk_delta_h_bwd_sm90.py | 135 +++++- cula/ops/chunk_delta_h_bwd.py | 529 ++++++++++++--------- 2 files changed, 421 insertions(+), 243 deletions(-) diff --git a/benchmarks/bench_chunk_delta_h_bwd_sm90.py b/benchmarks/bench_chunk_delta_h_bwd_sm90.py index 4edebbc..33d1ceb 100644 --- a/benchmarks/bench_chunk_delta_h_bwd_sm90.py +++ b/benchmarks/bench_chunk_delta_h_bwd_sm90.py @@ -16,6 +16,7 @@ Usage: python benchmarks/bench_chunk_delta_h_bwd_sm90.py --mode both python benchmarks/bench_chunk_delta_h_bwd_sm90.py --preset fwd --mode non-varlen + python benchmarks/bench_chunk_delta_h_bwd_sm90.py --preset focused --mode non-varlen """ import argparse @@ -34,7 +35,9 @@ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu as fla_bwd_dhu from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets -from cula.ops.chunk_delta_h_bwd import chunk_gated_delta_rule_bwd_dhu_sm90 +import cula.ops.chunk_delta_h_bwd as bwd_mod + +chunk_gated_delta_rule_bwd_dhu_sm90 = bwd_mod.chunk_gated_delta_rule_bwd_dhu_sm90 if hasattr(sys.stdout, "reconfigure"): sys.stdout.reconfigure(line_buffering=True) @@ -69,26 +72,32 @@ def time_kernel(fn, warmup=None, n_iters=None): def accuracy_stats(ref, out): diff = (ref.float() - out.float()).abs() - return diff.max().item(), diff.mean().item() + max_abs = diff.max().item() + rel_linf = max_abs / max(ref.float().abs().max().item(), 1e-6) + return max_abs, diff.mean().item(), rel_linf def bwd_accuracy_stats(ref_result, cute_result): ref_dh, ref_dh0, ref_dv2 = ref_result got_dh, got_dh0, got_dv2 = cute_result - dh_max, dh_mean = accuracy_stats(ref_dh, got_dh) - dv2_max, dv2_mean = accuracy_stats(ref_dv2, got_dv2) - dh0_max, dh0_mean = 0.0, 0.0 + dh_max, dh_mean, dh_rel = accuracy_stats(ref_dh, got_dh) + dv2_max, dv2_mean, dv2_rel = accuracy_stats(ref_dv2, got_dv2) + dh0_max, dh0_mean, dh0_rel = 0.0, 0.0, 0.0 if ref_dh0 is not None: - dh0_max, dh0_mean = accuracy_stats(ref_dh0, got_dh0) + dh0_max, dh0_mean, dh0_rel = accuracy_stats(ref_dh0, got_dh0) return { "dh_max": dh_max, "dh_mean": dh_mean, + "dh_rel": dh_rel, "dh0_max": dh0_max, "dh0_mean": dh0_mean, + "dh0_rel": dh0_rel, "dv2_max": dv2_max, "dv2_mean": dv2_mean, + "dv2_rel": dv2_rel, "max_diff": max(dh_max, dh0_max, dv2_max), "mean_diff": max(dh_mean, dh0_mean, dv2_mean), + "max_rel": max(dh_rel, dh0_rel, dv2_rel), } @@ -230,6 +239,51 @@ def flags_str(use_g, use_gk, use_dht, use_dh0): return f" [{','.join(flags)}]" if flags else "" +FOCUSED_FEATURE_MODES = { + "A": (False, False, False, False), + "B": (False, False, True, True), + "C": (True, False, False, False), + "D": (False, True, False, False), + "E": (True, True, False, False), +} + + +def focused_feature_label(use_g, use_gk, use_dht, use_dh0): + for name, flags in FOCUSED_FEATURE_MODES.items(): + if flags == (use_g, use_gk, use_dht, use_dh0): + return name + return "-" + + +def build_focused_non_varlen_configs(feature_mode="all"): + modes = FOCUSED_FEATURE_MODES.items() + if feature_mode != "all": + modes = [(feature_mode, FOCUSED_FEATURE_MODES[feature_mode])] + + configs = [] + for B in (1, 2, 4): + for H in (16, 32): + for T in (2048, 4096, 8192, 16384): + for _, flags in modes: + use_g, use_gk, use_dht, use_dh0 = flags + configs.append((B, T, H, use_g, use_gk, use_dht, use_dh0)) + return configs + + +def filter_non_varlen_configs(configs, only_b=None, only_h=None, only_t=None): + if only_b is not None: + configs = [cfg for cfg in configs if cfg[0] == only_b] + if only_t is not None: + configs = [cfg for cfg in configs if cfg[1] == only_t] + if only_h is not None: + configs = [cfg for cfg in configs if cfg[2] == only_h] + return configs + + +def _compile_cache_misses(): + return bwd_mod._compile_bwd_dhu_sm90.cache_info().misses + + def bench_non_varlen(configs): print("\n" + "=" * 80) print(" Non-Varlen Benchmark: CuTe DSL (SM90) bwd_dhu vs FLA Triton") @@ -240,7 +294,9 @@ def bench_non_varlen(configs): q, k, w, do, dv, g, gk, dht, dh0 = make_non_varlen_inputs(B, T, H, use_g, use_gk, use_dht, use_dh0) ref = run_fla(q, k, w, do, dv, g, gk, dht, dh0) + misses_before = _compile_cache_misses() got = run_cute(q, k, w, do, dv, g, gk, dht, dh0) + compiled_new = _compile_cache_misses() > misses_before torch.cuda.synchronize() acc = bwd_accuracy_stats(ref, got) @@ -254,12 +310,15 @@ def run_cute_case(q=q, k=k, w=w, do=do, dv=dv, g=g, gk=gk, dht=dht, dh0=dh0): ms_cute = time_kernel(run_cute_case) speedup = ms_fla / ms_cute if ms_cute > 0 else float("inf") flag_str = flags_str(use_g, use_gk, use_dht, use_dh0) + feature_mode = focused_feature_label(use_g, use_gk, use_dht, use_dh0) r = { "B": B, "T": T, "H": H, + "feature_mode": feature_mode, "flags": flag_str, + "compiled_new": compiled_new, "ms_fla": ms_fla, "ms_cute": ms_cute, "speedup": speedup, @@ -267,10 +326,11 @@ def run_cute_case(q=q, k=k, w=w, do=do, dv=dv, g=g, gk=gk, dht=dht, dh0=dh0): } results.append(r) print( - f" B={B:2d} T={T:5d} H={H:3d}{flag_str:<18s} | " - f"max={acc['max_diff']:.6f} mean={acc['mean_diff']:.8f} " - f"(dh={acc['dh_max']:.6f} dh0={acc['dh0_max']:.6f} dv2={acc['dv2_max']:.6f}) | " - f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | speedup={speedup:.2f}x" + f" B={B:2d} T={T:5d} H={H:3d} mode={feature_mode}{flag_str:<18s} | " + f"abs(dh={acc['dh_max']:.6f} dh0={acc['dh0_max']:.6f} dv2={acc['dv2_max']:.6f}) " + f"rel(dh={acc['dh_rel']:.3e} dh0={acc['dh0_rel']:.3e} dv2={acc['dv2_rel']:.3e}) | " + f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms speedup={speedup:.2f}x | " + f"compiled={'yes' if compiled_new else 'no'}" ) return results @@ -351,7 +411,9 @@ def run_cute_case( "T_total": total_T, "H": H, "n_seqs": num_seqs, + "feature_mode": focused_feature_label(use_g, use_gk, use_dht, use_dh0), "flags": flag_str, + "compiled_new": False, "ms_fla": ms_fla, "ms_cute": ms_cute, "speedup": speedup, @@ -382,19 +444,23 @@ def print_report(nv_results, vl_results): if nv_results: print("\n [Non-Varlen]") - print(f" {'-' * 112}") - print(f" {'Config':<37s} | {'max_diff':>10s} {'mean_diff':>12s} | {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}") - print(f" {'-' * 112}") + print(f" {'-' * 132}") + print( + f" {'Config':<45s} | {'max_abs':>10s} {'max_rel':>10s} | " + f"{'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s} {'Compiled':>8s}" + ) + print(f" {'-' * 132}") for r in nv_results: - label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d}{r['flags']}" + label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d} mode={r.get('feature_mode', '-')}{r['flags']}" print( - f" {label:<37s} | {r['max_diff']:10.6f} {r['mean_diff']:12.8f} | " - f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x" + f" {label:<45s} | {r['max_diff']:10.6f} {r['max_rel']:10.3e} | " + f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x " + f"{'yes' if r.get('compiled_new') else 'no':>8s}" ) - print(f" {'-' * 112}") + print(f" {'-' * 132}") speedups = [r["speedup"] for r in nv_results] geo = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) - print(f" {'Geometric mean':<37s} | {'':>10s} {'':>12s} | {'':>9s} {'':>9s} {geo:7.2f}x") + print(f" {'Geometric mean':<45s} | {'':>10s} {'':>10s} | {'':>9s} {'':>9s} {geo:7.2f}x {'':>8s}") if vl_results: print("\n [Varlen]") @@ -428,9 +494,20 @@ def main(): "--preset", type=str, default="representative", - choices=["representative", "fwd"], - help="representative runs a short subset; fwd mirrors bench_chunk_delta_h.py's large default configs", + choices=["representative", "fwd", "focused"], + help="representative runs a short subset; fwd mirrors bench_chunk_delta_h.py; focused runs the long non-varlen matrix", + ) + parser.add_argument( + "--feature-mode", + type=str, + default="all", + choices=["all", "A", "B", "C", "D", "E"], + help="For --preset focused: A=no gates/state, B=dht+dh0, C=g, D=gk, E=g+gk", ) + parser.add_argument("--max-configs", type=int, default=None, help="Run only the first N configs from the selected preset") + parser.add_argument("--filter-b", type=int, default=None, help="Only run non-varlen configs with this B") + parser.add_argument("--filter-h", type=int, default=None, help="Only run non-varlen configs with this H") + parser.add_argument("--filter-t", type=int, default=None, help="Only run non-varlen configs with this T") parser.add_argument("--warmup", type=int, default=None, help="Override warmup iterations") parser.add_argument("--iters", type=int, default=None, help="Override timed iterations") parser.add_argument("--ncu", action="store_true", help="NCU profiling mode: warmup=1, iters=1") @@ -448,7 +525,12 @@ def main(): if args.iters is not None: N_ITERS = args.iters - if args.preset == "fwd": + if args.preset == "focused": + # Focused non-varlen long-token matrix requested for SM90 bwd_dhu tuning. + # Tuple: (B, T, H, use_g, use_gk, use_dht, use_dh0) + non_varlen_configs = build_focused_non_varlen_configs(args.feature_mode) + varlen_configs = [] + elif args.preset == "fwd": # Matches bench_chunk_delta_h.py's default dimensions. # Tuple: (B, T, H, use_g, use_gk, use_dht, use_dh0) non_varlen_configs = [ @@ -482,6 +564,17 @@ def main(): (4, 768, 2, 4.0, True, False, True, True), ] + non_varlen_configs = filter_non_varlen_configs( + non_varlen_configs, + only_b=args.filter_b, + only_h=args.filter_h, + only_t=args.filter_t, + ) + + if args.max_configs is not None: + non_varlen_configs = non_varlen_configs[: args.max_configs] + varlen_configs = varlen_configs[: args.max_configs] + nv_res, vl_res = [], [] if args.mode in ("non-varlen", "both"): nv_res = bench_non_varlen(non_varlen_configs) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 412f3b9..d3fb552 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -53,7 +53,7 @@ BT = 64 BV = 64 BK = 128 -NUM_THREADS = 128 +NUM_THREADS = 192 def make_thread_cooperative_group(size: int): @@ -105,8 +105,14 @@ def __init__( self.BK = head_dim_k self.num_k_blocks = head_dim_k // self.BK self.num_v_tiles = (head_dim_v + BV - 1) // BV + self.threads_per_warp = 32 + self.num_compute_warps = 4 + self.num_compute_threads = self.threads_per_warp * self.num_compute_warps + self.load_warp_id = 4 + self.store_warp_id = 5 self.num_threads = NUM_THREADS self.num_regs_compute = 232 + self.num_regs_other = 40 self.input_stage = 2 self.io_dtype = cutlass.BFloat16 self.acc_dtype = cutlass.Float32 @@ -118,7 +124,7 @@ def __init__( self.cluster_shape_mnk = (1, 1, 1) self.gk_precompute_bar = pipeline.NamedBarrier( barrier_id=1, - num_threads=self.num_threads, + num_threads=self.num_compute_threads, ) @cute.jit @@ -208,6 +214,19 @@ def __call__( ), ) dh = cute.make_tensor(dh_ptr, state_layout) + if cutlass.const_expr(self.transpose_state_layout): + dh_tma_layout = cute.make_layout( + (self.K, self.V, (NT_total, self.H, self.B)), + stride=(1, self.K, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), + ) + dh_tma_tile = (self.BK, self.BV) + else: + dh_tma_layout = cute.make_layout( + (self.V, self.K, (NT_total, self.H, self.B)), + stride=(1, self.V, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), + ) + dh_tma_tile = (self.BV, self.BK) + dh_tma = cute.make_tensor(dh_ptr, dh_tma_layout) if cutlass.const_expr(self.transpose_state_layout): final_layout = cute.make_layout( @@ -285,15 +304,16 @@ def __call__( self.io_dtype, self.input_stage, ) - dv_smem_layout_staged = cute.make_layout( - (self.BV, self.BT, self.input_stage), - stride=(1, self.BV, self.BV * self.BT), - ) gk_smem_layout_staged = cute.make_layout( (self.BK, 1, self.input_stage), stride=(1, self.BK, self.BK), ) - + dh_smem_layout_staged = sm90_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + dh_tma_tile, + 1, + ) tma_load_op = cpasync.CopyBulkTensorTileG2SOp() tma_store_op = cpasync.CopyBulkTensorTileS2GOp() @@ -306,7 +326,7 @@ def __call__( tma_atom_dv, tma_tensor_dv = cpasync.make_tiled_tma_atom( tma_load_op, dv_vt, - cute.slice_(dv_smem_layout_staged, (None, None, 0)), + cute.slice_(update_a_smem_layout_staged, (None, None, 0)), (self.BV, self.BT), ) tma_atom_do, tma_tensor_do = cpasync.make_tiled_tma_atom( @@ -333,14 +353,20 @@ def __call__( cute.slice_(gk_smem_layout_staged, (None, None, 0)), (self.BK, 1), ) + tma_atom_dh, tma_tensor_dh = cpasync.make_tiled_tma_atom( + tma_store_op, + dh_tma, + cute.slice_(dh_smem_layout_staged, (None, None, 0)), + dh_tma_tile, + ) tma_atom_dv2, tma_tensor_dv2 = cpasync.make_tiled_tma_atom( tma_store_op, dv2_vt, - cute.slice_(dv_smem_layout_staged, (None, None, 0)), + cute.slice_(update_a_smem_layout_staged, (None, None, 0)), (self.BV, self.BT), ) self.tma_k_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(a_smem_layout_staged, (None, None, 0))) - self.tma_dv_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(dv_smem_layout_staged, (None, None, 0))) + self.tma_dv_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_a_smem_layout_staged, (None, None, 0))) self.tma_do_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_a_smem_layout_staged, (None, None, 0))) self.tma_q_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) @@ -382,8 +408,8 @@ class SharedStorage: cute.struct.MemRange[self.io_dtype, cute.cosize(update_b_smem_layout_staged)], self.buffer_align_bytes, ] - sDv2T: cute.struct.Align[ - cute.struct.MemRange[self.io_dtype, cute.cosize(dv_smem_layout_staged)], + sDh: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(dh_smem_layout_staged)], self.buffer_align_bytes, ] @@ -421,6 +447,9 @@ class SharedStorage: tma_tensor_w, tma_atom_gk, tma_tensor_gk, + tma_atom_dh, + tma_tensor_dh, + dh_smem_layout_staged, tma_atom_dv2, tma_tensor_dv2, ).launch( @@ -465,6 +494,9 @@ def kernel( tma_tensor_w: cute.Tensor, tma_atom_gk: cute.CopyAtom, tma_tensor_gk: cute.Tensor, + tma_atom_dh: cute.CopyAtom, + tma_tensor_dh: cute.Tensor, + dh_smem_layout_staged: cute.ComposedLayout, tma_atom_dv2: cute.CopyAtom, tma_tensor_dv2: cute.Tensor, ): @@ -498,59 +530,63 @@ def kernel( sGK = storage.sGK.get_tensor(cute.make_layout((BK, 1, self.input_stage), stride=(1, BK, BK))) sUB = storage.sUB.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) sW = storage.sW.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) - sDv2T = storage.sDv2T.get_tensor(cute.make_layout((BV, BT, self.input_stage), stride=(1, BV, BV * BT))) + sDh = storage.sDh.get_tensor(dh_smem_layout_staged.outer, swizzle=dh_smem_layout_staged.inner) - if warp_idx == 0: + if warp_idx == self.load_warp_id: cpasync.prefetch_descriptor(tma_atom_k) cpasync.prefetch_descriptor(tma_atom_dv) cpasync.prefetch_descriptor(tma_atom_do) cpasync.prefetch_descriptor(tma_atom_q) cpasync.prefetch_descriptor(tma_atom_w) - cpasync.prefetch_descriptor(tma_atom_gk) + if cutlass.const_expr(self.use_gk): + cpasync.prefetch_descriptor(tma_atom_gk) + if warp_idx == self.store_warp_id: + cpasync.prefetch_descriptor(tma_atom_dh) cpasync.prefetch_descriptor(tma_atom_dv2) load_k_P, load_k_C = pipeline.PipelineTmaAsync.create( num_stages=self.input_stage, producer_group=make_thread_cooperative_group(1), - consumer_group=make_thread_cooperative_group(self.num_threads // 32), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_k_bytes, barrier_storage=storage.load_k_mbar.data_ptr(), ).make_participants() load_dv_P, load_dv_C = pipeline.PipelineTmaAsync.create( num_stages=self.input_stage, producer_group=make_thread_cooperative_group(1), - consumer_group=make_thread_cooperative_group(self.num_threads // 32), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_dv_bytes, barrier_storage=storage.load_dv_mbar.data_ptr(), ).make_participants() load_do_P, load_do_C = pipeline.PipelineTmaAsync.create( num_stages=self.input_stage, producer_group=make_thread_cooperative_group(1), - consumer_group=make_thread_cooperative_group(self.num_threads // 32), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_do_bytes, barrier_storage=storage.load_do_mbar.data_ptr(), ).make_participants() load_q_P, load_q_C = pipeline.PipelineTmaAsync.create( num_stages=self.input_stage, producer_group=make_thread_cooperative_group(1), - consumer_group=make_thread_cooperative_group(self.num_threads // 32), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_q_bytes, barrier_storage=storage.load_q_mbar.data_ptr(), ).make_participants() load_w_P, load_w_C = pipeline.PipelineTmaAsync.create( num_stages=self.input_stage, producer_group=make_thread_cooperative_group(1), - consumer_group=make_thread_cooperative_group(self.num_threads // 32), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_w_bytes, barrier_storage=storage.load_w_mbar.data_ptr(), ).make_participants() - load_gk_P, load_gk_C = pipeline.PipelineTmaAsync.create( - num_stages=self.input_stage, - producer_group=make_thread_cooperative_group(1), - consumer_group=make_thread_cooperative_group(self.num_threads // 32), - tx_count=self.tma_gk_bytes, - barrier_storage=storage.load_gk_mbar.data_ptr(), - ).make_participants() + if cutlass.const_expr(self.use_gk): + load_gk_P, load_gk_C = pipeline.PipelineTmaAsync.create( + num_stages=self.input_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_gk_bytes, + barrier_storage=storage.load_gk_mbar.data_ptr(), + ).make_participants() if cutlass.const_expr(self.is_varlen): tma_tensor_k_use = cute.domain_offset((seq_start, 0, (0, 0)), tma_tensor_k) @@ -558,22 +594,26 @@ def kernel( tma_tensor_do_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_do) tma_tensor_q_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_q) tma_tensor_w_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_w) - tma_tensor_gk_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_gk) + tma_tensor_dh_use = cute.domain_offset((0, 0, (chunk_base, 0, 0)), tma_tensor_dh) tma_tensor_dv2_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_dv2) + if cutlass.const_expr(self.use_gk): + tma_tensor_gk_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_gk) else: tma_tensor_k_use = tma_tensor_k tma_tensor_dv_use = tma_tensor_dv tma_tensor_do_use = tma_tensor_do tma_tensor_q_use = tma_tensor_q tma_tensor_w_use = tma_tensor_w - tma_tensor_gk_use = tma_tensor_gk + tma_tensor_dh_use = tma_tensor_dh tma_tensor_dv2_use = tma_tensor_dv2 + if cutlass.const_expr(self.use_gk): + tma_tensor_gk_use = tma_tensor_gk _, bSG_sK, bSG_gK = self._epilog_partition( tma_atom_k, tma_tensor_k_use[None, None, (i_h, data_b)], (self.BT, self.BK), sA ) _, bSG_sDv, bSG_gDv = self._epilog_partition( - tma_atom_dv, tma_tensor_dv_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDv2T + tma_atom_dv, tma_tensor_dv_use[None, None, (i_h, data_b)], (self.BV, self.BT), sUA ) _, bSG_sDo, bSG_gDo = self._epilog_partition( tma_atom_do, tma_tensor_do_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDo @@ -584,17 +624,31 @@ def kernel( _, bSG_sW, bSG_gW = self._epilog_partition( tma_atom_w, tma_tensor_w_use[None, None, (i_h, data_b)], (self.BK, self.BT), sW ) - _, bSG_sGK, bSG_gGK = self._epilog_partition( - tma_atom_gk, tma_tensor_gk_use[None, None, (i_h, data_b)], (self.BK, 1), sGK - ) + if cutlass.const_expr(self.use_gk): + _, bSG_sGK, bSG_gGK = self._epilog_partition( + tma_atom_gk, tma_tensor_gk_use[None, None, (i_h, data_b)], (self.BK, 1), sGK + ) + if cutlass.const_expr(self.transpose_state_layout): + _, bSG_sDh, bSG_gDh = self._epilog_partition( + tma_atom_dh, tma_tensor_dh_use[None, None, (None, i_h, state_b)], (self.BK, self.BV), sDh + ) + else: + _, bSG_sDh, bSG_gDh = self._epilog_partition( + tma_atom_dh, tma_tensor_dh_use[None, None, (None, i_h, state_b)], (self.BV, self.BK), sDh + ) _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( - tma_atom_dv2, tma_tensor_dv2_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDv2T + tma_atom_dv2, tma_tensor_dv2_use[None, None, (i_h, data_b)], (self.BV, self.BT), sUA ) - cute.arch.setmaxregister_increase(self.num_regs_compute) + is_compute_warp = warp_idx < self.num_compute_warps + local_tidx = tidx % self.num_compute_threads + if is_compute_warp: + cute.arch.setmaxregister_increase(self.num_regs_compute) + else: + cute.arch.setmaxregister_decrease(self.num_regs_other) - thr_mma = tiled_mma.get_slice(tidx) - update_thr_mma = update_tiled_mma.get_slice(tidx) + thr_mma = tiled_mma.get_slice(local_tidx) + update_thr_mma = update_tiled_mma.get_slice(local_tidx) tCsA = thr_mma.partition_A(sA) tCsB = thr_mma.partition_B(sB) @@ -631,22 +685,23 @@ def kernel( acc_wdv = update_thr_mma.make_fragment_C(state_shape) # Initialize carried dh state in register blocks. - for k_block in cutlass.range_constexpr(self.num_k_blocks): - k_base = k_block * self.BK - rState = rStates[k_block] - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - v_idx = v_base + v_rel - k_idx = k_base + k_rel - init = Float32(0.0) - if cutlass.const_expr(self.use_dht): - if cutlass.const_expr(self.transpose_state_layout): - init = dht[i_n, i_h, v_idx, k_idx].to(self.acc_dtype) - else: - init = dht[i_n, i_h, k_idx, v_idx].to(self.acc_dtype) - rState[ei] = init - - if warp_idx == 0 and NT > 0: + if is_compute_warp: + for k_block in cutlass.range_constexpr(self.num_k_blocks): + k_base = k_block * self.BK + rState = rStates[k_block] + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + v_idx = v_base + v_rel + k_idx = k_base + k_rel + init = Float32(0.0) + if cutlass.const_expr(self.use_dht): + if cutlass.const_expr(self.transpose_state_layout): + init = dht[i_n, i_h, v_idx, k_idx].to(self.acc_dtype) + else: + init = dht[i_n, i_h, k_idx, v_idx].to(self.acc_dtype) + rState[ei] = init + + if warp_idx == self.load_warp_id and NT > 0: first_chunk = NT - 1 k_h = load_k_P.acquire_and_advance() cute.copy(tma_atom_k, bSG_gK[(None, first_chunk, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) @@ -682,19 +737,47 @@ def kernel( g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) # Store dh before applying this chunk's reverse update. - for k_block in cutlass.range_constexpr(self.num_k_blocks): - k_base = k_block * self.BK - rState = rStates[k_block] - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - v_idx = v_base + v_rel - k_idx = k_base + k_rel - if cutlass.const_expr(self.transpose_state_layout): - dh[state_b, chunk_base + i_t, i_h, v_idx, k_idx] = rState[ei].to(dh.element_type) - else: - dh[state_b, chunk_base + i_t, i_h, k_idx, v_idx] = rState[ei].to(dh.element_type) + if cutlass.const_expr(not self.is_varlen): + if chunk_rev > 0: + prev_chunk_start = (i_t + 1) * self.BT + prev_chunk_end = cutlass.min(prev_chunk_start + self.BT, seq_len) + prev_remaining = prev_chunk_end - prev_chunk_start + if prev_remaining < self.BT: + if warp_idx == self.store_warp_id: + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.sync_threads() + else: + if warp_idx == self.store_warp_id: + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.sync_threads() + if is_compute_warp: + for k_block in cutlass.range_constexpr(self.num_k_blocks): + k_base = k_block * self.BK + rState = rStates[k_block] + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + state_bf16 = rState[ei].to(self.io_dtype) + if cutlass.const_expr(self.transpose_state_layout): + sDh[k_rel, v_rel, 0] = state_bf16 + sB[v_rel, k_rel, 0] = state_bf16 + else: + dh[state_b, chunk_base + i_t, i_h, k_idx, v_idx] = state_bf16 + sB[v_rel, k_rel, 0] = state_bf16 + cute.arch.fence_proxy("async.shared", space="cta") + cute.arch.sync_threads() + if warp_idx == self.store_warp_id: + if cutlass.const_expr(self.transpose_state_layout): + cute.copy(tma_atom_dh, bSG_sDh[None, 0], bSG_gDh[(None, 0, i_v_tile, i_t)]) + cute.arch.cp_async_bulk_commit_group() + elif cutlass.const_expr(False): + cute.copy(tma_atom_dh, bSG_sDh[None, 0], bSG_gDh[(None, i_v_tile, 0, i_t)]) + cute.arch.cp_async_bulk_commit_group() + if cutlass.const_expr(self.transpose_state_layout and self.is_varlen): + cute.arch.cp_async_bulk_wait_group(0, read=True) + if cutlass.const_expr(self.is_varlen): + cute.arch.sync_threads() - if warp_idx == 0 and next_i_t >= 0: + if warp_idx == self.load_warp_id and next_i_t >= 0: k_h = load_k_P.acquire_and_advance() cute.copy(tma_atom_k, bSG_gK[(None, next_i_t, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) dv_h = load_dv_P.acquire_and_advance() @@ -713,183 +796,185 @@ def kernel( bSG_sGK[None, gk_h.index], tma_bar_ptr=gk_h.barrier, ) - - # dv2 = dv + K @ dh. - acc_dv.fill(0.0) - for k_block in cutlass.range_constexpr(self.num_k_blocks): - rState = rStates[k_block] - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - sB[v_rel, k_rel, 0] = rState[ei].to(self.io_dtype) - - k_wait = load_k_C.wait_and_advance() - cute.arch.sync_threads() - - cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True): - tiled_mma.set( - cute.nvgpu.warpgroup.Field.ACCUMULATE, - cutlass.Boolean((k_block != 0) or (kp != 0)), - ) - cute.gemm( - tiled_mma, - acc_dv, - tCrA[None, None, kp, k_wait.index], - tCrB[None, None, kp, 0], - acc_dv, - ) - cute.nvgpu.warpgroup.commit_group() - cute.nvgpu.warpgroup.wait_group(0) - k_wait.release() - - if warp_idx == 0: + if warp_idx == self.load_warp_id: do_h = load_do_P.acquire_and_advance() cute.copy(tma_atom_do, bSG_gDo[(None, i_v_tile, i_t)], bSG_sDo[None, do_h.index], tma_bar_ptr=do_h.barrier) q_h = load_q_P.acquire_and_advance() cute.copy(tma_atom_q, bSG_gQ[(None, 0, i_t)], bSG_sQ[None, q_h.index], tma_bar_ptr=q_h.barrier) w_h = load_w_P.acquire_and_advance() cute.copy(tma_atom_w, bSG_gW[(None, 0, i_t)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) - - dv_wait = load_dv_C.wait_and_advance() - cute.arch.sync_threads() - - for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): - t_rel, v_rel = tCcDV[ei] - t_idx = chunk_start + t_rel - v_idx = v_base + v_rel - out = Float32(0.0) - if t_idx < seq_len: - out = acc_dv[ei] - if cutlass.const_expr(self.use_g): - g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - g_decay = cute.exp2(g_last - g_cur, fastmath=self.use_fast_math) - else: - g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) - out = out * g_decay - out = out + sDv2T[v_rel, t_rel, dv_wait.index].to(self.acc_dtype) - sDv2T[v_rel, t_rel, dv_wait.index] = out.to(self.io_dtype) + # dv2 = dv + K @ dh. + if is_compute_warp: + acc_dv.fill(0.0) + for k_block in cutlass.range_constexpr(self.num_k_blocks): + k_wait = load_k_C.wait_and_advance() + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True): + tiled_mma.set( + cute.nvgpu.warpgroup.Field.ACCUMULATE, + cutlass.Boolean((k_block != 0) or (kp != 0)), + ) + cute.gemm( + tiled_mma, + acc_dv, + tCrA[None, None, kp, k_wait.index], + tCrB[None, None, kp, 0], + acc_dv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + k_wait.release() + + dv_wait = load_dv_C.wait_and_advance() + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): + t_rel, v_rel = tCcDV[ei] + t_idx = chunk_start + t_rel + out = Float32(0.0) + if t_idx < seq_len: + out = acc_dv[ei] + if cutlass.const_expr(self.use_g): + g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_decay = cute.exp2(g_last - g_cur, fastmath=self.use_fast_math) + else: + g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) + out = out * g_decay + out = out + sUA[v_rel, t_rel, dv_wait.index].to(self.acc_dtype) + out_bf16 = out.to(self.io_dtype) + sUA[v_rel, t_rel, dv_wait.index] = out_bf16 + dv_wait.release() cute.arch.fence_proxy("async.shared", space="cta") cute.arch.sync_threads() remaining = chunk_end - chunk_start if remaining < self.BT: - linear_store = tidx - while linear_store < self.BV * self.BT: - v_rel = linear_store // self.BT - t_rel = linear_store - v_rel * self.BT - if t_rel < remaining: - dv2[data_b, seq_start + chunk_start + t_rel, i_h, v_base + v_rel] = sDv2T[v_rel, t_rel, dv_wait.index] - linear_store += self.num_threads + if is_compute_warp: + linear_store = local_tidx + while linear_store < self.BV * self.BT: + v_rel = linear_store // self.BT + t_rel = linear_store - v_rel * self.BT + if t_rel < remaining: + dv2[data_b, seq_start + chunk_start + t_rel, i_h, v_base + v_rel] = sUA[ + v_rel, + t_rel, + chunk_rev % self.input_stage, + ] + linear_store += self.num_compute_threads else: - if warp_idx == 0: - cute.copy(tma_atom_dv2, bSG_sDv2[None, dv_wait.index], bSG_gDv2[(None, i_v_tile, i_t)]) + if warp_idx == self.store_warp_id: + cute.copy( + tma_atom_dv2, + bSG_sDv2[None, chunk_rev % self.input_stage], + bSG_gDv2[(None, i_v_tile, i_t)], + ) cute.arch.cp_async_bulk_commit_group() # dh += scale * do^T @ q - dv2^T @ w. - linear = tidx - while linear < self.BV * self.BT: - v_rel = linear // self.BT - t_rel = linear - v_rel * self.BT - sUA[v_rel, t_rel, 0] = sDv2T[v_rel, t_rel, dv_wait.index] - linear += self.num_threads - do_wait = load_do_C.wait_and_advance() - - if cutlass.const_expr(self.use_g or self.is_varlen): - linear_do = tidx - while linear_do < self.BV * self.BT: - v_rel = linear_do // self.BT - t_rel = linear_do - v_rel * self.BT - t_idx = chunk_start + t_rel - do_scaled = Float32(0.0) - if t_idx < seq_len: - do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) - if cutlass.const_expr(self.use_g): - g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) - else: - g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) - do_scaled = do_scaled * g_exp - sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) - linear_do += self.num_threads - cute.arch.sync_threads() - - for k_block in cutlass.range_constexpr(self.num_k_blocks): - rState = rStates[k_block] - q_wait = load_q_C.wait_and_advance() - - acc_qdo.fill(0.0) - cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): - update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - update_tiled_mma, - acc_qdo, - tUrDo[None, None, kp, do_wait.index], - tUrB[None, None, kp, q_wait.index], - acc_qdo, - ) - cute.nvgpu.warpgroup.commit_group() + if is_compute_warp: + do_wait = load_do_C.wait_and_advance() + if cutlass.const_expr(self.use_g or self.is_varlen): + linear_do = local_tidx + while linear_do < self.BV * self.BT: + v_rel = linear_do // self.BT + t_rel = linear_do - v_rel * self.BT + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_g): + g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) + else: + g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) + do_scaled = do_scaled * g_exp + sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) + linear_do += self.num_compute_threads + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + + for k_block in cutlass.range_constexpr(self.num_k_blocks): + rState = rStates[k_block] + q_wait = load_q_C.wait_and_advance() + if cutlass.const_expr(self.use_gk): + gk_wait = load_gk_C.wait_and_advance() + acc_qdo.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_qdo, + tUrDo[None, None, kp, do_wait.index], + tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) + cute.nvgpu.warpgroup.commit_group() + + # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. + if cutlass.const_expr(self.use_g): + for ei in cutlass.range(cute.size(rState), unroll_full=True): + rState[ei] = rState[ei] * g_last_exp + if cutlass.const_expr(self.use_gk): + gk_last = sGK[local_tidx, 0, gk_wait.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + sGK[local_tidx, 0, gk_wait.index] = k_decay + self.gk_precompute_bar.arrive_and_wait() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait.index] + + cute.nvgpu.warpgroup.wait_group(0) + q_wait.release() + if cutlass.const_expr(self.use_gk): + gk_wait.release() + + w_wait = load_w_C.wait_and_advance() + acc_wdv.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_wdv, + tUrA[None, None, kp, chunk_rev % self.input_stage], + tWrB[None, None, kp, w_wait.index], + acc_wdv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) - # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. - if cutlass.const_expr(self.use_g): for ei in cutlass.range(cute.size(rState), unroll_full=True): - rState[ei] = rState[ei] * g_last_exp - if cutlass.const_expr(self.use_gk): - gk_wait = load_gk_C.wait_and_advance() - gk_last = sGK[tidx, 0, gk_wait.index].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) - else: - k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) - sGK[tidx, 0, gk_wait.index] = k_decay - self.gk_precompute_bar.arrive_and_wait() - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait.index] - gk_wait.release() - - cute.nvgpu.warpgroup.wait_group(0) - q_wait.release() - - w_wait = load_w_C.wait_and_advance() - - acc_wdv.fill(0.0) - cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): - update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - update_tiled_mma, - acc_wdv, - tUrA[None, None, kp, 0], - tWrB[None, None, kp, w_wait.index], - acc_wdv, - ) - cute.nvgpu.warpgroup.commit_group() - cute.nvgpu.warpgroup.wait_group(0) - w_wait.release() - - for ei in cutlass.range(cute.size(rState), unroll_full=True): - update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] - rState[ei] = rState[ei] + update - do_wait.release() + update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] + rState[ei] = rState[ei] + update + w_wait.release() + do_wait.release() if remaining >= self.BT: - if warp_idx == 0: + if warp_idx == self.store_warp_id: + cute.arch.cp_async_bulk_wait_group(0, read=True) + + if cutlass.const_expr(not self.is_varlen): + final_remaining = cutlass.min(self.BT, seq_len) + if final_remaining < self.BT: + if warp_idx == self.store_warp_id: cute.arch.cp_async_bulk_wait_group(0, read=True) - dv_wait.release() if cutlass.const_expr(self.use_dh0): - for k_block in cutlass.range_constexpr(self.num_k_blocks): - k_base = k_block * self.BK - rState = rStates[k_block] - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - v_idx = v_base + v_rel - k_idx = k_base + k_rel - if cutlass.const_expr(self.transpose_state_layout): - dh0[i_n, i_h, v_idx, k_idx] = rState[ei] - else: - dh0[i_n, i_h, k_idx, v_idx] = rState[ei] + if is_compute_warp: + for k_block in cutlass.range_constexpr(self.num_k_blocks): + k_base = k_block * self.BK + rState = rStates[k_block] + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + v_idx = v_base + v_rel + k_idx = k_base + k_rel + if cutlass.const_expr(self.transpose_state_layout): + dh0[i_n, i_h, v_idx, k_idx] = rState[ei] + else: + dh0[i_n, i_h, k_idx, v_idx] = rState[ei] @cute.jit def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): From 5b83ffe43548bd26b9b785296d006b45c26a59a9 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Wed, 13 May 2026 20:59:01 +0800 Subject: [PATCH 06/28] rollback --- cula/ops/chunk_delta_h_bwd.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index d3fb552..71dd295 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -761,18 +761,17 @@ def kernel( sDh[k_rel, v_rel, 0] = state_bf16 sB[v_rel, k_rel, 0] = state_bf16 else: - dh[state_b, chunk_base + i_t, i_h, k_idx, v_idx] = state_bf16 + sDh[v_rel, k_rel, 0] = state_bf16 sB[v_rel, k_rel, 0] = state_bf16 cute.arch.fence_proxy("async.shared", space="cta") cute.arch.sync_threads() if warp_idx == self.store_warp_id: if cutlass.const_expr(self.transpose_state_layout): cute.copy(tma_atom_dh, bSG_sDh[None, 0], bSG_gDh[(None, 0, i_v_tile, i_t)]) - cute.arch.cp_async_bulk_commit_group() - elif cutlass.const_expr(False): + else: cute.copy(tma_atom_dh, bSG_sDh[None, 0], bSG_gDh[(None, i_v_tile, 0, i_t)]) - cute.arch.cp_async_bulk_commit_group() - if cutlass.const_expr(self.transpose_state_layout and self.is_varlen): + cute.arch.cp_async_bulk_commit_group() + if cutlass.const_expr(self.is_varlen): cute.arch.cp_async_bulk_wait_group(0, read=True) if cutlass.const_expr(self.is_varlen): cute.arch.sync_threads() From 4d74dda0e3d8d2cd20030f527a23f0718bc99d3f Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Thu, 14 May 2026 11:30:08 +0800 Subject: [PATCH 07/28] re organize pipelines --- cula/ops/chunk_delta_h_bwd.py | 123 ++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 52 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 71dd295..197a2fa 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -114,6 +114,8 @@ def __init__( self.num_regs_compute = 232 self.num_regs_other = 40 self.input_stage = 2 + self.dh_store_stage = 2 + self.dv2_store_stage = 2 self.io_dtype = cutlass.BFloat16 self.acc_dtype = cutlass.Float32 self.buffer_align_bytes = 1024 @@ -312,7 +314,7 @@ def __call__( self.io_dtype, utils.LayoutEnum.COL_MAJOR, dh_tma_tile, - 1, + self.dh_store_stage, ) tma_load_op = cpasync.CopyBulkTensorTileG2SOp() tma_store_op = cpasync.CopyBulkTensorTileS2GOp() @@ -380,6 +382,10 @@ class SharedStorage: load_q_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] load_w_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] load_gk_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] + store_dh_mbar: cute.struct.MemRange[Int64, self.dh_store_stage * 2] + store_dv2_mbar: cute.struct.MemRange[Int64, self.dv2_store_stage * 2] + store_dv2_done_mbar: cute.struct.MemRange[Int64, self.dv2_store_stage * 2] + sDv2Stage: cute.struct.MemRange[Int32, self.dv2_store_stage] sA: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(a_smem_layout_staged)], self.buffer_align_bytes, @@ -526,6 +532,7 @@ def kernel( sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) sUA = storage.sUA.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) + sDv2Stage = storage.sDv2Stage.get_tensor(cute.make_layout((self.dv2_store_stage,))) sDo = storage.sDo.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) sGK = storage.sGK.get_tensor(cute.make_layout((BK, 1, self.input_stage), stride=(1, BK, BK))) sUB = storage.sUB.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) @@ -587,6 +594,24 @@ def kernel( tx_count=self.tma_gk_bytes, barrier_storage=storage.load_gk_mbar.data_ptr(), ).make_participants() + store_dh_P, store_dh_C = pipeline.PipelineAsync.create( + num_stages=self.dh_store_stage, + producer_group=make_thread_cooperative_group(self.num_compute_threads), + consumer_group=make_thread_cooperative_group(self.threads_per_warp), + barrier_storage=storage.store_dh_mbar.data_ptr(), + ).make_participants() + store_dv2_P, store_dv2_C = pipeline.PipelineAsync.create( + num_stages=self.dv2_store_stage, + producer_group=make_thread_cooperative_group(self.num_compute_threads), + consumer_group=make_thread_cooperative_group(self.threads_per_warp), + barrier_storage=storage.store_dv2_mbar.data_ptr(), + ).make_participants() + store_dv2_done_P, store_dv2_done_C = pipeline.PipelineAsync.create( + num_stages=self.dv2_store_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group(self.num_compute_threads), + barrier_storage=storage.store_dv2_done_mbar.data_ptr(), + ).make_participants() if cutlass.const_expr(self.is_varlen): tma_tensor_k_use = cute.domain_offset((seq_start, 0, (0, 0)), tma_tensor_k) @@ -726,6 +751,7 @@ def kernel( next_i_t = i_t - 1 chunk_start = i_t * self.BT chunk_end = cutlass.min(chunk_start + self.BT, seq_len) + remaining = chunk_end - chunk_start last_idx = chunk_end - 1 g_last = Float32(0.0) g_last_exp = Float32(1.0) @@ -736,21 +762,11 @@ def kernel( else: g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) - # Store dh before applying this chunk's reverse update. - if cutlass.const_expr(not self.is_varlen): - if chunk_rev > 0: - prev_chunk_start = (i_t + 1) * self.BT - prev_chunk_end = cutlass.min(prev_chunk_start + self.BT, seq_len) - prev_remaining = prev_chunk_end - prev_chunk_start - if prev_remaining < self.BT: - if warp_idx == self.store_warp_id: - cute.arch.cp_async_bulk_wait_group(0, read=True) - cute.arch.sync_threads() - else: - if warp_idx == self.store_warp_id: - cute.arch.cp_async_bulk_wait_group(0, read=True) - cute.arch.sync_threads() + # Publish the current reverse state both as dh output and as the + # K@dH WGMMA B operand. The dh path is staged for the store warp; + # sB remains single-buffered because only the compute warpgroup uses it. if is_compute_warp: + dh_h = store_dh_P.acquire_and_advance() for k_block in cutlass.range_constexpr(self.num_k_blocks): k_base = k_block * self.BK rState = rStates[k_block] @@ -758,23 +774,23 @@ def kernel( v_rel, k_rel = tUcState[ei] state_bf16 = rState[ei].to(self.io_dtype) if cutlass.const_expr(self.transpose_state_layout): - sDh[k_rel, v_rel, 0] = state_bf16 + sDh[k_rel, v_rel, dh_h.index] = state_bf16 sB[v_rel, k_rel, 0] = state_bf16 else: - sDh[v_rel, k_rel, 0] = state_bf16 + sDh[v_rel, k_rel, dh_h.index] = state_bf16 sB[v_rel, k_rel, 0] = state_bf16 - cute.arch.fence_proxy("async.shared", space="cta") - cute.arch.sync_threads() + cute.arch.fence_proxy("async.shared", space="cta") + dh_h.commit() + if warp_idx == self.store_warp_id: + dh_h = store_dh_C.wait_and_advance() if cutlass.const_expr(self.transpose_state_layout): - cute.copy(tma_atom_dh, bSG_sDh[None, 0], bSG_gDh[(None, 0, i_v_tile, i_t)]) + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, 0, i_v_tile, i_t)]) else: - cute.copy(tma_atom_dh, bSG_sDh[None, 0], bSG_gDh[(None, i_v_tile, 0, i_t)]) + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) cute.arch.cp_async_bulk_commit_group() - if cutlass.const_expr(self.is_varlen): - cute.arch.cp_async_bulk_wait_group(0, read=True) - if cutlass.const_expr(self.is_varlen): - cute.arch.sync_threads() + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh_h.release() if warp_idx == self.load_warp_id and next_i_t >= 0: k_h = load_k_P.acquire_and_advance() @@ -826,6 +842,10 @@ def kernel( k_wait.release() dv_wait = load_dv_C.wait_and_advance() + dv_stage = dv_wait.index + dv2_store_h = store_dv2_P.acquire_and_advance() + if local_tidx == 0: + sDv2Stage[dv2_store_h.index] = dv_stage cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): t_rel, v_rel = tCcDV[ei] @@ -840,15 +860,14 @@ def kernel( else: g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) out = out * g_decay - out = out + sUA[v_rel, t_rel, dv_wait.index].to(self.acc_dtype) + out = out + sUA[v_rel, t_rel, dv_stage].to(self.acc_dtype) out_bf16 = out.to(self.io_dtype) - sUA[v_rel, t_rel, dv_wait.index] = out_bf16 - dv_wait.release() - cute.arch.fence_proxy("async.shared", space="cta") - cute.arch.sync_threads() - remaining = chunk_end - chunk_start - if remaining < self.BT: - if is_compute_warp: + sUA[v_rel, t_rel, dv_stage] = out_bf16 + cute.arch.fence_proxy("async.shared", space="cta") + dv2_store_h.commit() + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + + if remaining < self.BT: linear_store = local_tidx while linear_store < self.BV * self.BT: v_rel = linear_store // self.BT @@ -857,20 +876,11 @@ def kernel( dv2[data_b, seq_start + chunk_start + t_rel, i_h, v_base + v_rel] = sUA[ v_rel, t_rel, - chunk_rev % self.input_stage, + dv_stage, ] linear_store += self.num_compute_threads - else: - if warp_idx == self.store_warp_id: - cute.copy( - tma_atom_dv2, - bSG_sDv2[None, chunk_rev % self.input_stage], - bSG_gDv2[(None, i_v_tile, i_t)], - ) - cute.arch.cp_async_bulk_commit_group() - # dh += scale * do^T @ q - dv2^T @ w. - if is_compute_warp: + # dh += scale * do^T @ q - dv2^T @ w. do_wait = load_do_C.wait_and_advance() if cutlass.const_expr(self.use_g or self.is_varlen): linear_do = local_tidx @@ -939,7 +949,7 @@ def kernel( cute.gemm( update_tiled_mma, acc_wdv, - tUrA[None, None, kp, chunk_rev % self.input_stage], + tUrA[None, None, kp, dv_stage], tWrB[None, None, kp, w_wait.index], acc_wdv, ) @@ -951,15 +961,24 @@ def kernel( rState[ei] = rState[ei] + update w_wait.release() do_wait.release() - if remaining >= self.BT: - if warp_idx == self.store_warp_id: - cute.arch.cp_async_bulk_wait_group(0, read=True) + dv2_done_h = store_dv2_done_C.wait_and_advance() + dv2_done_h.release() + dv_wait.release() - if cutlass.const_expr(not self.is_varlen): - final_remaining = cutlass.min(self.BT, seq_len) - if final_remaining < self.BT: - if warp_idx == self.store_warp_id: + if warp_idx == self.store_warp_id: + dv2_store_h = store_dv2_C.wait_and_advance() + dv2_done_h = store_dv2_done_P.acquire_and_advance() + if remaining >= self.BT: + dv_stage = sDv2Stage[dv2_store_h.index] + cute.copy( + tma_atom_dv2, + bSG_sDv2[None, dv_stage], + bSG_gDv2[(None, i_v_tile, i_t)], + ) + cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) + dv2_done_h.commit() + dv2_store_h.release() if cutlass.const_expr(self.use_dh0): if is_compute_warp: From b3e211ee4bb8d03af609556f0e5399245753b8dc Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Thu, 14 May 2026 12:17:44 +0800 Subject: [PATCH 08/28] tests --- cula/ops/chunk_delta_h_bwd.py | 10 ++- tests/test_chunk_delta_h_bwd_sm90.py | 124 ++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 3 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 197a2fa..201eb10 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -763,7 +763,7 @@ def kernel( g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) # Publish the current reverse state both as dh output and as the - # K@dH WGMMA B operand. The dh path is staged for the store warp; + # K@dH WGMMA B operand. The dh path is staged for the store warp; # sB remains single-buffered because only the compute warpgroup uses it. if is_compute_warp: dh_h = store_dh_P.acquire_and_advance() @@ -961,6 +961,11 @@ def kernel( rState[ei] = rState[ei] + update w_wait.release() do_wait.release() + # sUA[dv_stage] has three consumers after dv2 is written: + # direct tail stores (above), W^T@dv2 (completed here), and + # full-tile TMA stores by the store warp. Keep the load_dv stage + # owned until the store warp signals done, otherwise the load + # warp could refill this stage while a TMA store still reads it. dv2_done_h = store_dv2_done_C.wait_and_advance() dv2_done_h.release() dv_wait.release() @@ -968,6 +973,9 @@ def kernel( if warp_idx == self.store_warp_id: dv2_store_h = store_dv2_C.wait_and_advance() dv2_done_h = store_dv2_done_P.acquire_and_advance() + # One done token is committed per chunk. Tail chunks skip TMA + # because the tile would cross sequence bounds, but still + # publish done so compute and store pipeline phases stay paired. if remaining >= self.BT: dv_stage = sDv2Stage[dv2_store_h.index] cute.copy( diff --git a/tests/test_chunk_delta_h_bwd_sm90.py b/tests/test_chunk_delta_h_bwd_sm90.py index cd93a8d..d731161 100644 --- a/tests/test_chunk_delta_h_bwd_sm90.py +++ b/tests/test_chunk_delta_h_bwd_sm90.py @@ -146,7 +146,17 @@ def _make_inputs( return q, k, w, do, dv, g, gk, dht, dh0 -def _make_varlen_inputs(seq_lens, H, K, V, use_g=False, use_gk=False, use_state=False, seed=42): +def _make_varlen_inputs( + seq_lens, + H, + K, + V, + use_g=False, + use_gk=False, + use_state=False, + seed=42, + transpose_state_layout=False, +): T_total = sum(seq_lens) num_seqs = len(seq_lens) cu = [0] @@ -176,7 +186,7 @@ def _make_varlen_inputs(seq_lens, H, K, V, use_g=False, use_gk=False, use_state= seg = torch.randn(1, eos - bos, H, K, dtype=torch.float32, device=device) * 0.01 gk[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) - state_shape = (num_seqs, H, K, V) + state_shape = (num_seqs, H, V, K) if transpose_state_layout else (num_seqs, H, K, V) dht = torch.randn(state_shape, dtype=torch.float32, device=device) * 0.01 if use_state else None dh0 = torch.empty(state_shape, dtype=torch.float32, device=device) if use_state else None cu_seqlens = torch.tensor(cu, dtype=torch.int32, device=device) @@ -271,6 +281,116 @@ def test_scalar_g_features(use_g, use_gk): _assert_bwd_close(got, ref, True, f"scalar-g g={use_g} gk={use_gk}") +@pytest.mark.parametrize( + "T,use_g,use_gk,transpose_state_layout", + [ + (65, False, False, False), + (127, True, False, False), + (129, False, True, True), + (191, True, True, True), + ], + ids=["t65-plain", "t127-g", "t129-gk-trans", "t191-g-gk-trans"], +) +def test_tail_chunk_sizes(T, use_g, use_gk, transpose_state_layout): + q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs( + B=1, + T=T, + H=2, + K=128, + V=128, + use_g=use_g, + use_gk=use_gk, + use_state=True, + seed=1000 + T, + transpose_state_layout=transpose_state_layout, + ) + ref = run_fla_ref( + q, + k, + w, + do, + dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + transpose_state_layout=transpose_state_layout, + ) + got = run_cute_dsl( + q, + k, + w, + do, + dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + transpose_state_layout=transpose_state_layout, + ) + _assert_bwd_close(got, ref, True, f"T={T} g={use_g} gk={use_gk} trans={transpose_state_layout}") + + +@pytest.mark.parametrize( + "use_g,use_gk,transpose_state_layout", + [ + (False, False, False), + (True, False, False), + (False, True, False), + (True, True, False), + (False, False, True), + (True, False, True), + (False, True, True), + (True, True, True), + ], +) +def test_varlen_tail_chunk_sizes(use_g, use_gk, transpose_state_layout): + seq_lens = [1, 63, 64, 65, 127, 128, 129] + q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens = _make_varlen_inputs( + seq_lens, + H=1, + K=128, + V=128, + use_g=use_g, + use_gk=use_gk, + use_state=True, + seed=2000 + int(use_g) * 10 + int(use_gk) * 20 + int(transpose_state_layout) * 40, + transpose_state_layout=transpose_state_layout, + ) + ref = run_fla_ref( + q, + k, + w, + do, + dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + transpose_state_layout=transpose_state_layout, + ) + got = run_cute_dsl( + q, + k, + w, + do, + dv, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + transpose_state_layout=transpose_state_layout, + ) + _assert_bwd_close( + got, + ref, + True, + f"varlen tails g={use_g} gk={use_gk} trans={transpose_state_layout}", + ) + + def test_transpose_state_layout(): q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs( B=1, From 34a6e8ffdd4a0000fabaa01a70d9da847fa2a026 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Thu, 14 May 2026 18:55:10 +0800 Subject: [PATCH 09/28] del sB --- cula/ops/chunk_delta_h_bwd.py | 73 +++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 201eb10..67d6d30 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -39,6 +39,7 @@ import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute +import cutlass.cute.nvgpu.warpgroup as warpgroup import cutlass.pipeline as pipeline import cutlass.utils as utils import cutlass.utils.hopper_helpers as sm90_utils @@ -121,6 +122,7 @@ def __init__( self.buffer_align_bytes = 1024 self.mma_tiler = (BT, BV, self.BK) + self.kdh_mma_tiler = (BV, BT, self.BK) self.update_mma_tiler = (BV, self.BK, BT) self.atom_layout_mnk = (1, 1, 1) self.cluster_shape_mnk = (1, 1, 1) @@ -269,7 +271,8 @@ def __call__( utils.LayoutEnum.ROW_MAJOR.sm90_mma_major_mode(), self.acc_dtype, self.atom_layout_mnk, - self.mma_tiler[:2], + self.kdh_mma_tiler[:2], + warpgroup.OperandSource.RMEM, ) update_tiled_mma = sm90_utils.make_trivial_tiled_mma( @@ -282,18 +285,12 @@ def __call__( self.update_mma_tiler[:2], ) - a_smem_layout_staged = sm90_utils.make_smem_layout_a( + a_smem_layout_staged = sm90_utils.make_smem_layout_b( utils.LayoutEnum.ROW_MAJOR, - self.mma_tiler, + self.kdh_mma_tiler, self.io_dtype, self.input_stage, ) - b_smem_layout_staged = sm90_utils.make_smem_layout_b( - utils.LayoutEnum.ROW_MAJOR, - self.mma_tiler, - self.io_dtype, - 1, - ) update_a_smem_layout_staged = sm90_utils.make_smem_layout_a( utils.LayoutEnum.COL_MAJOR, self.update_mma_tiler, @@ -390,10 +387,6 @@ class SharedStorage: cute.struct.MemRange[self.io_dtype, cute.cosize(a_smem_layout_staged)], self.buffer_align_bytes, ] - sB: cute.struct.Align[ - cute.struct.MemRange[self.io_dtype, cute.cosize(b_smem_layout_staged)], - self.buffer_align_bytes, - ] sUA: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(update_a_smem_layout_staged)], self.buffer_align_bytes, @@ -438,7 +431,6 @@ class SharedStorage: tiled_mma, update_tiled_mma, a_smem_layout_staged, - b_smem_layout_staged, update_a_smem_layout_staged, update_b_smem_layout_staged, tma_atom_k, @@ -485,7 +477,6 @@ def kernel( tiled_mma: cute.TiledMma, update_tiled_mma: cute.TiledMma, a_smem_layout_staged: cute.ComposedLayout, - b_smem_layout_staged: cute.ComposedLayout, update_a_smem_layout_staged: cute.ComposedLayout, update_b_smem_layout_staged: cute.ComposedLayout, tma_atom_k: cute.CopyAtom, @@ -530,7 +521,6 @@ def kernel( storage = smem.allocate(self.shared_storage) sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) - sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) sUA = storage.sUA.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) sDv2Stage = storage.sDv2Stage.get_tensor(cute.make_layout((self.dv2_store_stage,))) sDo = storage.sDo.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) @@ -675,10 +665,8 @@ def kernel( thr_mma = tiled_mma.get_slice(local_tidx) update_thr_mma = update_tiled_mma.get_slice(local_tidx) - tCsA = thr_mma.partition_A(sA) - tCsB = thr_mma.partition_B(sB) - tCrA = thr_mma.make_fragment_A(tCsA) - tCrB = thr_mma.make_fragment_B(tCsB) + tKsB = thr_mma.partition_B(sA) + tKrB = thr_mma.make_fragment_B(tKsB) tUsA = update_thr_mma.partition_A(sUA) tUsDo = update_thr_mma.partition_A(sDo) tUsB = update_thr_mma.partition_B(sUB) @@ -688,9 +676,9 @@ def kernel( tUrB = update_thr_mma.make_fragment_B(tUsB) tWrB = update_thr_mma.make_fragment_B(tWsB) - cDV = cute.make_identity_tensor((BT, BV)) + cDV = cute.make_identity_tensor((BV, BT)) tCcDV = thr_mma.partition_C(cDV) - acc_dv = thr_mma.make_fragment_C(thr_mma.partition_shape_C((BT, BV))) + acc_dv = thr_mma.make_fragment_C(thr_mma.partition_shape_C((BV, BT))) cState = cute.make_identity_tensor((BV, self.BK)) tUcState = update_thr_mma.partition_C(cState) @@ -762,23 +750,19 @@ def kernel( else: g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) - # Publish the current reverse state both as dh output and as the - # K@dH WGMMA B operand. The dh path is staged for the store warp; - # sB remains single-buffered because only the compute warpgroup uses it. + # Publish the current reverse state as dh output. K@dH consumes the + # same register-carried state through the RMEM WGMMA path below. if is_compute_warp: dh_h = store_dh_P.acquire_and_advance() for k_block in cutlass.range_constexpr(self.num_k_blocks): - k_base = k_block * self.BK rState = rStates[k_block] for ei in cutlass.range(cute.size(rState), unroll_full=True): v_rel, k_rel = tUcState[ei] state_bf16 = rState[ei].to(self.io_dtype) if cutlass.const_expr(self.transpose_state_layout): sDh[k_rel, v_rel, dh_h.index] = state_bf16 - sB[v_rel, k_rel, 0] = state_bf16 else: sDh[v_rel, k_rel, dh_h.index] = state_bf16 - sB[v_rel, k_rel, 0] = state_bf16 cute.arch.fence_proxy("async.shared", space="cta") dh_h.commit() @@ -818,14 +802,16 @@ def kernel( cute.copy(tma_atom_q, bSG_gQ[(None, 0, i_t)], bSG_sQ[None, q_h.index], tma_bar_ptr=q_h.barrier) w_h = load_w_P.acquire_and_advance() cute.copy(tma_atom_w, bSG_gW[(None, 0, i_t)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) - # dv2 = dv + K @ dh. + # dv2 = dv + K @ dh. Compute the equivalent (dh @ K^T) tile so the + # register-carried state can feed WGMMA as an RMEM A operand. if is_compute_warp: acc_dv.fill(0.0) for k_block in cutlass.range_constexpr(self.num_k_blocks): k_wait = load_k_C.wait_and_advance() cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + rState_op = self.make_acc_into_op(rStates[k_block], tiled_mma.tv_layout_A, self.io_dtype) cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True): + for kp in cutlass.range(cute.size(tKrB, mode=[2]), unroll_full=True): tiled_mma.set( cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean((k_block != 0) or (kp != 0)), @@ -833,8 +819,8 @@ def kernel( cute.gemm( tiled_mma, acc_dv, - tCrA[None, None, kp, k_wait.index], - tCrB[None, None, kp, 0], + rState_op[None, None, kp], + tKrB[None, None, kp, k_wait.index], acc_dv, ) cute.nvgpu.warpgroup.commit_group() @@ -848,7 +834,7 @@ def kernel( sDv2Stage[dv2_store_h.index] = dv_stage cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): - t_rel, v_rel = tCcDV[ei] + v_rel, t_rel = tCcDV[ei] t_idx = chunk_start + t_rel out = Float32(0.0) if t_idx < seq_len: @@ -1016,6 +1002,27 @@ def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): ) return atom, bSG_sC, bSG_gC + @staticmethod + def _convert_c_layout_to_a_layout(c, a): + return cute.make_layout( + (a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))), + stride=( + c.stride[0], + c.stride[1], + (c.stride[2], cute.size(a, mode=[2]) * c.stride[0][2]), + ), + ) + + @cute.jit + def make_acc_into_op(self, acc, operand_layout_tv, element_type): + operand = cute.make_rmem_tensor_like( + self._convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]), + element_type, + ) + operand_as_acc = cute.make_tensor(operand.iterator, acc.layout) + operand_as_acc.store(acc.load().to(element_type)) + return operand + def _as_cute(tensor: torch.Tensor): return from_dlpack(tensor, assumed_align=16) From 7bda1289cc8e6537b359efca6b2089fc73b99ead Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Thu, 14 May 2026 19:05:59 +0800 Subject: [PATCH 10/28] adjust dh store side path --- cula/ops/chunk_delta_h_bwd.py | 52 +++++++++++++++++------------------ 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 67d6d30..182c75f 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -750,32 +750,6 @@ def kernel( else: g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) - # Publish the current reverse state as dh output. K@dH consumes the - # same register-carried state through the RMEM WGMMA path below. - if is_compute_warp: - dh_h = store_dh_P.acquire_and_advance() - for k_block in cutlass.range_constexpr(self.num_k_blocks): - rState = rStates[k_block] - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - state_bf16 = rState[ei].to(self.io_dtype) - if cutlass.const_expr(self.transpose_state_layout): - sDh[k_rel, v_rel, dh_h.index] = state_bf16 - else: - sDh[v_rel, k_rel, dh_h.index] = state_bf16 - cute.arch.fence_proxy("async.shared", space="cta") - dh_h.commit() - - if warp_idx == self.store_warp_id: - dh_h = store_dh_C.wait_and_advance() - if cutlass.const_expr(self.transpose_state_layout): - cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, 0, i_v_tile, i_t)]) - else: - cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - dh_h.release() - if warp_idx == self.load_warp_id and next_i_t >= 0: k_h = load_k_P.acquire_and_advance() cute.copy(tma_atom_k, bSG_gK[(None, next_i_t, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) @@ -808,7 +782,6 @@ def kernel( acc_dv.fill(0.0) for k_block in cutlass.range_constexpr(self.num_k_blocks): k_wait = load_k_C.wait_and_advance() - cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) rState_op = self.make_acc_into_op(rStates[k_block], tiled_mma.tv_layout_A, self.io_dtype) cute.nvgpu.warpgroup.fence() for kp in cutlass.range(cute.size(tKrB, mode=[2]), unroll_full=True): @@ -906,6 +879,22 @@ def kernel( ) cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(k_block == 0): + # QDO does not read rState, so publish the old reverse + # state to the dh store side path while QDO is in flight. + dh_h = store_dh_P.acquire_and_advance() + for state_block in cutlass.range_constexpr(self.num_k_blocks): + state = rStates[state_block] + for ei in cutlass.range(cute.size(state), unroll_full=True): + v_rel, k_rel = tUcState[ei] + state_bf16 = state[ei].to(self.io_dtype) + if cutlass.const_expr(self.transpose_state_layout): + sDh[k_rel, v_rel, dh_h.index] = state_bf16 + else: + sDh[v_rel, k_rel, dh_h.index] = state_bf16 + cute.arch.fence_proxy("async.shared", space="cta") + dh_h.commit() + # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. if cutlass.const_expr(self.use_g): for ei in cutlass.range(cute.size(rState), unroll_full=True): @@ -974,6 +963,15 @@ def kernel( dv2_done_h.commit() dv2_store_h.release() + dh_h = store_dh_C.wait_and_advance() + if cutlass.const_expr(self.transpose_state_layout): + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, 0, i_v_tile, i_t)]) + else: + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh_h.release() + if cutlass.const_expr(self.use_dh0): if is_compute_warp: for k_block in cutlass.range_constexpr(self.num_k_blocks): From 87336c71e973c58bc47dc42b34f47c5cac504141 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Thu, 14 May 2026 21:56:30 +0800 Subject: [PATCH 11/28] trans and not trans tRS seperatly --- cula/ops/chunk_delta_h_bwd.py | 43 +++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 182c75f..4d7ffbd 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -696,6 +696,26 @@ def kernel( rStates = (rState0, rState1, rState2, rState3) acc_qdo = update_thr_mma.make_fragment_C(state_shape) acc_wdv = update_thr_mma.make_fragment_C(state_shape) + if cutlass.const_expr(not self.transpose_state_layout): + dh_copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + utils.LayoutEnum.COL_MAJOR, + elem_ty_d=self.io_dtype, + elem_ty_acc=self.acc_dtype, + ) + dh_copy_atom = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + utils.LayoutEnum.COL_MAJOR.is_m_major_c(), + 4, + ), + self.io_dtype, + ) + tiled_copy_dh_atom = cute.make_tiled_copy_C_atom(dh_copy_atom, update_tiled_mma) + tiled_copy_dh_r2s = cute.make_tiled_copy_S(dh_copy_atom_r2s, tiled_copy_dh_atom) + thr_copy_dh_r2s = tiled_copy_dh_r2s.get_slice(local_tidx) + tRS_sDh = thr_copy_dh_r2s.partition_D(sDh) + rDh_shape = cute.shape(thr_copy_dh_r2s.partition_S(sDh)) + tRS_rDh_layout = cute.make_layout(rDh_shape[:3]) + tRS_rDh_out = cute.make_rmem_tensor_like(tRS_rDh_layout, self.io_dtype) # Initialize carried dh state in register blocks. if is_compute_warp: @@ -883,15 +903,20 @@ def kernel( # QDO does not read rState, so publish the old reverse # state to the dh store side path while QDO is in flight. dh_h = store_dh_P.acquire_and_advance() - for state_block in cutlass.range_constexpr(self.num_k_blocks): - state = rStates[state_block] - for ei in cutlass.range(cute.size(state), unroll_full=True): - v_rel, k_rel = tUcState[ei] - state_bf16 = state[ei].to(self.io_dtype) - if cutlass.const_expr(self.transpose_state_layout): - sDh[k_rel, v_rel, dh_h.index] = state_bf16 - else: - sDh[v_rel, k_rel, dh_h.index] = state_bf16 + if cutlass.const_expr(self.transpose_state_layout): + for state_block in cutlass.range_constexpr(self.num_k_blocks): + state = rStates[state_block] + for ei in cutlass.range(cute.size(state), unroll_full=True): + v_rel, k_rel = tUcState[ei] + sDh[k_rel, v_rel, dh_h.index] = state[ei].to(self.io_dtype) + else: + tRS_rState = tiled_copy_dh_r2s.retile(rState) + tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) + cute.copy( + tiled_copy_dh_r2s, + tRS_rDh_out, + tRS_sDh[(None, None, None, dh_h.index)], + ) cute.arch.fence_proxy("async.shared", space="cta") dh_h.commit() From 11bb357eb34f73526086b5b43ee335d3f9ee9a30 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Thu, 14 May 2026 22:03:52 +0800 Subject: [PATCH 12/28] trans and not trans use same tRS --- cula/ops/chunk_delta_h_bwd.py | 91 +++++++++++++++-------------------- 1 file changed, 40 insertions(+), 51 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 4d7ffbd..0a55f2f 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -220,16 +220,18 @@ def __call__( dh = cute.make_tensor(dh_ptr, state_layout) if cutlass.const_expr(self.transpose_state_layout): dh_tma_layout = cute.make_layout( - (self.K, self.V, (NT_total, self.H, self.B)), - stride=(1, self.K, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), + (self.V, self.K, (NT_total, self.H, self.B)), + stride=(self.K, 1, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), ) - dh_tma_tile = (self.BK, self.BV) else: dh_tma_layout = cute.make_layout( (self.V, self.K, (NT_total, self.H, self.B)), stride=(1, self.V, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), ) - dh_tma_tile = (self.BV, self.BK) + dh_tma_tile = (self.BV, self.BK) + dh_smem_layout_enum = ( + utils.LayoutEnum.ROW_MAJOR if cutlass.const_expr(self.transpose_state_layout) else utils.LayoutEnum.COL_MAJOR + ) dh_tma = cute.make_tensor(dh_ptr, dh_tma_layout) if cutlass.const_expr(self.transpose_state_layout): @@ -309,7 +311,7 @@ def __call__( ) dh_smem_layout_staged = sm90_utils.make_smem_layout_epi( self.io_dtype, - utils.LayoutEnum.COL_MAJOR, + dh_smem_layout_enum, dh_tma_tile, self.dh_store_stage, ) @@ -643,14 +645,9 @@ def kernel( _, bSG_sGK, bSG_gGK = self._epilog_partition( tma_atom_gk, tma_tensor_gk_use[None, None, (i_h, data_b)], (self.BK, 1), sGK ) - if cutlass.const_expr(self.transpose_state_layout): - _, bSG_sDh, bSG_gDh = self._epilog_partition( - tma_atom_dh, tma_tensor_dh_use[None, None, (None, i_h, state_b)], (self.BK, self.BV), sDh - ) - else: - _, bSG_sDh, bSG_gDh = self._epilog_partition( - tma_atom_dh, tma_tensor_dh_use[None, None, (None, i_h, state_b)], (self.BV, self.BK), sDh - ) + _, bSG_sDh, bSG_gDh = self._epilog_partition( + tma_atom_dh, tma_tensor_dh_use[None, None, (None, i_h, state_b)], (self.BV, self.BK), sDh + ) _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( tma_atom_dv2, tma_tensor_dv2_use[None, None, (i_h, data_b)], (self.BV, self.BT), sUA ) @@ -696,26 +693,28 @@ def kernel( rStates = (rState0, rState1, rState2, rState3) acc_qdo = update_thr_mma.make_fragment_C(state_shape) acc_wdv = update_thr_mma.make_fragment_C(state_shape) - if cutlass.const_expr(not self.transpose_state_layout): - dh_copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( - utils.LayoutEnum.COL_MAJOR, - elem_ty_d=self.io_dtype, - elem_ty_acc=self.acc_dtype, - ) - dh_copy_atom = cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp( - utils.LayoutEnum.COL_MAJOR.is_m_major_c(), - 4, - ), - self.io_dtype, - ) - tiled_copy_dh_atom = cute.make_tiled_copy_C_atom(dh_copy_atom, update_tiled_mma) - tiled_copy_dh_r2s = cute.make_tiled_copy_S(dh_copy_atom_r2s, tiled_copy_dh_atom) - thr_copy_dh_r2s = tiled_copy_dh_r2s.get_slice(local_tidx) - tRS_sDh = thr_copy_dh_r2s.partition_D(sDh) - rDh_shape = cute.shape(thr_copy_dh_r2s.partition_S(sDh)) - tRS_rDh_layout = cute.make_layout(rDh_shape[:3]) - tRS_rDh_out = cute.make_rmem_tensor_like(tRS_rDh_layout, self.io_dtype) + dh_smem_layout_enum = ( + utils.LayoutEnum.ROW_MAJOR if cutlass.const_expr(self.transpose_state_layout) else utils.LayoutEnum.COL_MAJOR + ) + dh_copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + dh_smem_layout_enum, + elem_ty_d=self.io_dtype, + elem_ty_acc=self.acc_dtype, + ) + dh_copy_atom = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + dh_smem_layout_enum.is_m_major_c(), + 4, + ), + self.io_dtype, + ) + tiled_copy_dh_atom = cute.make_tiled_copy_C_atom(dh_copy_atom, update_tiled_mma) + tiled_copy_dh_r2s = cute.make_tiled_copy_S(dh_copy_atom_r2s, tiled_copy_dh_atom) + thr_copy_dh_r2s = tiled_copy_dh_r2s.get_slice(local_tidx) + tRS_sDh = thr_copy_dh_r2s.partition_D(sDh) + rDh_shape = cute.shape(thr_copy_dh_r2s.partition_S(sDh)) + tRS_rDh_layout = cute.make_layout(rDh_shape[:3]) + tRS_rDh_out = cute.make_rmem_tensor_like(tRS_rDh_layout, self.io_dtype) # Initialize carried dh state in register blocks. if is_compute_warp: @@ -903,20 +902,13 @@ def kernel( # QDO does not read rState, so publish the old reverse # state to the dh store side path while QDO is in flight. dh_h = store_dh_P.acquire_and_advance() - if cutlass.const_expr(self.transpose_state_layout): - for state_block in cutlass.range_constexpr(self.num_k_blocks): - state = rStates[state_block] - for ei in cutlass.range(cute.size(state), unroll_full=True): - v_rel, k_rel = tUcState[ei] - sDh[k_rel, v_rel, dh_h.index] = state[ei].to(self.io_dtype) - else: - tRS_rState = tiled_copy_dh_r2s.retile(rState) - tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) - cute.copy( - tiled_copy_dh_r2s, - tRS_rDh_out, - tRS_sDh[(None, None, None, dh_h.index)], - ) + tRS_rState = tiled_copy_dh_r2s.retile(rState) + tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) + cute.copy( + tiled_copy_dh_r2s, + tRS_rDh_out, + tRS_sDh[(None, None, None, dh_h.index)], + ) cute.arch.fence_proxy("async.shared", space="cta") dh_h.commit() @@ -989,10 +981,7 @@ def kernel( dv2_store_h.release() dh_h = store_dh_C.wait_and_advance() - if cutlass.const_expr(self.transpose_state_layout): - cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, 0, i_v_tile, i_t)]) - else: - cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) dh_h.release() From 700b1a48447f52b07d31d268367342d0cd0a1e2f Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Fri, 15 May 2026 16:31:12 +0800 Subject: [PATCH 13/28] format --- benchmarks/bench_chunk_delta_h_bwd_sm90.py | 9 ++-- tests/test_chunk_delta_h_bwd_sm90.py | 61 ++++++++++------------ 2 files changed, 33 insertions(+), 37 deletions(-) diff --git a/benchmarks/bench_chunk_delta_h_bwd_sm90.py b/benchmarks/bench_chunk_delta_h_bwd_sm90.py index 33d1ceb..ccd4ae9 100644 --- a/benchmarks/bench_chunk_delta_h_bwd_sm90.py +++ b/benchmarks/bench_chunk_delta_h_bwd_sm90.py @@ -15,7 +15,6 @@ Usage: python benchmarks/bench_chunk_delta_h_bwd_sm90.py --mode both - python benchmarks/bench_chunk_delta_h_bwd_sm90.py --preset fwd --mode non-varlen python benchmarks/bench_chunk_delta_h_bwd_sm90.py --preset focused --mode non-varlen """ @@ -47,8 +46,8 @@ dtype = torch.bfloat16 device = "cuda" -WARMUP = 5 -N_ITERS = 30 +WARMUP = 10 +N_ITERS = 100 NCU_MODE = False @@ -493,9 +492,9 @@ def main(): parser.add_argument( "--preset", type=str, - default="representative", + default="fwd", choices=["representative", "fwd", "focused"], - help="representative runs a short subset; fwd mirrors bench_chunk_delta_h.py; focused runs the long non-varlen matrix", + help="fwd mirrors bench_chunk_delta_h.py; representative runs a short subset; focused runs the long non-varlen matrix", ) parser.add_argument( "--feature-mode", diff --git a/tests/test_chunk_delta_h_bwd_sm90.py b/tests/test_chunk_delta_h_bwd_sm90.py index d731161..658168d 100644 --- a/tests/test_chunk_delta_h_bwd_sm90.py +++ b/tests/test_chunk_delta_h_bwd_sm90.py @@ -2,11 +2,10 @@ # Copyright 2025-2026 Ant Group Co., Ltd. # SPDX-License-Identifier: Apache-2.0 -"""Representative correctness tests for the SM90 CuTe DSL WGMMA bwd_dhu path. +"""Correctness tests for the SM90 CuTe DSL WGMMA bwd_dhu path. -These cases follow the same logic as tests/test_chunk_delta_h.py but avoid the -full Cartesian sweep during kernel iteration. For bwd_dhu, fwd's -initial_state/output_final_state pair maps to dht/dh0. +These cases follow tests/test_chunk_delta_h.py where the backward API permits. +For bwd_dhu, fwd's initial_state/output_final_state pair maps to dht/dh0. """ import os @@ -193,27 +192,29 @@ def _make_varlen_inputs( return q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens -@pytest.mark.parametrize( - "case", - [ - dict(B=1, T=64, H=1, K=128, V=128, use_gk=False, use_state=False), - dict(B=1, T=128, H=4, K=128, V=128, use_gk=True, use_state=False), - dict(B=2, T=256, H=4, K=128, V=128, use_gk=True, use_state=True), - dict(B=1, T=1024, H=64, K=128, V=128, use_gk=True, use_state=False), - ], - ids=["minimal", "multihead-gk", "batch-state", "long-h64"], -) -def test_dhu_against_fla(case): - B, T, H, K, V = case["B"], case["T"], case["H"], case["K"], case["V"] - use_gk, use_state = case["use_gk"], case["use_state"] +@pytest.mark.parametrize("B", [1, 2]) +@pytest.mark.parametrize("H", [1, 4]) +@pytest.mark.parametrize("T", [64, 128, 256]) +@pytest.mark.parametrize("K", [128]) +@pytest.mark.parametrize("V", [128]) +@pytest.mark.parametrize("use_gk", [False, True]) +@pytest.mark.parametrize("use_state", [False, True]) +def test_dhu_against_fla(B, H, T, K, V, use_gk, use_state): q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs(B, T, H, K, V, use_gk=use_gk, use_state=use_state) ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) _assert_bwd_close(got, ref, use_state, f"B={B} H={H} T={T} gk={use_gk} state={use_state}") -def test_dv2_no_gating(): - B, T, H, K, V = 4, 512, 4, 128, 128 +@pytest.mark.parametrize( + "B,T,H,K,V", + [ + (1, 64, 1, 128, 128), + (2, 128, 4, 128, 128), + (4, 512, 4, 128, 128), + ], +) +def test_dv2_no_gating(B, T, H, K, V): q, k, w, do, dv, g, gk, dht, dh0 = _make_inputs(B, T, H, K, V) ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0) @@ -221,23 +222,19 @@ def test_dv2_no_gating(): @pytest.mark.parametrize( - "case", + "seq_lens", [ - dict(seq_lens=[50, 192, 100], H=2, use_g=False, use_gk=True, use_state=False), - dict(seq_lens=[33, 128, 200, 95], H=1, use_g=True, use_gk=False, use_state=True), + [128, 128], + [50, 192, 100], + [33, 128, 200, 95], ], - ids=["gk-dht", "g-dh0"], ) -def test_varlen_against_fla(case): +@pytest.mark.parametrize("H", [1, 4]) +@pytest.mark.parametrize("use_gk", [False, True]) +@pytest.mark.parametrize("use_state", [False, True]) +def test_varlen_against_fla(seq_lens, H, use_gk, use_state): K, V = 128, 128 - seq_lens = case["seq_lens"] - H = case["H"] - use_g = case["use_g"] - use_gk = case["use_gk"] - use_state = case["use_state"] - q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens = _make_varlen_inputs( - seq_lens, H, K, V, use_g=use_g, use_gk=use_gk, use_state=use_state - ) + q, k, w, do, dv, g, gk, dht, dh0, cu_seqlens = _make_varlen_inputs(seq_lens, H, K, V, use_gk=use_gk, use_state=use_state) ref = run_fla_ref(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, cu_seqlens=cu_seqlens) got = run_cute_dsl(q, k, w, do, dv, g=g, gk=gk, dht=dht, dh0=dh0, cu_seqlens=cu_seqlens) _assert_bwd_close(got, ref, use_state, f"varlen seqs={seq_lens} H={H} gk={use_gk} state={use_state}") From ddf074f37e3c108c522ff45a69c7ee11039642be Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Sat, 16 May 2026 08:00:57 +0800 Subject: [PATCH 14/28] fix atol and rtol to 1e-2 --- tests/test_chunk_delta_h_bwd_sm90.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_chunk_delta_h_bwd_sm90.py b/tests/test_chunk_delta_h_bwd_sm90.py index 658168d..7324ef2 100644 --- a/tests/test_chunk_delta_h_bwd_sm90.py +++ b/tests/test_chunk_delta_h_bwd_sm90.py @@ -21,8 +21,8 @@ from cula.ops.chunk_delta_h_bwd import chunk_gated_delta_rule_bwd_dhu_sm90 BT = 64 -ATOL = 3e-2 -RTOL = 3e-2 +ATOL = 1e-2 +RTOL = 1e-2 device = "cuda" From 192d3b38b64a90bb70a8a2e9ff8408a08d253f96 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Sat, 16 May 2026 18:45:52 +0800 Subject: [PATCH 15/28] dh publish earlier --- cula/ops/chunk_delta_h_bwd.py | 42 +++++++++++++++++------------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 0a55f2f..c3469d0 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -801,7 +801,8 @@ def kernel( acc_dv.fill(0.0) for k_block in cutlass.range_constexpr(self.num_k_blocks): k_wait = load_k_C.wait_and_advance() - rState_op = self.make_acc_into_op(rStates[k_block], tiled_mma.tv_layout_A, self.io_dtype) + rState = rStates[k_block] + rState_op = self.make_acc_into_op(rState, tiled_mma.tv_layout_A, self.io_dtype) cute.nvgpu.warpgroup.fence() for kp in cutlass.range(cute.size(tKrB, mode=[2]), unroll_full=True): tiled_mma.set( @@ -816,6 +817,19 @@ def kernel( acc_dv, ) cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(k_block == 0): + # KdH consumes a copied RMEM operand, so publish the old + # reverse state while KdH is still in flight. + dh_h = store_dh_P.acquire_and_advance() + tRS_rState = tiled_copy_dh_r2s.retile(rState) + tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) + cute.copy( + tiled_copy_dh_r2s, + tRS_rDh_out, + tRS_sDh[(None, None, None, dh_h.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + dh_h.commit() cute.nvgpu.warpgroup.wait_group(0) k_wait.release() @@ -898,20 +912,6 @@ def kernel( ) cute.nvgpu.warpgroup.commit_group() - if cutlass.const_expr(k_block == 0): - # QDO does not read rState, so publish the old reverse - # state to the dh store side path while QDO is in flight. - dh_h = store_dh_P.acquire_and_advance() - tRS_rState = tiled_copy_dh_r2s.retile(rState) - tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) - cute.copy( - tiled_copy_dh_r2s, - tRS_rDh_out, - tRS_sDh[(None, None, None, dh_h.index)], - ) - cute.arch.fence_proxy("async.shared", space="cta") - dh_h.commit() - # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. if cutlass.const_expr(self.use_g): for ei in cutlass.range(cute.size(rState), unroll_full=True): @@ -963,6 +963,12 @@ def kernel( dv_wait.release() if warp_idx == self.store_warp_id: + dh_h = store_dh_C.wait_and_advance() + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh_h.release() + dv2_store_h = store_dv2_C.wait_and_advance() dv2_done_h = store_dv2_done_P.acquire_and_advance() # One done token is committed per chunk. Tail chunks skip TMA @@ -980,12 +986,6 @@ def kernel( dv2_done_h.commit() dv2_store_h.release() - dh_h = store_dh_C.wait_and_advance() - cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - dh_h.release() - if cutlass.const_expr(self.use_dh0): if is_compute_warp: for k_block in cutlass.range_constexpr(self.num_k_blocks): From dfbfdee86b89c5e6cb6fd8900a97309313d76002 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Sun, 17 May 2026 08:51:03 +0800 Subject: [PATCH 16/28] add input stage=3 --- cula/ops/chunk_delta_h_bwd.py | 39 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index c3469d0..e853cab 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -114,7 +114,7 @@ def __init__( self.num_threads = NUM_THREADS self.num_regs_compute = 232 self.num_regs_other = 40 - self.input_stage = 2 + self.input_stage = 3 self.dh_store_stage = 2 self.dv2_store_stage = 2 self.io_dtype = cutlass.BFloat16 @@ -817,19 +817,6 @@ def kernel( acc_dv, ) cute.nvgpu.warpgroup.commit_group() - if cutlass.const_expr(k_block == 0): - # KdH consumes a copied RMEM operand, so publish the old - # reverse state while KdH is still in flight. - dh_h = store_dh_P.acquire_and_advance() - tRS_rState = tiled_copy_dh_r2s.retile(rState) - tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) - cute.copy( - tiled_copy_dh_r2s, - tRS_rDh_out, - tRS_sDh[(None, None, None, dh_h.index)], - ) - cute.arch.fence_proxy("async.shared", space="cta") - dh_h.commit() cute.nvgpu.warpgroup.wait_group(0) k_wait.release() @@ -912,6 +899,18 @@ def kernel( ) cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(k_block == 0): + dh_h = store_dh_P.acquire_and_advance() + tRS_rState = tiled_copy_dh_r2s.retile(rState) + tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) + cute.copy( + tiled_copy_dh_r2s, + tRS_rDh_out, + tRS_sDh[(None, None, None, dh_h.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + dh_h.commit() + # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. if cutlass.const_expr(self.use_g): for ei in cutlass.range(cute.size(rState), unroll_full=True): @@ -963,12 +962,6 @@ def kernel( dv_wait.release() if warp_idx == self.store_warp_id: - dh_h = store_dh_C.wait_and_advance() - cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - dh_h.release() - dv2_store_h = store_dv2_C.wait_and_advance() dv2_done_h = store_dv2_done_P.acquire_and_advance() # One done token is committed per chunk. Tail chunks skip TMA @@ -986,6 +979,12 @@ def kernel( dv2_done_h.commit() dv2_store_h.release() + dh_h = store_dh_C.wait_and_advance() + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh_h.release() + if cutlass.const_expr(self.use_dh0): if is_compute_warp: for k_block in cutlass.range_constexpr(self.num_k_blocks): From 94e4783aa05172a70d29ef5b10cb4f2fdf684f7f Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Sun, 17 May 2026 09:10:58 +0800 Subject: [PATCH 17/28] optimize g --- cula/ops/chunk_delta_h_bwd.py | 90 ++++++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 23 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index e853cab..7b3eef5 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -286,6 +286,16 @@ def __call__( self.atom_layout_mnk, self.update_mma_tiler[:2], ) + qdo_tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.io_dtype, + self.io_dtype, + utils.LayoutEnum.COL_MAJOR.sm90_mma_major_mode(), + utils.LayoutEnum.COL_MAJOR.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + self.update_mma_tiler[:2], + warpgroup.OperandSource.RMEM, + ) a_smem_layout_staged = sm90_utils.make_smem_layout_b( utils.LayoutEnum.ROW_MAJOR, @@ -401,6 +411,10 @@ class SharedStorage: cute.struct.MemRange[cutlass.Float32, BK * self.input_stage], 128, ] + sG: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, BT * 2], + 128, + ] sUB: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(update_b_smem_layout_staged)], self.buffer_align_bytes, @@ -432,6 +446,7 @@ class SharedStorage: chunk_offsets, tiled_mma, update_tiled_mma, + qdo_tiled_mma, a_smem_layout_staged, update_a_smem_layout_staged, update_b_smem_layout_staged, @@ -478,6 +493,7 @@ def kernel( chunk_offsets: cute.Tensor, tiled_mma: cute.TiledMma, update_tiled_mma: cute.TiledMma, + qdo_tiled_mma: cute.TiledMma, a_smem_layout_staged: cute.ComposedLayout, update_a_smem_layout_staged: cute.ComposedLayout, update_b_smem_layout_staged: cute.ComposedLayout, @@ -527,6 +543,7 @@ def kernel( sDv2Stage = storage.sDv2Stage.get_tensor(cute.make_layout((self.dv2_store_stage,))) sDo = storage.sDo.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) sGK = storage.sGK.get_tensor(cute.make_layout((BK, 1, self.input_stage), stride=(1, BK, BK))) + sG = storage.sG.get_tensor(cute.make_layout((BT, 2), stride=(1, BT))) sUB = storage.sUB.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) sW = storage.sW.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) sDh = storage.sDh.get_tensor(dh_smem_layout_staged.outer, swizzle=dh_smem_layout_staged.inner) @@ -661,6 +678,7 @@ def kernel( thr_mma = tiled_mma.get_slice(local_tidx) update_thr_mma = update_tiled_mma.get_slice(local_tidx) + qdo_thr_mma = qdo_tiled_mma.get_slice(local_tidx) tKsB = thr_mma.partition_B(sA) tKrB = thr_mma.make_fragment_B(tKsB) @@ -668,10 +686,12 @@ def kernel( tUsDo = update_thr_mma.partition_A(sDo) tUsB = update_thr_mma.partition_B(sUB) tWsB = update_thr_mma.partition_B(sW) + qdo_tUsB = qdo_thr_mma.partition_B(sUB) tUrA = update_thr_mma.make_fragment_A(tUsA) tUrDo = update_thr_mma.make_fragment_A(tUsDo) tUrB = update_thr_mma.make_fragment_B(tUsB) tWrB = update_thr_mma.make_fragment_B(tWsB) + qdo_tUrB = qdo_thr_mma.make_fragment_B(qdo_tUsB) cDV = cute.make_identity_tensor((BV, BT)) tCcDV = thr_mma.partition_C(cDV) @@ -817,6 +837,21 @@ def kernel( acc_dv, ) cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(self.use_g): + if local_tidx < self.BT: + t_idx = chunk_start + local_tidx + g_decay = Float32(0.0) + g_exp = Float32(0.0) + if t_idx < seq_len: + g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_decay = cute.exp2(g_last - g_cur, fastmath=self.use_fast_math) + g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) + else: + g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) + g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) + sG[local_tidx, 0] = g_decay + sG[local_tidx, 1] = g_exp cute.nvgpu.warpgroup.wait_group(0) k_wait.release() @@ -833,12 +868,7 @@ def kernel( if t_idx < seq_len: out = acc_dv[ei] if cutlass.const_expr(self.use_g): - g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - g_decay = cute.exp2(g_last - g_cur, fastmath=self.use_fast_math) - else: - g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) - out = out * g_decay + out = out * sG[t_rel, 0] out = out + sUA[v_rel, t_rel, dv_stage].to(self.acc_dtype) out_bf16 = out.to(self.io_dtype) sUA[v_rel, t_rel, dv_stage] = out_bf16 @@ -861,7 +891,17 @@ def kernel( # dh += scale * do^T @ q - dv2^T @ w. do_wait = load_do_C.wait_and_advance() - if cutlass.const_expr(self.use_g or self.is_varlen): + if cutlass.const_expr(self.use_g): + for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): + v_rel, t_rel = tCcDV[ei] + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) * sG[t_rel, 1] + acc_dv[ei] = do_scaled + rDo_op = self.make_acc_into_op(acc_dv, qdo_tiled_mma.tv_layout_A, self.io_dtype) + do_wait.release() + if cutlass.const_expr((not self.use_g) and self.is_varlen): linear_do = local_tidx while linear_do < self.BV * self.BT: v_rel = linear_do // self.BT @@ -870,13 +910,6 @@ def kernel( do_scaled = Float32(0.0) if t_idx < seq_len: do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) - if cutlass.const_expr(self.use_g): - g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) - else: - g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) - do_scaled = do_scaled * g_exp sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) linear_do += self.num_compute_threads cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) @@ -889,14 +922,24 @@ def kernel( acc_qdo.fill(0.0) cute.nvgpu.warpgroup.fence() for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): - update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - update_tiled_mma, - acc_qdo, - tUrDo[None, None, kp, do_wait.index], - tUrB[None, None, kp, q_wait.index], - acc_qdo, - ) + if cutlass.const_expr(self.use_g): + qdo_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + qdo_tiled_mma, + acc_qdo, + rDo_op[None, None, kp], + qdo_tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) + else: + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_qdo, + tUrDo[None, None, kp, do_wait.index], + tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) cute.nvgpu.warpgroup.commit_group() if cutlass.const_expr(k_block == 0): @@ -951,7 +994,8 @@ def kernel( update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] rState[ei] = rState[ei] + update w_wait.release() - do_wait.release() + if cutlass.const_expr(not self.use_g): + do_wait.release() # sUA[dv_stage] has three consumers after dv2 is written: # direct tail stores (above), W^T@dv2 (completed here), and # full-tile TMA stores by the store warp. Keep the load_dv stage From 5a10d44d12d2086495fc1e7c459c197fe8e7f9ed Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Sun, 17 May 2026 18:16:33 +0800 Subject: [PATCH 18/28] fix regs --- cula/ops/chunk_delta_h_bwd.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 7b3eef5..e3bfe6d 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -678,20 +678,22 @@ def kernel( thr_mma = tiled_mma.get_slice(local_tidx) update_thr_mma = update_tiled_mma.get_slice(local_tidx) - qdo_thr_mma = qdo_tiled_mma.get_slice(local_tidx) tKsB = thr_mma.partition_B(sA) tKrB = thr_mma.make_fragment_B(tKsB) tUsA = update_thr_mma.partition_A(sUA) - tUsDo = update_thr_mma.partition_A(sDo) tUsB = update_thr_mma.partition_B(sUB) tWsB = update_thr_mma.partition_B(sW) - qdo_tUsB = qdo_thr_mma.partition_B(sUB) tUrA = update_thr_mma.make_fragment_A(tUsA) - tUrDo = update_thr_mma.make_fragment_A(tUsDo) tUrB = update_thr_mma.make_fragment_B(tUsB) tWrB = update_thr_mma.make_fragment_B(tWsB) - qdo_tUrB = qdo_thr_mma.make_fragment_B(qdo_tUsB) + if cutlass.const_expr(self.use_g): + qdo_thr_mma = qdo_tiled_mma.get_slice(local_tidx) + qdo_tUsB = qdo_thr_mma.partition_B(sUB) + qdo_tUrB = qdo_thr_mma.make_fragment_B(qdo_tUsB) + else: + tUsDo = update_thr_mma.partition_A(sDo) + tUrDo = update_thr_mma.make_fragment_A(tUsDo) cDV = cute.make_identity_tensor((BV, BT)) tCcDV = thr_mma.partition_C(cDV) @@ -921,8 +923,8 @@ def kernel( gk_wait = load_gk_C.wait_and_advance() acc_qdo.fill(0.0) cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): - if cutlass.const_expr(self.use_g): + if cutlass.const_expr(self.use_g): + for kp in cutlass.range(cute.size(qdo_tUrB, mode=[2]), unroll_full=True): qdo_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) cute.gemm( qdo_tiled_mma, @@ -931,7 +933,8 @@ def kernel( qdo_tUrB[None, None, kp, q_wait.index], acc_qdo, ) - else: + else: + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) cute.gemm( update_tiled_mma, From 35866e18f9239b5d4c415fc31e0d6ad1e7944bee Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Mon, 18 May 2026 09:01:28 +0800 Subject: [PATCH 19/28] mv dh to front --- cula/ops/chunk_delta_h_bwd.py | 37 ++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index e3bfe6d..cb435ab 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -820,6 +820,19 @@ def kernel( # dv2 = dv + K @ dh. Compute the equivalent (dh @ K^T) tile so the # register-carried state can feed WGMMA as an RMEM A operand. if is_compute_warp: + # Match chunk_delta_h.py's h_out overlap: publish the carried + # state to the store pipeline before the chunk GEMM chain. + dh_h = store_dh_P.acquire_and_advance() + tRS_rState = tiled_copy_dh_r2s.retile(rStates[0]) + tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) + cute.copy( + tiled_copy_dh_r2s, + tRS_rDh_out, + tRS_sDh[(None, None, None, dh_h.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + dh_h.commit() + acc_dv.fill(0.0) for k_block in cutlass.range_constexpr(self.num_k_blocks): k_wait = load_k_C.wait_and_advance() @@ -945,18 +958,6 @@ def kernel( ) cute.nvgpu.warpgroup.commit_group() - if cutlass.const_expr(k_block == 0): - dh_h = store_dh_P.acquire_and_advance() - tRS_rState = tiled_copy_dh_r2s.retile(rState) - tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) - cute.copy( - tiled_copy_dh_r2s, - tRS_rDh_out, - tRS_sDh[(None, None, None, dh_h.index)], - ) - cute.arch.fence_proxy("async.shared", space="cta") - dh_h.commit() - # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. if cutlass.const_expr(self.use_g): for ei in cutlass.range(cute.size(rState), unroll_full=True): @@ -1009,6 +1010,12 @@ def kernel( dv_wait.release() if warp_idx == self.store_warp_id: + dh_h = store_dh_C.wait_and_advance() + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh_h.release() + dv2_store_h = store_dv2_C.wait_and_advance() dv2_done_h = store_dv2_done_P.acquire_and_advance() # One done token is committed per chunk. Tail chunks skip TMA @@ -1026,12 +1033,6 @@ def kernel( dv2_done_h.commit() dv2_store_h.release() - dh_h = store_dh_C.wait_and_advance() - cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - dh_h.release() - if cutlass.const_expr(self.use_dh0): if is_compute_warp: for k_block in cutlass.range_constexpr(self.num_k_blocks): From 3375d277f074a97566e3f0a956c55c71a30d1ae4 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Mon, 18 May 2026 09:55:59 +0800 Subject: [PATCH 20/28] Optimize SM90 bwd_dhu pipeline overlap and dv2 store buffering --- cula/ops/chunk_delta_h_bwd.py | 473 ++++++++++++++++++++-------------- 1 file changed, 286 insertions(+), 187 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index cb435ab..1d815d0 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -54,13 +54,48 @@ BT = 64 BV = 64 BK = 128 -NUM_THREADS = 192 +NUM_THREADS = 224 +_DUMMY_TENSOR_CACHE_MAX = 32 +_dummy_tensor_cache: dict[tuple, torch.Tensor] = {} +_nonvarlen_metadata_cache: dict[tuple, tuple[torch.Tensor, torch.Tensor]] = {} def make_thread_cooperative_group(size: int): return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) +def _device_key(device: torch.device) -> tuple[str, int | None]: + device = torch.device(device) + index = device.index + if device.type == "cuda" and index is None: + index = torch.cuda.current_device() + return device.type, index + + +def _cached_empty(shape: tuple[int, ...], *, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + key = (_device_key(device), dtype, tuple(int(x) for x in shape)) + tensor = _dummy_tensor_cache.get(key) + if tensor is None: + if len(_dummy_tensor_cache) >= _DUMMY_TENSOR_CACHE_MAX: + _dummy_tensor_cache.clear() + tensor = torch.empty(shape, device=device, dtype=dtype) + _dummy_tensor_cache[key] = tensor + return tensor + + +def _cached_nonvarlen_metadata(B: int, T: int, NT: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + key = (_device_key(device), int(B), int(T), int(NT)) + metadata = _nonvarlen_metadata_cache.get(key) + if metadata is None: + if len(_nonvarlen_metadata_cache) >= _DUMMY_TENSOR_CACHE_MAX: + _nonvarlen_metadata_cache.clear() + cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32) * T + chunk_offsets = torch.arange(B + 1, device=device, dtype=torch.int32) * NT + metadata = (cu_seqlens, chunk_offsets) + _nonvarlen_metadata_cache[key] = metadata + return metadata + + class ChunkDeltaRuleBwdDHUSm90: def __init__( self, @@ -110,16 +145,22 @@ def __init__( self.num_compute_warps = 4 self.num_compute_threads = self.threads_per_warp * self.num_compute_warps self.load_warp_id = 4 - self.store_warp_id = 5 + self.load_current_warp_id = 5 + self.store_warp_id = 6 self.num_threads = NUM_THREADS self.num_regs_compute = 232 self.num_regs_other = 40 - self.input_stage = 3 + self.k_stage = 3 + self.dv_stage = 2 + self.do_stage = 2 + self.q_stage = 3 + self.w_stage = 3 + self.gk_stage = 3 self.dh_store_stage = 2 self.dv2_store_stage = 2 self.io_dtype = cutlass.BFloat16 self.acc_dtype = cutlass.Float32 - self.buffer_align_bytes = 1024 + self.buffer_align_bytes = 128 self.mma_tiler = (BT, BV, self.BK) self.kdh_mma_tiler = (BV, BT, self.BK) @@ -297,26 +338,44 @@ def __call__( warpgroup.OperandSource.RMEM, ) - a_smem_layout_staged = sm90_utils.make_smem_layout_b( + k_smem_layout_staged = sm90_utils.make_smem_layout_b( utils.LayoutEnum.ROW_MAJOR, self.kdh_mma_tiler, self.io_dtype, - self.input_stage, + self.k_stage, + ) + dv_smem_layout_staged = sm90_utils.make_smem_layout_a( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + self.dv_stage, + ) + do_smem_layout_staged = sm90_utils.make_smem_layout_a( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + self.do_stage, ) - update_a_smem_layout_staged = sm90_utils.make_smem_layout_a( + q_smem_layout_staged = sm90_utils.make_smem_layout_b( utils.LayoutEnum.COL_MAJOR, self.update_mma_tiler, self.io_dtype, - self.input_stage, + self.q_stage, ) - update_b_smem_layout_staged = sm90_utils.make_smem_layout_b( + w_smem_layout_staged = sm90_utils.make_smem_layout_b( utils.LayoutEnum.COL_MAJOR, self.update_mma_tiler, self.io_dtype, - self.input_stage, + self.w_stage, + ) + dv2_smem_layout_staged = sm90_utils.make_smem_layout_a( + utils.LayoutEnum.COL_MAJOR, + self.update_mma_tiler, + self.io_dtype, + self.dv2_store_stage, ) gk_smem_layout_staged = cute.make_layout( - (self.BK, 1, self.input_stage), + (self.BK, 1, self.gk_stage), stride=(1, self.BK, self.BK), ) dh_smem_layout_staged = sm90_utils.make_smem_layout_epi( @@ -331,31 +390,31 @@ def __call__( tma_atom_k, tma_tensor_k = cpasync.make_tiled_tma_atom( tma_load_op, k_tk, - cute.slice_(a_smem_layout_staged, (None, None, 0)), + cute.slice_(k_smem_layout_staged, (None, None, 0)), (self.BT, self.BK), ) tma_atom_dv, tma_tensor_dv = cpasync.make_tiled_tma_atom( tma_load_op, dv_vt, - cute.slice_(update_a_smem_layout_staged, (None, None, 0)), + cute.slice_(dv_smem_layout_staged, (None, None, 0)), (self.BV, self.BT), ) tma_atom_do, tma_tensor_do = cpasync.make_tiled_tma_atom( tma_load_op, do_vt, - cute.slice_(update_a_smem_layout_staged, (None, None, 0)), + cute.slice_(do_smem_layout_staged, (None, None, 0)), (self.BV, self.BT), ) tma_atom_q, tma_tensor_q = cpasync.make_tiled_tma_atom( tma_load_op, q_kt, - cute.slice_(update_b_smem_layout_staged, (None, None, 0)), + cute.slice_(q_smem_layout_staged, (None, None, 0)), (self.BK, self.BT), ) tma_atom_w, tma_tensor_w = cpasync.make_tiled_tma_atom( tma_load_op, w_kt, - cute.slice_(update_b_smem_layout_staged, (None, None, 0)), + cute.slice_(w_smem_layout_staged, (None, None, 0)), (self.BK, self.BT), ) tma_atom_gk, tma_tensor_gk = cpasync.make_tiled_tma_atom( @@ -373,42 +432,40 @@ def __call__( tma_atom_dv2, tma_tensor_dv2 = cpasync.make_tiled_tma_atom( tma_store_op, dv2_vt, - cute.slice_(update_a_smem_layout_staged, (None, None, 0)), + cute.slice_(dv2_smem_layout_staged, (None, None, 0)), (self.BV, self.BT), ) - self.tma_k_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(a_smem_layout_staged, (None, None, 0))) - self.tma_dv_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_a_smem_layout_staged, (None, None, 0))) - self.tma_do_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_a_smem_layout_staged, (None, None, 0))) - self.tma_q_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) - self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(update_b_smem_layout_staged, (None, None, 0))) + self.tma_k_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(k_smem_layout_staged, (None, None, 0))) + self.tma_dv_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(dv_smem_layout_staged, (None, None, 0))) + self.tma_do_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(do_smem_layout_staged, (None, None, 0))) + self.tma_q_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(q_smem_layout_staged, (None, None, 0))) + self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(w_smem_layout_staged, (None, None, 0))) self.tma_gk_bytes = cute.size_in_bytes(cutlass.Float32, cute.slice_(gk_smem_layout_staged, (None, None, 0))) @cute.struct class SharedStorage: - load_k_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] - load_dv_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] - load_do_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] - load_q_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] - load_w_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] - load_gk_mbar: cute.struct.MemRange[Int64, self.input_stage * 2] + load_k_mbar: cute.struct.MemRange[Int64, self.k_stage * 2] + load_dv_mbar: cute.struct.MemRange[Int64, self.dv_stage * 2] + load_do_mbar: cute.struct.MemRange[Int64, self.do_stage * 2] + load_q_mbar: cute.struct.MemRange[Int64, self.q_stage * 2] + load_w_mbar: cute.struct.MemRange[Int64, self.w_stage * 2] + load_gk_mbar: cute.struct.MemRange[Int64, self.gk_stage * 2] store_dh_mbar: cute.struct.MemRange[Int64, self.dh_store_stage * 2] store_dv2_mbar: cute.struct.MemRange[Int64, self.dv2_store_stage * 2] - store_dv2_done_mbar: cute.struct.MemRange[Int64, self.dv2_store_stage * 2] - sDv2Stage: cute.struct.MemRange[Int32, self.dv2_store_stage] sA: cute.struct.Align[ - cute.struct.MemRange[self.io_dtype, cute.cosize(a_smem_layout_staged)], + cute.struct.MemRange[self.io_dtype, cute.cosize(k_smem_layout_staged)], self.buffer_align_bytes, ] sUA: cute.struct.Align[ - cute.struct.MemRange[self.io_dtype, cute.cosize(update_a_smem_layout_staged)], + cute.struct.MemRange[self.io_dtype, cute.cosize(dv_smem_layout_staged)], self.buffer_align_bytes, ] sDo: cute.struct.Align[ - cute.struct.MemRange[self.io_dtype, cute.cosize(update_a_smem_layout_staged)], + cute.struct.MemRange[self.io_dtype, cute.cosize(do_smem_layout_staged)], self.buffer_align_bytes, ] sGK: cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, BK * self.input_stage], + cute.struct.MemRange[cutlass.Float32, BK * self.gk_stage], 128, ] sG: cute.struct.Align[ @@ -416,11 +473,15 @@ class SharedStorage: 128, ] sUB: cute.struct.Align[ - cute.struct.MemRange[self.io_dtype, cute.cosize(update_b_smem_layout_staged)], + cute.struct.MemRange[self.io_dtype, cute.cosize(q_smem_layout_staged)], self.buffer_align_bytes, ] sW: cute.struct.Align[ - cute.struct.MemRange[self.io_dtype, cute.cosize(update_b_smem_layout_staged)], + cute.struct.MemRange[self.io_dtype, cute.cosize(w_smem_layout_staged)], + self.buffer_align_bytes, + ] + sDv2: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(dv2_smem_layout_staged)], self.buffer_align_bytes, ] sDh: cute.struct.Align[ @@ -447,9 +508,12 @@ class SharedStorage: tiled_mma, update_tiled_mma, qdo_tiled_mma, - a_smem_layout_staged, - update_a_smem_layout_staged, - update_b_smem_layout_staged, + k_smem_layout_staged, + dv_smem_layout_staged, + do_smem_layout_staged, + dv2_smem_layout_staged, + q_smem_layout_staged, + w_smem_layout_staged, tma_atom_k, tma_tensor_k, tma_atom_dv, @@ -494,9 +558,12 @@ def kernel( tiled_mma: cute.TiledMma, update_tiled_mma: cute.TiledMma, qdo_tiled_mma: cute.TiledMma, - a_smem_layout_staged: cute.ComposedLayout, - update_a_smem_layout_staged: cute.ComposedLayout, - update_b_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + dv_smem_layout_staged: cute.ComposedLayout, + do_smem_layout_staged: cute.ComposedLayout, + dv2_smem_layout_staged: cute.ComposedLayout, + q_smem_layout_staged: cute.ComposedLayout, + w_smem_layout_staged: cute.ComposedLayout, tma_atom_k: cute.CopyAtom, tma_tensor_k: cute.Tensor, tma_atom_dv: cute.CopyAtom, @@ -538,58 +605,59 @@ def kernel( smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) - sUA = storage.sUA.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) - sDv2Stage = storage.sDv2Stage.get_tensor(cute.make_layout((self.dv2_store_stage,))) - sDo = storage.sDo.get_tensor(update_a_smem_layout_staged.outer, swizzle=update_a_smem_layout_staged.inner) - sGK = storage.sGK.get_tensor(cute.make_layout((BK, 1, self.input_stage), stride=(1, BK, BK))) + sA = storage.sA.get_tensor(k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner) + sUA = storage.sUA.get_tensor(dv_smem_layout_staged.outer, swizzle=dv_smem_layout_staged.inner) + sDo = storage.sDo.get_tensor(do_smem_layout_staged.outer, swizzle=do_smem_layout_staged.inner) + sGK = storage.sGK.get_tensor(cute.make_layout((BK, 1, self.gk_stage), stride=(1, BK, BK))) sG = storage.sG.get_tensor(cute.make_layout((BT, 2), stride=(1, BT))) - sUB = storage.sUB.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) - sW = storage.sW.get_tensor(update_b_smem_layout_staged.outer, swizzle=update_b_smem_layout_staged.inner) + sUB = storage.sUB.get_tensor(q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner) + sW = storage.sW.get_tensor(w_smem_layout_staged.outer, swizzle=w_smem_layout_staged.inner) + sDv2 = storage.sDv2.get_tensor(dv2_smem_layout_staged.outer, swizzle=dv2_smem_layout_staged.inner) sDh = storage.sDh.get_tensor(dh_smem_layout_staged.outer, swizzle=dh_smem_layout_staged.inner) if warp_idx == self.load_warp_id: cpasync.prefetch_descriptor(tma_atom_k) cpasync.prefetch_descriptor(tma_atom_dv) + if cutlass.const_expr(self.use_gk): + cpasync.prefetch_descriptor(tma_atom_gk) + if warp_idx == self.load_current_warp_id: cpasync.prefetch_descriptor(tma_atom_do) cpasync.prefetch_descriptor(tma_atom_q) cpasync.prefetch_descriptor(tma_atom_w) - if cutlass.const_expr(self.use_gk): - cpasync.prefetch_descriptor(tma_atom_gk) if warp_idx == self.store_warp_id: cpasync.prefetch_descriptor(tma_atom_dh) cpasync.prefetch_descriptor(tma_atom_dv2) load_k_P, load_k_C = pipeline.PipelineTmaAsync.create( - num_stages=self.input_stage, + num_stages=self.k_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_k_bytes, barrier_storage=storage.load_k_mbar.data_ptr(), ).make_participants() load_dv_P, load_dv_C = pipeline.PipelineTmaAsync.create( - num_stages=self.input_stage, + num_stages=self.dv_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_dv_bytes, barrier_storage=storage.load_dv_mbar.data_ptr(), ).make_participants() load_do_P, load_do_C = pipeline.PipelineTmaAsync.create( - num_stages=self.input_stage, + num_stages=self.do_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_do_bytes, barrier_storage=storage.load_do_mbar.data_ptr(), ).make_participants() load_q_P, load_q_C = pipeline.PipelineTmaAsync.create( - num_stages=self.input_stage, + num_stages=self.q_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_q_bytes, barrier_storage=storage.load_q_mbar.data_ptr(), ).make_participants() load_w_P, load_w_C = pipeline.PipelineTmaAsync.create( - num_stages=self.input_stage, + num_stages=self.w_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_w_bytes, @@ -597,7 +665,7 @@ def kernel( ).make_participants() if cutlass.const_expr(self.use_gk): load_gk_P, load_gk_C = pipeline.PipelineTmaAsync.create( - num_stages=self.input_stage, + num_stages=self.gk_stage, producer_group=make_thread_cooperative_group(1), consumer_group=make_thread_cooperative_group(self.num_compute_warps), tx_count=self.tma_gk_bytes, @@ -615,13 +683,6 @@ def kernel( consumer_group=make_thread_cooperative_group(self.threads_per_warp), barrier_storage=storage.store_dv2_mbar.data_ptr(), ).make_participants() - store_dv2_done_P, store_dv2_done_C = pipeline.PipelineAsync.create( - num_stages=self.dv2_store_stage, - producer_group=make_thread_cooperative_group(self.threads_per_warp), - consumer_group=make_thread_cooperative_group(self.num_compute_threads), - barrier_storage=storage.store_dv2_done_mbar.data_ptr(), - ).make_participants() - if cutlass.const_expr(self.is_varlen): tma_tensor_k_use = cute.domain_offset((seq_start, 0, (0, 0)), tma_tensor_k) tma_tensor_dv_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_dv) @@ -666,7 +727,7 @@ def kernel( tma_atom_dh, tma_tensor_dh_use[None, None, (None, i_h, state_b)], (self.BV, self.BK), sDh ) _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( - tma_atom_dv2, tma_tensor_dv2_use[None, None, (i_h, data_b)], (self.BV, self.BT), sUA + tma_atom_dv2, tma_tensor_dv2_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDv2 ) is_compute_warp = warp_idx < self.num_compute_warps @@ -685,6 +746,8 @@ def kernel( tUsB = update_thr_mma.partition_B(sUB) tWsB = update_thr_mma.partition_B(sW) tUrA = update_thr_mma.make_fragment_A(tUsA) + tDv2sA = update_thr_mma.partition_A(sDv2) + tDv2rA = update_thr_mma.make_fragment_A(tDv2sA) tUrB = update_thr_mma.make_fragment_B(tUsB) tWrB = update_thr_mma.make_fragment_B(tWsB) if cutlass.const_expr(self.use_g): @@ -736,7 +799,6 @@ def kernel( tRS_sDh = thr_copy_dh_r2s.partition_D(sDh) rDh_shape = cute.shape(thr_copy_dh_r2s.partition_S(sDh)) tRS_rDh_layout = cute.make_layout(rDh_shape[:3]) - tRS_rDh_out = cute.make_rmem_tensor_like(tRS_rDh_layout, self.io_dtype) # Initialize carried dh state in register blocks. if is_compute_warp: @@ -810,7 +872,7 @@ def kernel( bSG_sGK[None, gk_h.index], tma_bar_ptr=gk_h.barrier, ) - if warp_idx == self.load_warp_id: + if warp_idx == self.load_current_warp_id: do_h = load_do_P.acquire_and_advance() cute.copy(tma_atom_do, bSG_gDo[(None, i_v_tile, i_t)], bSG_sDo[None, do_h.index], tma_bar_ptr=do_h.barrier) q_h = load_q_P.acquire_and_advance() @@ -822,9 +884,12 @@ def kernel( if is_compute_warp: # Match chunk_delta_h.py's h_out overlap: publish the carried # state to the store pipeline before the chunk GEMM chain. + rState0_bf16 = cute.make_rmem_tensor(rStates[0].shape, self.io_dtype) + rState0_bf16.store(rStates[0].load().to(self.io_dtype)) dh_h = store_dh_P.acquire_and_advance() - tRS_rState = tiled_copy_dh_r2s.retile(rStates[0]) - tRS_rDh_out.store(tRS_rState.load().to(self.io_dtype)) + tRS_rState = tiled_copy_dh_r2s.retile(rState0_bf16) + tRS_rDh_out = cute.make_rmem_tensor_like(tRS_rDh_layout, self.io_dtype) + tRS_rDh_out.store(tRS_rState.load()) cute.copy( tiled_copy_dh_r2s, tRS_rDh_out, @@ -837,7 +902,10 @@ def kernel( for k_block in cutlass.range_constexpr(self.num_k_blocks): k_wait = load_k_C.wait_and_advance() rState = rStates[k_block] - rState_op = self.make_acc_into_op(rState, tiled_mma.tv_layout_A, self.io_dtype) + if cutlass.const_expr(k_block == 0): + rState_op = self.make_acc_into_op(rState0_bf16, tiled_mma.tv_layout_A, self.io_dtype) + else: + rState_op = self.make_acc_into_op(rState, tiled_mma.tv_layout_A, self.io_dtype) cute.nvgpu.warpgroup.fence() for kp in cutlass.range(cute.size(tKrB, mode=[2]), unroll_full=True): tiled_mma.set( @@ -867,15 +935,45 @@ def kernel( g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) sG[local_tidx, 0] = g_decay sG[local_tidx, 1] = g_exp - cute.nvgpu.warpgroup.wait_group(0) + if cutlass.const_expr((not self.use_g) and (not self.is_varlen) and (self.num_k_blocks == 1)): + do_wait_early = load_do_C.wait_and_advance() + q_wait_early = load_q_C.wait_and_advance() + if cutlass.const_expr(self.use_gk): + gk_wait_early = load_gk_C.wait_and_advance() + acc_qdo.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_qdo, + tUrDo[None, None, kp, do_wait_early.index], + tUrB[None, None, kp, q_wait_early.index], + acc_qdo, + ) + cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(self.use_gk): + gk_last = sGK[local_tidx, 0, gk_wait_early.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + sGK[local_tidx, 0, gk_wait_early.index] = k_decay + self.gk_precompute_bar.arrive_and_wait() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait_early.index] + cute.nvgpu.warpgroup.wait_group(1) + else: + cute.nvgpu.warpgroup.wait_group(0) k_wait.release() dv_wait = load_dv_C.wait_and_advance() dv_stage = dv_wait.index dv2_store_h = store_dv2_P.acquire_and_advance() - if local_tidx == 0: - sDv2Stage[dv2_store_h.index] = dv_stage - cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + dv2_stage = dv2_store_h.index + if cutlass.const_expr(self.use_g): + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): v_rel, t_rel = tCcDV[ei] t_idx = chunk_start + t_rel @@ -886,99 +984,16 @@ def kernel( out = out * sG[t_rel, 0] out = out + sUA[v_rel, t_rel, dv_stage].to(self.acc_dtype) out_bf16 = out.to(self.io_dtype) - sUA[v_rel, t_rel, dv_stage] = out_bf16 + sDv2[v_rel, t_rel, dv2_stage] = out_bf16 + if remaining < self.BT and t_idx < seq_len: + dv2[data_b, seq_start + chunk_start + t_rel, i_h, v_base + v_rel] = out_bf16 cute.arch.fence_proxy("async.shared", space="cta") dv2_store_h.commit() - cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) - - if remaining < self.BT: - linear_store = local_tidx - while linear_store < self.BV * self.BT: - v_rel = linear_store // self.BT - t_rel = linear_store - v_rel * self.BT - if t_rel < remaining: - dv2[data_b, seq_start + chunk_start + t_rel, i_h, v_base + v_rel] = sUA[ - v_rel, - t_rel, - dv_stage, - ] - linear_store += self.num_compute_threads + dv_wait.release() # dh += scale * do^T @ q - dv2^T @ w. - do_wait = load_do_C.wait_and_advance() - if cutlass.const_expr(self.use_g): - for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): - v_rel, t_rel = tCcDV[ei] - t_idx = chunk_start + t_rel - do_scaled = Float32(0.0) - if t_idx < seq_len: - do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) * sG[t_rel, 1] - acc_dv[ei] = do_scaled - rDo_op = self.make_acc_into_op(acc_dv, qdo_tiled_mma.tv_layout_A, self.io_dtype) - do_wait.release() - if cutlass.const_expr((not self.use_g) and self.is_varlen): - linear_do = local_tidx - while linear_do < self.BV * self.BT: - v_rel = linear_do // self.BT - t_rel = linear_do - v_rel * self.BT - t_idx = chunk_start + t_rel - do_scaled = Float32(0.0) - if t_idx < seq_len: - do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) - sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) - linear_do += self.num_compute_threads - cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) - - for k_block in cutlass.range_constexpr(self.num_k_blocks): - rState = rStates[k_block] - q_wait = load_q_C.wait_and_advance() - if cutlass.const_expr(self.use_gk): - gk_wait = load_gk_C.wait_and_advance() - acc_qdo.fill(0.0) - cute.nvgpu.warpgroup.fence() - if cutlass.const_expr(self.use_g): - for kp in cutlass.range(cute.size(qdo_tUrB, mode=[2]), unroll_full=True): - qdo_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - qdo_tiled_mma, - acc_qdo, - rDo_op[None, None, kp], - qdo_tUrB[None, None, kp, q_wait.index], - acc_qdo, - ) - else: - for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): - update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - update_tiled_mma, - acc_qdo, - tUrDo[None, None, kp, do_wait.index], - tUrB[None, None, kp, q_wait.index], - acc_qdo, - ) - cute.nvgpu.warpgroup.commit_group() - - # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. - if cutlass.const_expr(self.use_g): - for ei in cutlass.range(cute.size(rState), unroll_full=True): - rState[ei] = rState[ei] * g_last_exp - if cutlass.const_expr(self.use_gk): - gk_last = sGK[local_tidx, 0, gk_wait.index].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) - else: - k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) - sGK[local_tidx, 0, gk_wait.index] = k_decay - self.gk_precompute_bar.arrive_and_wait() - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait.index] - - cute.nvgpu.warpgroup.wait_group(0) - q_wait.release() - if cutlass.const_expr(self.use_gk): - gk_wait.release() - + if cutlass.const_expr((not self.use_g) and (not self.is_varlen) and (self.num_k_blocks == 1)): + rState = rStates[0] w_wait = load_w_C.wait_and_advance() acc_wdv.fill(0.0) cute.nvgpu.warpgroup.fence() @@ -987,27 +1002,115 @@ def kernel( cute.gemm( update_tiled_mma, acc_wdv, - tUrA[None, None, kp, dv_stage], + tDv2rA[None, None, kp, dv2_stage], tWrB[None, None, kp, w_wait.index], acc_wdv, ) cute.nvgpu.warpgroup.commit_group() cute.nvgpu.warpgroup.wait_group(0) + q_wait_early.release() + if cutlass.const_expr(self.use_gk): + gk_wait_early.release() for ei in cutlass.range(cute.size(rState), unroll_full=True): update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] rState[ei] = rState[ei] + update w_wait.release() - if cutlass.const_expr(not self.use_g): - do_wait.release() - # sUA[dv_stage] has three consumers after dv2 is written: - # direct tail stores (above), W^T@dv2 (completed here), and - # full-tile TMA stores by the store warp. Keep the load_dv stage - # owned until the store warp signals done, otherwise the load - # warp could refill this stage while a TMA store still reads it. - dv2_done_h = store_dv2_done_C.wait_and_advance() - dv2_done_h.release() - dv_wait.release() + do_wait_early.release() + else: + do_wait = load_do_C.wait_and_advance() + if cutlass.const_expr(self.use_g): + for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): + v_rel, t_rel = tCcDV[ei] + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) * sG[t_rel, 1] + acc_dv[ei] = do_scaled + rDo_op = self.make_acc_into_op(acc_dv, qdo_tiled_mma.tv_layout_A, self.io_dtype) + do_wait.release() + if cutlass.const_expr((not self.use_g) and self.is_varlen): + linear_do = local_tidx + while linear_do < self.BV * self.BT: + v_rel = linear_do // self.BT + t_rel = linear_do - v_rel * self.BT + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) + sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) + linear_do += self.num_compute_threads + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + + for k_block in cutlass.range_constexpr(self.num_k_blocks): + rState = rStates[k_block] + q_wait = load_q_C.wait_and_advance() + if cutlass.const_expr(self.use_gk): + gk_wait = load_gk_C.wait_and_advance() + acc_qdo.fill(0.0) + cute.nvgpu.warpgroup.fence() + if cutlass.const_expr(self.use_g): + for kp in cutlass.range(cute.size(qdo_tUrB, mode=[2]), unroll_full=True): + qdo_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + qdo_tiled_mma, + acc_qdo, + rDo_op[None, None, kp], + qdo_tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) + else: + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_qdo, + tUrDo[None, None, kp, do_wait.index], + tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) + cute.nvgpu.warpgroup.commit_group() + + # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. + if cutlass.const_expr(self.use_g): + for ei in cutlass.range(cute.size(rState), unroll_full=True): + rState[ei] = rState[ei] * g_last_exp + if cutlass.const_expr(self.use_gk): + gk_last = sGK[local_tidx, 0, gk_wait.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + sGK[local_tidx, 0, gk_wait.index] = k_decay + self.gk_precompute_bar.arrive_and_wait() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait.index] + + w_wait = load_w_C.wait_and_advance() + acc_wdv.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_wdv, + tDv2rA[None, None, kp, dv2_stage], + tWrB[None, None, kp, w_wait.index], + acc_wdv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + q_wait.release() + if cutlass.const_expr(self.use_gk): + gk_wait.release() + + for ei in cutlass.range(cute.size(rState), unroll_full=True): + update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] + rState[ei] = rState[ei] + update + w_wait.release() + if cutlass.const_expr(not self.use_g): + do_wait.release() if warp_idx == self.store_warp_id: dh_h = store_dh_C.wait_and_advance() @@ -1017,20 +1120,17 @@ def kernel( dh_h.release() dv2_store_h = store_dv2_C.wait_and_advance() - dv2_done_h = store_dv2_done_P.acquire_and_advance() - # One done token is committed per chunk. Tail chunks skip TMA - # because the tile would cross sequence bounds, but still - # publish done so compute and store pipeline phases stay paired. + # Tail chunks skip TMA because the tile would cross sequence + # bounds. The store pipeline itself keeps sDv2 stages from + # being overwritten before this warp releases them. if remaining >= self.BT: - dv_stage = sDv2Stage[dv2_store_h.index] cute.copy( tma_atom_dv2, - bSG_sDv2[None, dv_stage], + bSG_sDv2[None, dv2_store_h.index], bSG_gDv2[(None, i_v_tile, i_t)], ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) - dv2_done_h.commit() dv2_store_h.release() if cutlass.const_expr(self.use_dh0): @@ -1224,8 +1324,7 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( else: N = B NT = math.ceil(T / BT) - cu_seqlens_arg = torch.arange(B + 1, device=q.device, dtype=torch.int32) * T - chunk_offsets = torch.arange(B + 1, device=q.device, dtype=torch.int32) * NT + cu_seqlens_arg, chunk_offsets = _cached_nonvarlen_metadata(B, T, NT, q.device) scale_value = 1.0 if scale is None else float(scale) state_shape = (N, H, V, K) if transpose_state_layout else (N, H, K, V) @@ -1233,10 +1332,10 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None dv2 = torch.empty_like(dv) - g_arg = g if g is not None else torch.empty(B, T, H, device=q.device, dtype=torch.float32) - gk_arg = gk if gk is not None else torch.empty(B, T, H, K, device=q.device, dtype=torch.float32) - dht_arg = dht if dht is not None else torch.empty(state_shape, device=q.device, dtype=torch.float32) - dh0_arg = dh0 if dh0 is not None else torch.empty(state_shape, device=q.device, dtype=torch.float32) + g_arg = g if g is not None else _cached_empty((B, T, H), device=q.device, dtype=torch.float32) + gk_arg = gk if gk is not None else _cached_empty((B, T, H, K), device=q.device, dtype=torch.float32) + dht_arg = dht if dht is not None else _cached_empty(state_shape, device=q.device, dtype=torch.float32) + dh0_arg = dh0 if dh0 is not None else _cached_empty(state_shape, device=q.device, dtype=torch.float32) if g is not None and (g.dtype != torch.float32 or not g.is_contiguous()): raise ValueError("g must be contiguous float32.") if g is not None and tuple(g.shape) != (B, T, H): From 3dac02600d6e349c9d835383e15e8613dbf633aa Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Tue, 19 May 2026 11:10:04 +0800 Subject: [PATCH 21/28] format --- cula/ops/chunk_delta_h_bwd.py | 692 ++++++++++++++++------------------ 1 file changed, 329 insertions(+), 363 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 1d815d0..fdf9142 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -13,28 +13,25 @@ # limitations under the License. """ -SM90 CuTe DSL implementation for chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64. +Chunk Gated Delta Rule Backward DHU Kernel (SM90 WGMMA) -This Hopper tensor-core path is scoped to match cula/ops/chunk_delta_h.py: - - fixed chunk size BT=64 - - K=V=128, BV=64 - - non-varlen tensors [B, T, H, D] and packed varlen tensors - - state layout [B, NT, H, K, V] or [B, NT, H, V, K] - - non-persistent scheduling +Hopper tensor-core path aligned with cula/ops/chunk_delta_h.py: +- fixed chunk size BT=64 +- K=V=128, BV=64 +- non-varlen tensors [B, T, H, D] and packed varlen tensors +- state layout [B, NT, H, K, V] or [B, NT, H, V, K] -The recurrence is the Triton bwd_dhu recurrence: +The recurrence follows FLA's bwd_dhu: dv2 = dv + K @ dh dh = decay(dh) + scale * Q^T @ do - W^T @ dv2 -Each CTA owns one BV tile and one (batch, head). WGMMA computes the three -64x64 GEMMs per chunk; scalar CUDA code only stages operands and applies the -elementwise recurrence. +Each CTA owns one BV tile and one (batch, head). WGMMA computes the three +64x64 GEMMs per chunk while CUDA threads carry dh in registers. """ from __future__ import annotations import functools -import math import cuda.bindings.driver as cuda import cutlass @@ -138,9 +135,7 @@ def __init__( self.BT = BT self.BV = BV - self.BK = head_dim_k - self.num_k_blocks = head_dim_k // self.BK - self.num_v_tiles = (head_dim_v + BV - 1) // BV + self.BK = BK self.threads_per_warp = 32 self.num_compute_warps = 4 self.num_compute_threads = self.threads_per_warp * self.num_compute_warps @@ -162,9 +157,9 @@ def __init__( self.acc_dtype = cutlass.Float32 self.buffer_align_bytes = 128 - self.mma_tiler = (BT, BV, self.BK) - self.kdh_mma_tiler = (BV, BT, self.BK) - self.update_mma_tiler = (BV, self.BK, BT) + # K=BK=128, so the carried dh state is a single BV x BK register tile. + self.kdh_mma_tiler = (self.BV, self.BT, self.BK) + self.update_mma_tiler = (self.BV, self.BK, self.BT) self.atom_layout_mnk = (1, 1, 1) self.cluster_shape_mnk = (1, 1, 1) self.gk_precompute_bar = pipeline.NamedBarrier( @@ -206,59 +201,17 @@ def __call__( NT_total = self.NT - q_layout = cute.make_layout( - (self.B, self.T, self.H, self.K), - stride=(self.T * self.H * self.K, self.H * self.K, self.K, 1), - ) - q = cute.make_tensor(q_ptr, q_layout) - k = cute.make_tensor(k_ptr, q_layout) - w = cute.make_tensor(w_ptr, q_layout) - - v_layout = cute.make_layout( - (self.B, self.T, self.H, self.V), - stride=(self.T * self.H * self.V, self.H * self.V, self.V, 1), - ) - do = cute.make_tensor(do_ptr, v_layout) - dv = cute.make_tensor(dv_ptr, v_layout) - dv2 = cute.make_tensor(dv2_ptr, v_layout) - + # ===================== GMEM layouts ===================== g_layout = cute.make_layout( (self.B, self.T, self.H), stride=(self.T * self.H, self.H, 1), ) g = cute.make_tensor(g_ptr, g_layout) - gk_layout = cute.make_layout( - (self.B, self.T, self.H, self.K), - stride=(self.T * self.H * self.K, self.H * self.K, self.K, 1), - ) - gk = cute.make_tensor(gk_ptr, gk_layout) cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((self.N + 1,))) chunk_offsets = cute.make_tensor(chunk_offsets_ptr, cute.make_layout((self.N + 1,))) - if cutlass.const_expr(self.transpose_state_layout): - state_layout = cute.make_layout( - (self.B, NT_total, self.H, self.V, self.K), - stride=( - NT_total * self.H * self.K * self.V, - self.H * self.K * self.V, - self.K * self.V, - self.K, - 1, - ), - ) - else: - state_layout = cute.make_layout( - (self.B, NT_total, self.H, self.K, self.V), - stride=( - NT_total * self.H * self.K * self.V, - self.H * self.K * self.V, - self.K * self.V, - self.V, - 1, - ), - ) - dh = cute.make_tensor(dh_ptr, state_layout) + # dh TMA store view: (V, K) tile with layout selected by the requested state layout. if cutlass.const_expr(self.transpose_state_layout): dh_tma_layout = cute.make_layout( (self.V, self.K, (NT_total, self.H, self.B)), @@ -288,6 +241,7 @@ def __call__( dht = cute.make_tensor(dht_ptr, final_layout) dh0 = cute.make_tensor(dh0_ptr, final_layout) + # TMA operand views. Varlen shifts the T dimension with domain_offset below. tk_layout = cute.make_layout( (self.T, self.K, (self.H, self.B)), stride=(self.H * self.K, 1, (self.K, self.T * self.H * self.K)) ) @@ -306,7 +260,13 @@ def __call__( do_vt = cute.make_tensor(do_ptr, vt_layout) dv_vt = cute.make_tensor(dv_ptr, vt_layout) dv2_vt = cute.make_tensor(dv2_ptr, vt_layout) + dv2_layout = cute.make_layout( + (self.B, self.T, self.H, self.V), + stride=(self.T * self.H * self.V, self.H * self.V, self.V, 1), + ) + dv2 = cute.make_tensor(dv2_ptr, dv2_layout) + # ===================== MMA setup ===================== tiled_mma = sm90_utils.make_trivial_tiled_mma( self.io_dtype, self.io_dtype, @@ -338,6 +298,7 @@ def __call__( warpgroup.OperandSource.RMEM, ) + # ===================== SMEM layouts ===================== k_smem_layout_staged = sm90_utils.make_smem_layout_b( utils.LayoutEnum.ROW_MAJOR, self.kdh_mma_tiler, @@ -387,6 +348,7 @@ def __call__( tma_load_op = cpasync.CopyBulkTensorTileG2SOp() tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + # ===================== TMA descriptors ===================== tma_atom_k, tma_tensor_k = cpasync.make_tiled_tma_atom( tma_load_op, k_tk, @@ -442,6 +404,7 @@ def __call__( self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(w_smem_layout_staged, (None, None, 0))) self.tma_gk_bytes = cute.size_in_bytes(cutlass.Float32, cute.slice_(gk_smem_layout_staged, (None, None, 0))) + # ===================== SharedStorage ===================== @cute.struct class SharedStorage: load_k_mbar: cute.struct.MemRange[Int64, self.k_stage * 2] @@ -452,11 +415,11 @@ class SharedStorage: load_gk_mbar: cute.struct.MemRange[Int64, self.gk_stage * 2] store_dh_mbar: cute.struct.MemRange[Int64, self.dh_store_stage * 2] store_dv2_mbar: cute.struct.MemRange[Int64, self.dv2_store_stage * 2] - sA: cute.struct.Align[ + sK: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(k_smem_layout_staged)], self.buffer_align_bytes, ] - sUA: cute.struct.Align[ + sDv: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(dv_smem_layout_staged)], self.buffer_align_bytes, ] @@ -465,14 +428,14 @@ class SharedStorage: self.buffer_align_bytes, ] sGK: cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, BK * self.gk_stage], + cute.struct.MemRange[cutlass.Float32, self.BK * self.gk_stage], 128, ] sG: cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, BT * 2], + cute.struct.MemRange[cutlass.Float32, self.BT * 2], 128, ] - sUB: cute.struct.Align[ + sQ: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(q_smem_layout_staged)], self.buffer_align_bytes, ] @@ -492,16 +455,9 @@ class SharedStorage: self.shared_storage = SharedStorage self.kernel( - q, - k, - w, g, - gk, dht, dh0, - do, - dh, - dv, dv2, cu_seqlens, chunk_offsets, @@ -542,16 +498,9 @@ class SharedStorage: @cute.kernel def kernel( self, - q: cute.Tensor, - k: cute.Tensor, - w: cute.Tensor, g: cute.Tensor, - gk: cute.Tensor, dht: cute.Tensor, dh0: cute.Tensor, - do: cute.Tensor, - dh: cute.Tensor, - dv: cute.Tensor, dv2: cute.Tensor, cu_seqlens: cute.Tensor, chunk_offsets: cute.Tensor, @@ -582,52 +531,43 @@ def kernel( tma_atom_dv2: cute.CopyAtom, tma_tensor_dv2: cute.Tensor, ): - tidx, _, _ = cute.arch.thread_idx() warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - i_v_tile, i_bh, _ = cute.arch.block_idx() - i_n = i_bh // self.H - i_h = i_bh - i_n * self.H - data_b = i_n - state_b = i_n - seq_start = Int32(0) + tidx, _, _ = cute.arch.thread_idx() + + # ===================== Block indices ===================== + v_tile_idx, bh_idx, _ = cute.arch.block_idx() + bidx = bh_idx // self.H + hidx = bh_idx - bidx * self.H + data_bidx = bidx + state_bidx = bidx + tok_offset = Int32(0) seq_len = self.T NT = (self.T + self.BT - 1) // self.BT - chunk_base = Int32(0) + chunk_off = Int32(0) if cutlass.const_expr(self.is_varlen): - data_b = Int32(0) - state_b = Int32(0) - seq_start = cu_seqlens[i_n] - seq_len = cu_seqlens[i_n + 1] - seq_start + data_bidx = Int32(0) + state_bidx = Int32(0) + tok_offset = cu_seqlens[bidx] + seq_len = cu_seqlens[bidx + 1] - tok_offset NT = (seq_len + self.BT - 1) // self.BT - chunk_base = chunk_offsets[i_n] - v_base = i_v_tile * self.BV + chunk_off = chunk_offsets[bidx] + v_tile_base = v_tile_idx * self.BV smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - sA = storage.sA.get_tensor(k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner) - sUA = storage.sUA.get_tensor(dv_smem_layout_staged.outer, swizzle=dv_smem_layout_staged.inner) + # ===================== SMEM views ===================== + sK = storage.sK.get_tensor(k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner) + sDv = storage.sDv.get_tensor(dv_smem_layout_staged.outer, swizzle=dv_smem_layout_staged.inner) sDo = storage.sDo.get_tensor(do_smem_layout_staged.outer, swizzle=do_smem_layout_staged.inner) - sGK = storage.sGK.get_tensor(cute.make_layout((BK, 1, self.gk_stage), stride=(1, BK, BK))) - sG = storage.sG.get_tensor(cute.make_layout((BT, 2), stride=(1, BT))) - sUB = storage.sUB.get_tensor(q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner) + sGK = storage.sGK.get_tensor(cute.make_layout((self.BK, 1, self.gk_stage), stride=(1, self.BK, self.BK))) + sG = storage.sG.get_tensor(cute.make_layout((self.BT, 2), stride=(1, self.BT))) + sQ = storage.sQ.get_tensor(q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner) sW = storage.sW.get_tensor(w_smem_layout_staged.outer, swizzle=w_smem_layout_staged.inner) sDv2 = storage.sDv2.get_tensor(dv2_smem_layout_staged.outer, swizzle=dv2_smem_layout_staged.inner) sDh = storage.sDh.get_tensor(dh_smem_layout_staged.outer, swizzle=dh_smem_layout_staged.inner) - if warp_idx == self.load_warp_id: - cpasync.prefetch_descriptor(tma_atom_k) - cpasync.prefetch_descriptor(tma_atom_dv) - if cutlass.const_expr(self.use_gk): - cpasync.prefetch_descriptor(tma_atom_gk) - if warp_idx == self.load_current_warp_id: - cpasync.prefetch_descriptor(tma_atom_do) - cpasync.prefetch_descriptor(tma_atom_q) - cpasync.prefetch_descriptor(tma_atom_w) - if warp_idx == self.store_warp_id: - cpasync.prefetch_descriptor(tma_atom_dh) - cpasync.prefetch_descriptor(tma_atom_dv2) - + # ===================== Pipelines ===================== load_k_P, load_k_C = pipeline.PipelineTmaAsync.create( num_stages=self.k_stage, producer_group=make_thread_cooperative_group(1), @@ -683,16 +623,20 @@ def kernel( consumer_group=make_thread_cooperative_group(self.threads_per_warp), barrier_storage=storage.store_dv2_mbar.data_ptr(), ).make_participants() + + # ===================== TMA partitions ===================== + # Varlen shifts token-indexed tensors by tok_offset; dh uses chunk_off + # because state storage is compact across sequences. if cutlass.const_expr(self.is_varlen): - tma_tensor_k_use = cute.domain_offset((seq_start, 0, (0, 0)), tma_tensor_k) - tma_tensor_dv_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_dv) - tma_tensor_do_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_do) - tma_tensor_q_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_q) - tma_tensor_w_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_w) - tma_tensor_dh_use = cute.domain_offset((0, 0, (chunk_base, 0, 0)), tma_tensor_dh) - tma_tensor_dv2_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_dv2) + tma_tensor_k_use = cute.domain_offset((tok_offset, 0, (0, 0)), tma_tensor_k) + tma_tensor_dv_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_dv) + tma_tensor_do_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_do) + tma_tensor_q_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_q) + tma_tensor_w_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_w) + tma_tensor_dh_use = cute.domain_offset((0, 0, (chunk_off, 0, 0)), tma_tensor_dh) + tma_tensor_dv2_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_dv2) if cutlass.const_expr(self.use_gk): - tma_tensor_gk_use = cute.domain_offset((0, seq_start, (0, 0)), tma_tensor_gk) + tma_tensor_gk_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_gk) else: tma_tensor_k_use = tma_tensor_k tma_tensor_dv_use = tma_tensor_dv @@ -705,29 +649,29 @@ def kernel( tma_tensor_gk_use = tma_tensor_gk _, bSG_sK, bSG_gK = self._epilog_partition( - tma_atom_k, tma_tensor_k_use[None, None, (i_h, data_b)], (self.BT, self.BK), sA + tma_atom_k, tma_tensor_k_use[None, None, (hidx, data_bidx)], (self.BT, self.BK), sK ) _, bSG_sDv, bSG_gDv = self._epilog_partition( - tma_atom_dv, tma_tensor_dv_use[None, None, (i_h, data_b)], (self.BV, self.BT), sUA + tma_atom_dv, tma_tensor_dv_use[None, None, (hidx, data_bidx)], (self.BV, self.BT), sDv ) _, bSG_sDo, bSG_gDo = self._epilog_partition( - tma_atom_do, tma_tensor_do_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDo + tma_atom_do, tma_tensor_do_use[None, None, (hidx, data_bidx)], (self.BV, self.BT), sDo ) _, bSG_sQ, bSG_gQ = self._epilog_partition( - tma_atom_q, tma_tensor_q_use[None, None, (i_h, data_b)], (self.BK, self.BT), sUB + tma_atom_q, tma_tensor_q_use[None, None, (hidx, data_bidx)], (self.BK, self.BT), sQ ) _, bSG_sW, bSG_gW = self._epilog_partition( - tma_atom_w, tma_tensor_w_use[None, None, (i_h, data_b)], (self.BK, self.BT), sW + tma_atom_w, tma_tensor_w_use[None, None, (hidx, data_bidx)], (self.BK, self.BT), sW ) if cutlass.const_expr(self.use_gk): _, bSG_sGK, bSG_gGK = self._epilog_partition( - tma_atom_gk, tma_tensor_gk_use[None, None, (i_h, data_b)], (self.BK, 1), sGK + tma_atom_gk, tma_tensor_gk_use[None, None, (hidx, data_bidx)], (self.BK, 1), sGK ) _, bSG_sDh, bSG_gDh = self._epilog_partition( - tma_atom_dh, tma_tensor_dh_use[None, None, (None, i_h, state_b)], (self.BV, self.BK), sDh + tma_atom_dh, tma_tensor_dh_use[None, None, (None, hidx, state_bidx)], (self.BV, self.BK), sDh ) _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( - tma_atom_dv2, tma_tensor_dv2_use[None, None, (i_h, data_b)], (self.BV, self.BT), sDv2 + tma_atom_dv2, tma_tensor_dv2_use[None, None, (hidx, data_bidx)], (self.BV, self.BT), sDv2 ) is_compute_warp = warp_idx < self.num_compute_warps @@ -737,13 +681,14 @@ def kernel( else: cute.arch.setmaxregister_decrease(self.num_regs_other) + # ===================== MMA fragments ===================== thr_mma = tiled_mma.get_slice(local_tidx) update_thr_mma = update_tiled_mma.get_slice(local_tidx) - tKsB = thr_mma.partition_B(sA) + tKsB = thr_mma.partition_B(sK) tKrB = thr_mma.make_fragment_B(tKsB) - tUsA = update_thr_mma.partition_A(sUA) - tUsB = update_thr_mma.partition_B(sUB) + tUsA = update_thr_mma.partition_A(sDv) + tUsB = update_thr_mma.partition_B(sQ) tWsB = update_thr_mma.partition_B(sW) tUrA = update_thr_mma.make_fragment_A(tUsA) tDv2sA = update_thr_mma.partition_A(sDv2) @@ -752,30 +697,20 @@ def kernel( tWrB = update_thr_mma.make_fragment_B(tWsB) if cutlass.const_expr(self.use_g): qdo_thr_mma = qdo_tiled_mma.get_slice(local_tidx) - qdo_tUsB = qdo_thr_mma.partition_B(sUB) + qdo_tUsB = qdo_thr_mma.partition_B(sQ) qdo_tUrB = qdo_thr_mma.make_fragment_B(qdo_tUsB) else: tUsDo = update_thr_mma.partition_A(sDo) tUrDo = update_thr_mma.make_fragment_A(tUsDo) - cDV = cute.make_identity_tensor((BV, BT)) + cDV = cute.make_identity_tensor((self.BV, self.BT)) tCcDV = thr_mma.partition_C(cDV) - acc_dv = thr_mma.make_fragment_C(thr_mma.partition_shape_C((BV, BT))) + acc_dv = thr_mma.make_fragment_C(thr_mma.partition_shape_C((self.BV, self.BT))) - cState = cute.make_identity_tensor((BV, self.BK)) + cState = cute.make_identity_tensor((self.BV, self.BK)) tUcState = update_thr_mma.partition_C(cState) - state_shape = update_thr_mma.partition_shape_C((BV, self.BK)) - rState0 = update_thr_mma.make_fragment_C(state_shape) - if cutlass.const_expr(self.num_k_blocks == 1): - rStates = (rState0,) - elif cutlass.const_expr(self.num_k_blocks == 2): - rState1 = update_thr_mma.make_fragment_C(state_shape) - rStates = (rState0, rState1) - else: - rState1 = update_thr_mma.make_fragment_C(state_shape) - rState2 = update_thr_mma.make_fragment_C(state_shape) - rState3 = update_thr_mma.make_fragment_C(state_shape) - rStates = (rState0, rState1, rState2, rState3) + state_shape = update_thr_mma.partition_shape_C((self.BV, self.BK)) + rState = update_thr_mma.make_fragment_C(state_shape) acc_qdo = update_thr_mma.make_fragment_C(state_shape) acc_wdv = update_thr_mma.make_fragment_C(state_shape) dh_smem_layout_enum = ( @@ -800,94 +735,119 @@ def kernel( rDh_shape = cute.shape(thr_copy_dh_r2s.partition_S(sDh)) tRS_rDh_layout = cute.make_layout(rDh_shape[:3]) - # Initialize carried dh state in register blocks. + # Initialize carried dh state in registers. if is_compute_warp: - for k_block in cutlass.range_constexpr(self.num_k_blocks): - k_base = k_block * self.BK - rState = rStates[k_block] - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - v_idx = v_base + v_rel - k_idx = k_base + k_rel - init = Float32(0.0) - if cutlass.const_expr(self.use_dht): - if cutlass.const_expr(self.transpose_state_layout): - init = dht[i_n, i_h, v_idx, k_idx].to(self.acc_dtype) - else: - init = dht[i_n, i_h, k_idx, v_idx].to(self.acc_dtype) - rState[ei] = init - - if warp_idx == self.load_warp_id and NT > 0: - first_chunk = NT - 1 - k_h = load_k_P.acquire_and_advance() - cute.copy(tma_atom_k, bSG_gK[(None, first_chunk, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) - dv_h = load_dv_P.acquire_and_advance() - cute.copy( - tma_atom_dv, - bSG_gDv[(None, i_v_tile, first_chunk)], - bSG_sDv[None, dv_h.index], - tma_bar_ptr=dv_h.barrier, - ) + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + v_idx = v_tile_base + v_rel + init = Float32(0.0) + if cutlass.const_expr(self.use_dht): + if cutlass.const_expr(self.transpose_state_layout): + init = dht[bidx, hidx, v_idx, k_rel].to(self.acc_dtype) + else: + init = dht[bidx, hidx, k_rel, v_idx].to(self.acc_dtype) + rState[ei] = init + + # ========================================================================= + # WARP SPECIALIZATION + # load_warp_id : preloads K, dv, and optional gk for the next reverse chunk + # load_current_warp_id : loads do, q, and w for the current reverse chunk + # compute warps : carry dh in registers and run WGMMA + # store_warp_id : stores dh and dv2 after compute warps publish SMEM tiles + # ========================================================================= + # ===== Reverse chunk loop ===== + # Pipeline: preload(prev chunk) -> Phase 1 (publish dh + K@dh) + # -> Phase 2 (dv2) -> Phase 3 (QDO + decay) -> Phase 4 (WDV + dh update). + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_k) + cpasync.prefetch_descriptor(tma_atom_dv) if cutlass.const_expr(self.use_gk): - gk_h = load_gk_P.acquire_and_advance() - cute.copy( - tma_atom_gk, - bSG_gGK[(None, 0, seq_len - 1)], - bSG_sGK[None, gk_h.index], - tma_bar_ptr=gk_h.barrier, - ) - - for chunk_rev in cutlass.range(0, NT, unroll=0): - i_t = NT - 1 - chunk_rev - next_i_t = i_t - 1 - chunk_start = i_t * self.BT - chunk_end = cutlass.min(chunk_start + self.BT, seq_len) - remaining = chunk_end - chunk_start - last_idx = chunk_end - 1 - g_last = Float32(0.0) - g_last_exp = Float32(1.0) - if cutlass.const_expr(self.use_g): - g_last = g[data_b, seq_start + last_idx, i_h].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - g_last_exp = cute.exp2(g_last, fastmath=self.use_fast_math) - else: - g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) + cpasync.prefetch_descriptor(tma_atom_gk) - if warp_idx == self.load_warp_id and next_i_t >= 0: + if NT > 0: + first_chunk = NT - 1 k_h = load_k_P.acquire_and_advance() - cute.copy(tma_atom_k, bSG_gK[(None, next_i_t, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) + cute.copy(tma_atom_k, bSG_gK[(None, first_chunk, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) dv_h = load_dv_P.acquire_and_advance() cute.copy( tma_atom_dv, - bSG_gDv[(None, i_v_tile, next_i_t)], + bSG_gDv[(None, v_tile_idx, first_chunk)], bSG_sDv[None, dv_h.index], tma_bar_ptr=dv_h.barrier, ) if cutlass.const_expr(self.use_gk): - next_gk_idx = cutlass.min(next_i_t * self.BT + self.BT, seq_len) - 1 gk_h = load_gk_P.acquire_and_advance() cute.copy( tma_atom_gk, - bSG_gGK[(None, 0, next_gk_idx)], + bSG_gGK[(None, 0, seq_len - 1)], bSG_sGK[None, gk_h.index], tma_bar_ptr=gk_h.barrier, ) - if warp_idx == self.load_current_warp_id: + + for chunk_rev in cutlass.range(0, NT, unroll=0): + chunk_idx = NT - 1 - chunk_rev + next_chunk_idx = chunk_idx - 1 + if next_chunk_idx >= 0: + k_h = load_k_P.acquire_and_advance() + cute.copy(tma_atom_k, bSG_gK[(None, next_chunk_idx, 0)], bSG_sK[None, k_h.index], tma_bar_ptr=k_h.barrier) + dv_h = load_dv_P.acquire_and_advance() + cute.copy( + tma_atom_dv, + bSG_gDv[(None, v_tile_idx, next_chunk_idx)], + bSG_sDv[None, dv_h.index], + tma_bar_ptr=dv_h.barrier, + ) + if cutlass.const_expr(self.use_gk): + next_gk_idx = cutlass.min(next_chunk_idx * self.BT + self.BT, seq_len) - 1 + gk_h = load_gk_P.acquire_and_advance() + cute.copy( + tma_atom_gk, + bSG_gGK[(None, 0, next_gk_idx)], + bSG_sGK[None, gk_h.index], + tma_bar_ptr=gk_h.barrier, + ) + + elif warp_idx == self.load_current_warp_id: + cpasync.prefetch_descriptor(tma_atom_do) + cpasync.prefetch_descriptor(tma_atom_q) + cpasync.prefetch_descriptor(tma_atom_w) + + for chunk_rev in cutlass.range(0, NT, unroll=0): + chunk_idx = NT - 1 - chunk_rev do_h = load_do_P.acquire_and_advance() - cute.copy(tma_atom_do, bSG_gDo[(None, i_v_tile, i_t)], bSG_sDo[None, do_h.index], tma_bar_ptr=do_h.barrier) + cute.copy( + tma_atom_do, bSG_gDo[(None, v_tile_idx, chunk_idx)], bSG_sDo[None, do_h.index], tma_bar_ptr=do_h.barrier + ) q_h = load_q_P.acquire_and_advance() - cute.copy(tma_atom_q, bSG_gQ[(None, 0, i_t)], bSG_sQ[None, q_h.index], tma_bar_ptr=q_h.barrier) + cute.copy(tma_atom_q, bSG_gQ[(None, 0, chunk_idx)], bSG_sQ[None, q_h.index], tma_bar_ptr=q_h.barrier) w_h = load_w_P.acquire_and_advance() - cute.copy(tma_atom_w, bSG_gW[(None, 0, i_t)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) - # dv2 = dv + K @ dh. Compute the equivalent (dh @ K^T) tile so the - # register-carried state can feed WGMMA as an RMEM A operand. - if is_compute_warp: - # Match chunk_delta_h.py's h_out overlap: publish the carried - # state to the store pipeline before the chunk GEMM chain. - rState0_bf16 = cute.make_rmem_tensor(rStates[0].shape, self.io_dtype) - rState0_bf16.store(rStates[0].load().to(self.io_dtype)) + cute.copy(tma_atom_w, bSG_gW[(None, 0, chunk_idx)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) + + elif is_compute_warp: + for chunk_rev in cutlass.range(0, NT, unroll=0): + chunk_idx = NT - 1 - chunk_rev + chunk_start = chunk_idx * self.BT + chunk_end = cutlass.min(chunk_start + self.BT, seq_len) + remaining = chunk_end - chunk_start + last_idx = chunk_end - 1 + g_last = Float32(0.0) + g_last_exp = Float32(1.0) + if cutlass.const_expr(self.use_g): + g_last = g[data_bidx, tok_offset + last_idx, hidx].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_last_exp = cute.exp2(g_last, fastmath=self.use_fast_math) + else: + g_last_exp = cute.exp(g_last, fastmath=self.use_fast_math) + + # ======================================== + # Phase 1: Publish dh + start K @ dh + # ======================================== + # Publish carried dh to the store pipeline before the GEMM chain, + # matching chunk_delta_h.py's h_out overlap pattern. + rState_bf16 = cute.make_rmem_tensor(rState.shape, self.io_dtype) + rState_bf16.store(rState.load().to(self.io_dtype)) dh_h = store_dh_P.acquire_and_advance() - tRS_rState = tiled_copy_dh_r2s.retile(rState0_bf16) + tRS_rState = tiled_copy_dh_r2s.retile(rState_bf16) tRS_rDh_out = cute.make_rmem_tensor_like(tRS_rDh_layout, self.io_dtype) tRS_rDh_out.store(tRS_rState.load()) cute.copy( @@ -898,76 +858,75 @@ def kernel( cute.arch.fence_proxy("async.shared", space="cta") dh_h.commit() + # dv2 = dv + K @ dh. Compute the equivalent dh @ K^T tile so + # the register-carried state can feed WGMMA as an RMEM A operand. acc_dv.fill(0.0) - for k_block in cutlass.range_constexpr(self.num_k_blocks): - k_wait = load_k_C.wait_and_advance() - rState = rStates[k_block] - if cutlass.const_expr(k_block == 0): - rState_op = self.make_acc_into_op(rState0_bf16, tiled_mma.tv_layout_A, self.io_dtype) - else: - rState_op = self.make_acc_into_op(rState, tiled_mma.tv_layout_A, self.io_dtype) + k_wait = load_k_C.wait_and_advance() + rState_op = self.make_acc_into_op(rState_bf16, tiled_mma.tv_layout_A, self.io_dtype) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tKrB, mode=[2]), unroll_full=True): + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + tiled_mma, + acc_dv, + rState_op[None, None, kp], + tKrB[None, None, kp, k_wait.index], + acc_dv, + ) + cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(self.use_g): + if local_tidx < self.BT: + t_idx = chunk_start + local_tidx + g_decay = Float32(0.0) + g_exp = Float32(0.0) + if t_idx < seq_len: + g_cur = g[data_bidx, tok_offset + t_idx, hidx].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + g_decay = cute.exp2(g_last - g_cur, fastmath=self.use_fast_math) + g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) + else: + g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) + g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) + sG[local_tidx, 0] = g_decay + sG[local_tidx, 1] = g_exp + if cutlass.const_expr((not self.use_g) and (not self.is_varlen)): + # Phase 3 is independent of K@dh, so overlap QDO and optional gk decay + # with the first GEMM in the no-scalar-g non-varlen fast path. + do_wait_early = load_do_C.wait_and_advance() + q_wait_early = load_q_C.wait_and_advance() + if cutlass.const_expr(self.use_gk): + gk_wait_early = load_gk_C.wait_and_advance() + acc_qdo.fill(0.0) cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tKrB, mode=[2]), unroll_full=True): - tiled_mma.set( - cute.nvgpu.warpgroup.Field.ACCUMULATE, - cutlass.Boolean((k_block != 0) or (kp != 0)), - ) + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) cute.gemm( - tiled_mma, - acc_dv, - rState_op[None, None, kp], - tKrB[None, None, kp, k_wait.index], - acc_dv, + update_tiled_mma, + acc_qdo, + tUrDo[None, None, kp, do_wait_early.index], + tUrB[None, None, kp, q_wait_early.index], + acc_qdo, ) cute.nvgpu.warpgroup.commit_group() - if cutlass.const_expr(self.use_g): - if local_tidx < self.BT: - t_idx = chunk_start + local_tidx - g_decay = Float32(0.0) - g_exp = Float32(0.0) - if t_idx < seq_len: - g_cur = g[data_b, seq_start + t_idx, i_h].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - g_decay = cute.exp2(g_last - g_cur, fastmath=self.use_fast_math) - g_exp = cute.exp2(g_cur, fastmath=self.use_fast_math) - else: - g_decay = cute.exp(g_last - g_cur, fastmath=self.use_fast_math) - g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) - sG[local_tidx, 0] = g_decay - sG[local_tidx, 1] = g_exp - if cutlass.const_expr((not self.use_g) and (not self.is_varlen) and (self.num_k_blocks == 1)): - do_wait_early = load_do_C.wait_and_advance() - q_wait_early = load_q_C.wait_and_advance() - if cutlass.const_expr(self.use_gk): - gk_wait_early = load_gk_C.wait_and_advance() - acc_qdo.fill(0.0) - cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): - update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - update_tiled_mma, - acc_qdo, - tUrDo[None, None, kp, do_wait_early.index], - tUrB[None, None, kp, q_wait_early.index], - acc_qdo, - ) - cute.nvgpu.warpgroup.commit_group() - if cutlass.const_expr(self.use_gk): - gk_last = sGK[local_tidx, 0, gk_wait_early.index].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) - else: - k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) - sGK[local_tidx, 0, gk_wait_early.index] = k_decay - self.gk_precompute_bar.arrive_and_wait() - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait_early.index] - cute.nvgpu.warpgroup.wait_group(1) - else: - cute.nvgpu.warpgroup.wait_group(0) - k_wait.release() + if cutlass.const_expr(self.use_gk): + gk_last = sGK[local_tidx, 0, gk_wait_early.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + sGK[local_tidx, 0, gk_wait_early.index] = k_decay + self.gk_precompute_bar.arrive_and_wait() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait_early.index] + cute.nvgpu.warpgroup.wait_group(1) + else: + cute.nvgpu.warpgroup.wait_group(0) + k_wait.release() + # ======================================== + # Phase 2: dv2 = dv + K @ dh + # ======================================== dv_wait = load_dv_C.wait_and_advance() dv_stage = dv_wait.index dv2_store_h = store_dv2_P.acquire_and_advance() @@ -982,18 +941,19 @@ def kernel( out = acc_dv[ei] if cutlass.const_expr(self.use_g): out = out * sG[t_rel, 0] - out = out + sUA[v_rel, t_rel, dv_stage].to(self.acc_dtype) + out = out + sDv[v_rel, t_rel, dv_stage].to(self.acc_dtype) out_bf16 = out.to(self.io_dtype) sDv2[v_rel, t_rel, dv2_stage] = out_bf16 if remaining < self.BT and t_idx < seq_len: - dv2[data_b, seq_start + chunk_start + t_rel, i_h, v_base + v_rel] = out_bf16 + dv2[data_bidx, tok_offset + chunk_start + t_rel, hidx, v_tile_base + v_rel] = out_bf16 cute.arch.fence_proxy("async.shared", space="cta") dv2_store_h.commit() dv_wait.release() - # dh += scale * do^T @ q - dv2^T @ w. - if cutlass.const_expr((not self.use_g) and (not self.is_varlen) and (self.num_k_blocks == 1)): - rState = rStates[0] + # ======================================== + # Phase 3/4: dh += scale * do^T @ q - dv2^T @ w + # ======================================== + if cutlass.const_expr((not self.use_g) and (not self.is_varlen)): w_wait = load_w_C.wait_and_advance() acc_wdv.fill(0.0) cute.nvgpu.warpgroup.fence() @@ -1020,6 +980,7 @@ def kernel( else: do_wait = load_do_C.wait_and_advance() if cutlass.const_expr(self.use_g): + # Phase 3a: materialize gated do in registers for QDO. for ei in cutlass.range(cute.size(acc_dv), unroll_full=True): v_rel, t_rel = tCcDV[ei] t_idx = chunk_start + t_rel @@ -1030,6 +991,7 @@ def kernel( rDo_op = self.make_acc_into_op(acc_dv, qdo_tiled_mma.tv_layout_A, self.io_dtype) do_wait.release() if cutlass.const_expr((not self.use_g) and self.is_varlen): + # Phase 3a: zero padded do positions in SMEM for varlen tails. linear_do = local_tidx while linear_do < self.BV * self.BT: v_rel = linear_do // self.BT @@ -1042,79 +1004,97 @@ def kernel( linear_do += self.num_compute_threads cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) - for k_block in cutlass.range_constexpr(self.num_k_blocks): - rState = rStates[k_block] - q_wait = load_q_C.wait_and_advance() - if cutlass.const_expr(self.use_gk): - gk_wait = load_gk_C.wait_and_advance() - acc_qdo.fill(0.0) - cute.nvgpu.warpgroup.fence() - if cutlass.const_expr(self.use_g): - for kp in cutlass.range(cute.size(qdo_tUrB, mode=[2]), unroll_full=True): - qdo_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - qdo_tiled_mma, - acc_qdo, - rDo_op[None, None, kp], - qdo_tUrB[None, None, kp, q_wait.index], - acc_qdo, - ) - else: - for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): - update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) - cute.gemm( - update_tiled_mma, - acc_qdo, - tUrDo[None, None, kp, do_wait.index], - tUrB[None, None, kp, q_wait.index], - acc_qdo, - ) - cute.nvgpu.warpgroup.commit_group() - - # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. - if cutlass.const_expr(self.use_g): - for ei in cutlass.range(cute.size(rState), unroll_full=True): - rState[ei] = rState[ei] * g_last_exp - if cutlass.const_expr(self.use_gk): - gk_last = sGK[local_tidx, 0, gk_wait.index].to(self.acc_dtype) - if cutlass.const_expr(self.use_exp2): - k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) - else: - k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) - sGK[local_tidx, 0, gk_wait.index] = k_decay - self.gk_precompute_bar.arrive_and_wait() - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait.index] - - w_wait = load_w_C.wait_and_advance() - acc_wdv.fill(0.0) - cute.nvgpu.warpgroup.fence() - for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + # Phase 3b: QDO plus scalar/key decay while QDO is in flight. + q_wait = load_q_C.wait_and_advance() + if cutlass.const_expr(self.use_gk): + gk_wait = load_gk_C.wait_and_advance() + acc_qdo.fill(0.0) + cute.nvgpu.warpgroup.fence() + if cutlass.const_expr(self.use_g): + for kp in cutlass.range(cute.size(qdo_tUrB, mode=[2]), unroll_full=True): + qdo_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + qdo_tiled_mma, + acc_qdo, + rDo_op[None, None, kp], + qdo_tUrB[None, None, kp, q_wait.index], + acc_qdo, + ) + else: + for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) cute.gemm( update_tiled_mma, - acc_wdv, - tDv2rA[None, None, kp, dv2_stage], - tWrB[None, None, kp, w_wait.index], - acc_wdv, + acc_qdo, + tUrDo[None, None, kp, do_wait.index], + tUrB[None, None, kp, q_wait.index], + acc_qdo, ) - cute.nvgpu.warpgroup.commit_group() - cute.nvgpu.warpgroup.wait_group(0) - q_wait.release() - if cutlass.const_expr(self.use_gk): - gk_wait.release() + cute.nvgpu.warpgroup.commit_group() + # QDO does not consume rState, so hide g/gk state decay under its WGMMA latency. + if cutlass.const_expr(self.use_g): + for ei in cutlass.range(cute.size(rState), unroll_full=True): + rState[ei] = rState[ei] * g_last_exp + if cutlass.const_expr(self.use_gk): + gk_last = sGK[local_tidx, 0, gk_wait.index].to(self.acc_dtype) + if cutlass.const_expr(self.use_exp2): + k_decay = cute.exp2(gk_last, fastmath=self.use_fast_math) + else: + k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) + sGK[local_tidx, 0, gk_wait.index] = k_decay + self.gk_precompute_bar.arrive_and_wait() for ei in cutlass.range(cute.size(rState), unroll_full=True): - update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] - rState[ei] = rState[ei] + update - w_wait.release() + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait.index] + + # Phase 4: WDV and final dh update. + w_wait = load_w_C.wait_and_advance() + acc_wdv.fill(0.0) + cute.nvgpu.warpgroup.fence() + for kp in cutlass.range(cute.size(tUrA, mode=[2]), unroll_full=True): + update_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + update_tiled_mma, + acc_wdv, + tDv2rA[None, None, kp, dv2_stage], + tWrB[None, None, kp, w_wait.index], + acc_wdv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + q_wait.release() + if cutlass.const_expr(self.use_gk): + gk_wait.release() + + for ei in cutlass.range(cute.size(rState), unroll_full=True): + update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] + rState[ei] = rState[ei] + update + w_wait.release() if cutlass.const_expr(not self.use_g): do_wait.release() - if warp_idx == self.store_warp_id: + if cutlass.const_expr(self.use_dh0): + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + v_idx = v_tile_base + v_rel + if cutlass.const_expr(self.transpose_state_layout): + dh0[bidx, hidx, v_idx, k_rel] = rState[ei] + else: + dh0[bidx, hidx, k_rel, v_idx] = rState[ei] + + elif warp_idx == self.store_warp_id: + cpasync.prefetch_descriptor(tma_atom_dh) + cpasync.prefetch_descriptor(tma_atom_dv2) + + for chunk_rev in cutlass.range(0, NT, unroll=0): + chunk_idx = NT - 1 - chunk_rev + chunk_start = chunk_idx * self.BT + chunk_end = cutlass.min(chunk_start + self.BT, seq_len) + remaining = chunk_end - chunk_start + dh_h = store_dh_C.wait_and_advance() - cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, i_v_tile, 0, i_t)]) + cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, v_tile_idx, 0, chunk_idx)]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) dh_h.release() @@ -1127,26 +1107,12 @@ def kernel( cute.copy( tma_atom_dv2, bSG_sDv2[None, dv2_store_h.index], - bSG_gDv2[(None, i_v_tile, i_t)], + bSG_gDv2[(None, v_tile_idx, chunk_idx)], ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) dv2_store_h.release() - if cutlass.const_expr(self.use_dh0): - if is_compute_warp: - for k_block in cutlass.range_constexpr(self.num_k_blocks): - k_base = k_block * self.BK - rState = rStates[k_block] - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - v_idx = v_base + v_rel - k_idx = k_base + k_rel - if cutlass.const_expr(self.transpose_state_layout): - dh0[i_n, i_h, v_idx, k_idx] = rState[ei] - else: - dh0[i_n, i_h, k_idx, v_idx] = rState[ei] - @cute.jit def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): gC_epi = cute.flat_divide(gC_mnl, epi_tile) @@ -1323,7 +1289,7 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( cu_seqlens_arg = cu_seqlens.int() if cu_seqlens.dtype != torch.int32 else cu_seqlens else: N = B - NT = math.ceil(T / BT) + NT = (T + BT - 1) // BT cu_seqlens_arg, chunk_offsets = _cached_nonvarlen_metadata(B, T, NT, q.device) scale_value = 1.0 if scale is None else float(scale) From ef8521015283a919872f65752ac9e1442237fd7f Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Tue, 19 May 2026 15:43:28 +0800 Subject: [PATCH 22/28] del empty cache --- cula/ops/chunk_delta_h_bwd.py | 46 +++++------------------------------ 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index fdf9142..eaef106 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -52,47 +52,12 @@ BV = 64 BK = 128 NUM_THREADS = 224 -_DUMMY_TENSOR_CACHE_MAX = 32 -_dummy_tensor_cache: dict[tuple, torch.Tensor] = {} -_nonvarlen_metadata_cache: dict[tuple, tuple[torch.Tensor, torch.Tensor]] = {} def make_thread_cooperative_group(size: int): return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) -def _device_key(device: torch.device) -> tuple[str, int | None]: - device = torch.device(device) - index = device.index - if device.type == "cuda" and index is None: - index = torch.cuda.current_device() - return device.type, index - - -def _cached_empty(shape: tuple[int, ...], *, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - key = (_device_key(device), dtype, tuple(int(x) for x in shape)) - tensor = _dummy_tensor_cache.get(key) - if tensor is None: - if len(_dummy_tensor_cache) >= _DUMMY_TENSOR_CACHE_MAX: - _dummy_tensor_cache.clear() - tensor = torch.empty(shape, device=device, dtype=dtype) - _dummy_tensor_cache[key] = tensor - return tensor - - -def _cached_nonvarlen_metadata(B: int, T: int, NT: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: - key = (_device_key(device), int(B), int(T), int(NT)) - metadata = _nonvarlen_metadata_cache.get(key) - if metadata is None: - if len(_nonvarlen_metadata_cache) >= _DUMMY_TENSOR_CACHE_MAX: - _nonvarlen_metadata_cache.clear() - cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32) * T - chunk_offsets = torch.arange(B + 1, device=device, dtype=torch.int32) * NT - metadata = (cu_seqlens, chunk_offsets) - _nonvarlen_metadata_cache[key] = metadata - return metadata - - class ChunkDeltaRuleBwdDHUSm90: def __init__( self, @@ -1290,7 +1255,8 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( else: N = B NT = (T + BT - 1) // BT - cu_seqlens_arg, chunk_offsets = _cached_nonvarlen_metadata(B, T, NT, q.device) + cu_seqlens_arg = torch.empty(B + 1, device=q.device, dtype=torch.int32) + chunk_offsets = torch.empty(B + 1, device=q.device, dtype=torch.int32) scale_value = 1.0 if scale is None else float(scale) state_shape = (N, H, V, K) if transpose_state_layout else (N, H, K, V) @@ -1298,10 +1264,10 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None dv2 = torch.empty_like(dv) - g_arg = g if g is not None else _cached_empty((B, T, H), device=q.device, dtype=torch.float32) - gk_arg = gk if gk is not None else _cached_empty((B, T, H, K), device=q.device, dtype=torch.float32) - dht_arg = dht if dht is not None else _cached_empty(state_shape, device=q.device, dtype=torch.float32) - dh0_arg = dh0 if dh0 is not None else _cached_empty(state_shape, device=q.device, dtype=torch.float32) + g_arg = g if g is not None else torch.empty(B, T, H, device=q.device, dtype=torch.float32) + gk_arg = gk if gk is not None else torch.empty(B, T, H, K, device=q.device, dtype=torch.float32) + dht_arg = dht if dht is not None else torch.empty(state_shape, device=q.device, dtype=torch.float32) + dh0_arg = dh0 if dh0 is not None else torch.empty(state_shape, device=q.device, dtype=torch.float32) if g is not None and (g.dtype != torch.float32 or not g.is_contiguous()): raise ValueError("g must be contiguous float32.") if g is not None and tuple(g.shape) != (B, T, H): From 47e62d96d060e23c8bc7acc0227bf5c2b9aeb5c6 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Tue, 19 May 2026 17:47:52 +0800 Subject: [PATCH 23/28] format --- cula/ops/chunk_delta_h_bwd.py | 145 ++++++++++++++++++++++++---------- 1 file changed, 104 insertions(+), 41 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index eaef106..30ba23b 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -42,7 +42,7 @@ import cutlass.utils.hopper_helpers as sm90_utils import torch from cutlass.cute.nvgpu import cpasync -from cutlass.cute.runtime import from_dlpack +from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream from cutlass.cute.typing import Float32, Int32, Int64 from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets @@ -504,14 +504,12 @@ def kernel( bidx = bh_idx // self.H hidx = bh_idx - bidx * self.H data_bidx = bidx - state_bidx = bidx tok_offset = Int32(0) seq_len = self.T NT = (self.T + self.BT - 1) // self.BT chunk_off = Int32(0) if cutlass.const_expr(self.is_varlen): data_bidx = Int32(0) - state_bidx = Int32(0) tok_offset = cu_seqlens[bidx] seq_len = cu_seqlens[bidx + 1] - tok_offset NT = (seq_len + self.BT - 1) // self.BT @@ -633,7 +631,7 @@ def kernel( tma_atom_gk, tma_tensor_gk_use[None, None, (hidx, data_bidx)], (self.BK, 1), sGK ) _, bSG_sDh, bSG_gDh = self._epilog_partition( - tma_atom_dh, tma_tensor_dh_use[None, None, (None, hidx, state_bidx)], (self.BV, self.BK), sDh + tma_atom_dh, tma_tensor_dh_use[None, None, (None, hidx, data_bidx)], (self.BV, self.BK), sDh ) _, bSG_sDv2, bSG_gDv2 = self._epilog_partition( tma_atom_dv2, tma_tensor_dv2_use[None, None, (hidx, data_bidx)], (self.BV, self.BT), sDv2 @@ -1114,10 +1112,6 @@ def make_acc_into_op(self, acc, operand_layout_tv, element_type): return operand -def _as_cute(tensor: torch.Tensor): - return from_dlpack(tensor, assumed_align=16) - - @functools.lru_cache(maxsize=64) def _compile_bwd_dhu_sm90( B: int, @@ -1155,42 +1149,112 @@ def _compile_bwd_dhu_sm90( use_fast_math=USE_FAST_MATH, ) - q_fake = torch.empty(B, T, H, K, device="cuda", dtype=torch.bfloat16) - k_fake = torch.empty_like(q_fake) - w_fake = torch.empty_like(q_fake) - do_fake = torch.empty(B, T, H, V, device="cuda", dtype=torch.bfloat16) - dv_fake = torch.empty_like(do_fake) - dv2_fake = torch.empty_like(do_fake) - g_fake = torch.empty(B, T, H, device="cuda", dtype=torch.float32) - gk_fake = torch.empty(B, T, H, K, device="cuda", dtype=torch.float32) + q_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (B, T, H, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + k_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (B, T, H, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + w_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (B, T, H, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + do_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (B, T, H, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dv_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (B, T, H, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dv2_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (B, T, H, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + g_fake = make_fake_compact_tensor( + cutlass.Float32, + (B, T, H), + stride_order=(2, 1, 0), + assumed_align=128, + ) + gk_fake = make_fake_compact_tensor( + cutlass.Float32, + (B, T, H, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) if transpose_state_layout: - dht_fake = torch.empty(N, H, V, K, device="cuda", dtype=torch.float32) - dh0_fake = torch.empty_like(dht_fake) - dh_fake = torch.empty(B, NT, H, V, K, device="cuda", dtype=torch.bfloat16) + dht_fake = make_fake_compact_tensor( + cutlass.Float32, + (N, H, V, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dh0_fake = make_fake_compact_tensor( + cutlass.Float32, + (N, H, V, K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dh_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (B, NT, H, V, K), + stride_order=(4, 3, 2, 1, 0), + assumed_align=128, + ) else: - dht_fake = torch.empty(N, H, K, V, device="cuda", dtype=torch.float32) - dh0_fake = torch.empty_like(dht_fake) - dh_fake = torch.empty(B, NT, H, K, V, device="cuda", dtype=torch.bfloat16) - cu_fake = torch.empty(N + 1, device="cuda", dtype=torch.int32) - offsets_fake = torch.empty(N + 1, device="cuda", dtype=torch.int32) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + dht_fake = make_fake_compact_tensor( + cutlass.Float32, + (N, H, K, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dh0_fake = make_fake_compact_tensor( + cutlass.Float32, + (N, H, K, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + dh_fake = make_fake_compact_tensor( + cutlass.BFloat16, + (B, NT, H, K, V), + stride_order=(4, 3, 2, 1, 0), + assumed_align=128, + ) + cu_fake = make_fake_compact_tensor(cutlass.Int32, (N + 1,), assumed_align=128) + offsets_fake = make_fake_compact_tensor(cutlass.Int32, (N + 1,), assumed_align=128) + stream_fake = make_fake_stream(use_tvm_ffi_env_stream=True) return cute.compile( kernel, - _as_cute(q_fake), - _as_cute(k_fake), - _as_cute(w_fake), - _as_cute(g_fake), - _as_cute(gk_fake), - _as_cute(dht_fake), - _as_cute(dh0_fake), - _as_cute(do_fake), - _as_cute(dh_fake), - _as_cute(dv_fake), - _as_cute(dv2_fake), - _as_cute(cu_fake), - _as_cute(offsets_fake), - stream=stream, + q_fake, + k_fake, + w_fake, + g_fake, + gk_fake, + dht_fake, + dh0_fake, + do_fake, + dh_fake, + dv_fake, + dv2_fake, + cu_fake, + offsets_fake, + stream_fake, options="--enable-tvm-ffi", ) @@ -1300,8 +1364,7 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( transpose_state_layout, scale_value, ) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - compiled(q, k, w, g_arg, gk_arg, dht_arg, dh0_arg, do, dh, dv, dv2, cu_seqlens_arg, chunk_offsets, stream) + compiled(q, k, w, g_arg, gk_arg, dht_arg, dh0_arg, do, dh, dv, dv2, cu_seqlens_arg, chunk_offsets) return dh, dh0, dv2 From bad0189634d21c4671865644c22bf52da533b538 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Tue, 19 May 2026 18:41:30 +0800 Subject: [PATCH 24/28] only concrete problem size to compiled kernel --- cula/ops/chunk_delta_h_bwd.py | 111 ++++++++++++++++------------------ 1 file changed, 51 insertions(+), 60 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 30ba23b..dd667d4 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -61,10 +61,6 @@ def make_thread_cooperative_group(size: int): class ChunkDeltaRuleBwdDHUSm90: def __init__( self, - batch_size: int, - seq_len: int, - num_sequences: int, - total_chunks: int, num_heads: int, head_dim_k: int, head_dim_v: int, @@ -81,10 +77,6 @@ def __init__( assert head_dim_k == 128 and head_dim_v == 128, ( f"SM90 bwd_dhu currently aligns with ChunkDeltaRuleFwdH and requires K=V=128, got K={head_dim_k}, V={head_dim_v}" ) - self.B = batch_size - self.T = seq_len - self.N = num_sequences - self.NT = total_chunks self.H = num_heads self.K = head_dim_k self.V = head_dim_v @@ -148,6 +140,7 @@ def __call__( dv2_in: cute.Tensor, cu_seqlens_in: cute.Tensor, chunk_offsets_in: cute.Tensor, + problem_size: tuple[Int32, Int32, Int32, Int32], stream: cuda.CUstream, ): q_ptr = q_in.iterator @@ -164,27 +157,27 @@ def __call__( cu_seqlens_ptr = cu_seqlens_in.iterator chunk_offsets_ptr = chunk_offsets_in.iterator - NT_total = self.NT + B, T, N, NT_total = problem_size # ===================== GMEM layouts ===================== g_layout = cute.make_layout( - (self.B, self.T, self.H), - stride=(self.T * self.H, self.H, 1), + (B, T, self.H), + stride=(T * self.H, self.H, 1), ) g = cute.make_tensor(g_ptr, g_layout) - cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((self.N + 1,))) - chunk_offsets = cute.make_tensor(chunk_offsets_ptr, cute.make_layout((self.N + 1,))) + cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((N + 1,))) + chunk_offsets = cute.make_tensor(chunk_offsets_ptr, cute.make_layout((N + 1,))) # dh TMA store view: (V, K) tile with layout selected by the requested state layout. if cutlass.const_expr(self.transpose_state_layout): dh_tma_layout = cute.make_layout( - (self.V, self.K, (NT_total, self.H, self.B)), + (self.V, self.K, (NT_total, self.H, B)), stride=(self.K, 1, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), ) else: dh_tma_layout = cute.make_layout( - (self.V, self.K, (NT_total, self.H, self.B)), + (self.V, self.K, (NT_total, self.H, B)), stride=(1, self.V, (self.H * self.K * self.V, self.K * self.V, NT_total * self.H * self.K * self.V)), ) dh_tma_tile = (self.BV, self.BK) @@ -195,39 +188,33 @@ def __call__( if cutlass.const_expr(self.transpose_state_layout): final_layout = cute.make_layout( - (self.N, self.H, self.V, self.K), + (N, self.H, self.V, self.K), stride=(self.H * self.K * self.V, self.K * self.V, self.K, 1), ) else: final_layout = cute.make_layout( - (self.N, self.H, self.K, self.V), + (N, self.H, self.K, self.V), stride=(self.H * self.K * self.V, self.K * self.V, self.V, 1), ) dht = cute.make_tensor(dht_ptr, final_layout) dh0 = cute.make_tensor(dh0_ptr, final_layout) # TMA operand views. Varlen shifts the T dimension with domain_offset below. - tk_layout = cute.make_layout( - (self.T, self.K, (self.H, self.B)), stride=(self.H * self.K, 1, (self.K, self.T * self.H * self.K)) - ) + tk_layout = cute.make_layout((T, self.K, (self.H, B)), stride=(self.H * self.K, 1, (self.K, T * self.H * self.K))) k_tk = cute.make_tensor(k_ptr, tk_layout) - kt_layout = cute.make_layout( - (self.K, self.T, (self.H, self.B)), stride=(1, self.H * self.K, (self.K, self.T * self.H * self.K)) - ) + kt_layout = cute.make_layout((self.K, T, (self.H, B)), stride=(1, self.H * self.K, (self.K, T * self.H * self.K))) q_kt = cute.make_tensor(q_ptr, kt_layout) w_kt = cute.make_tensor(w_ptr, kt_layout) gk_kt = cute.make_tensor(gk_ptr, kt_layout) - vt_layout = cute.make_layout( - (self.V, self.T, (self.H, self.B)), stride=(1, self.H * self.V, (self.V, self.T * self.H * self.V)) - ) + vt_layout = cute.make_layout((self.V, T, (self.H, B)), stride=(1, self.H * self.V, (self.V, T * self.H * self.V))) do_vt = cute.make_tensor(do_ptr, vt_layout) dv_vt = cute.make_tensor(dv_ptr, vt_layout) dv2_vt = cute.make_tensor(dv2_ptr, vt_layout) dv2_layout = cute.make_layout( - (self.B, self.T, self.H, self.V), - stride=(self.T * self.H * self.V, self.H * self.V, self.V, 1), + (B, T, self.H, self.V), + stride=(T * self.H * self.V, self.H * self.V, self.V, 1), ) dv2 = cute.make_tensor(dv2_ptr, dv2_layout) @@ -426,6 +413,7 @@ class SharedStorage: dv2, cu_seqlens, chunk_offsets, + problem_size, tiled_mma, update_tiled_mma, qdo_tiled_mma, @@ -453,7 +441,7 @@ class SharedStorage: tma_atom_dv2, tma_tensor_dv2, ).launch( - grid=[cute.ceil_div(self.V, self.BV), self.N * self.H, 1], + grid=[cute.ceil_div(self.V, self.BV), N * self.H, 1], block=[self.num_threads, 1, 1], cluster=self.cluster_shape_mnk, stream=stream, @@ -469,6 +457,7 @@ def kernel( dv2: cute.Tensor, cu_seqlens: cute.Tensor, chunk_offsets: cute.Tensor, + problem_size: tuple[Int32, Int32, Int32, Int32], tiled_mma: cute.TiledMma, update_tiled_mma: cute.TiledMma, qdo_tiled_mma: cute.TiledMma, @@ -498,6 +487,7 @@ def kernel( ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx, _, _ = cute.arch.thread_idx() + B, T, N, NT_total = problem_size # ===================== Block indices ===================== v_tile_idx, bh_idx, _ = cute.arch.block_idx() @@ -505,8 +495,8 @@ def kernel( hidx = bh_idx - bidx * self.H data_bidx = bidx tok_offset = Int32(0) - seq_len = self.T - NT = (self.T + self.BT - 1) // self.BT + seq_len = T + NT = (T + self.BT - 1) // self.BT chunk_off = Int32(0) if cutlass.const_expr(self.is_varlen): data_bidx = Int32(0) @@ -1114,10 +1104,6 @@ def make_acc_into_op(self, acc, operand_layout_tv, element_type): @functools.lru_cache(maxsize=64) def _compile_bwd_dhu_sm90( - B: int, - T: int, - N: int, - NT: int, H: int, K: int, V: int, @@ -1130,11 +1116,12 @@ def _compile_bwd_dhu_sm90( transpose_state_layout: bool, scale: float, ): + """Compile one bwd_dhu kernel variant. + + B, T, N, and NT are symbolic during compilation and are passed as runtime + problem_size values, matching the forward kernel's dynamic-shape pattern. + """ kernel = ChunkDeltaRuleBwdDHUSm90( - batch_size=B, - seq_len=T, - num_sequences=N, - total_chunks=NT, num_heads=H, head_dim_k=K, head_dim_v=V, @@ -1149,94 +1136,100 @@ def _compile_bwd_dhu_sm90( use_fast_math=USE_FAST_MATH, ) + sym_b = cute.sym_int() + sym_t = cute.sym_int() + sym_n = cute.sym_int() + sym_nt = cute.sym_int() + sym_meta = cute.sym_int() + q_fake = make_fake_compact_tensor( cutlass.BFloat16, - (B, T, H, K), + (sym_b, sym_t, H, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) k_fake = make_fake_compact_tensor( cutlass.BFloat16, - (B, T, H, K), + (sym_b, sym_t, H, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) w_fake = make_fake_compact_tensor( cutlass.BFloat16, - (B, T, H, K), + (sym_b, sym_t, H, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) do_fake = make_fake_compact_tensor( cutlass.BFloat16, - (B, T, H, V), + (sym_b, sym_t, H, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) dv_fake = make_fake_compact_tensor( cutlass.BFloat16, - (B, T, H, V), + (sym_b, sym_t, H, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) dv2_fake = make_fake_compact_tensor( cutlass.BFloat16, - (B, T, H, V), + (sym_b, sym_t, H, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (B, T, H), + (sym_b, sym_t, H), stride_order=(2, 1, 0), assumed_align=128, ) gk_fake = make_fake_compact_tensor( cutlass.Float32, - (B, T, H, K), + (sym_b, sym_t, H, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) if transpose_state_layout: dht_fake = make_fake_compact_tensor( cutlass.Float32, - (N, H, V, K), + (sym_n, H, V, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) dh0_fake = make_fake_compact_tensor( cutlass.Float32, - (N, H, V, K), + (sym_n, H, V, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) dh_fake = make_fake_compact_tensor( cutlass.BFloat16, - (B, NT, H, V, K), + (sym_b, sym_nt, H, V, K), stride_order=(4, 3, 2, 1, 0), assumed_align=128, ) else: dht_fake = make_fake_compact_tensor( cutlass.Float32, - (N, H, K, V), + (sym_n, H, K, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) dh0_fake = make_fake_compact_tensor( cutlass.Float32, - (N, H, K, V), + (sym_n, H, K, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) dh_fake = make_fake_compact_tensor( cutlass.BFloat16, - (B, NT, H, K, V), + (sym_b, sym_nt, H, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128, ) - cu_fake = make_fake_compact_tensor(cutlass.Int32, (N + 1,), assumed_align=128) - offsets_fake = make_fake_compact_tensor(cutlass.Int32, (N + 1,), assumed_align=128) + cu_fake = make_fake_compact_tensor(cutlass.Int32, (sym_meta,), assumed_align=128) + offsets_fake = make_fake_compact_tensor(cutlass.Int32, (sym_meta,), assumed_align=128) stream_fake = make_fake_stream(use_tvm_ffi_env_stream=True) return cute.compile( @@ -1254,6 +1247,7 @@ def _compile_bwd_dhu_sm90( dv2_fake, cu_fake, offsets_fake, + (Int32(1), Int32(1), Int32(1), Int32(1)), stream_fake, options="--enable-tvm-ffi", ) @@ -1348,10 +1342,6 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( raise ValueError(f"h0 must have shape {state_shape} for this state layout, got {tuple(h0.shape)}.") compiled = _compile_bwd_dhu_sm90( - B, - T, - N, - NT, H, K, V, @@ -1364,7 +1354,8 @@ def chunk_gated_delta_rule_bwd_dhu_sm90( transpose_state_layout, scale_value, ) - compiled(q, k, w, g_arg, gk_arg, dht_arg, dh0_arg, do, dh, dv, dv2, cu_seqlens_arg, chunk_offsets) + problem_size = (Int32(B), Int32(T), Int32(N), Int32(NT)) + compiled(q, k, w, g_arg, gk_arg, dht_arg, dh0_arg, do, dh, dv, dv2, cu_seqlens_arg, chunk_offsets, problem_size) return dh, dh0, dv2 From a1aae3dd7ea2f621c88934b1039968d329f86670 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Wed, 20 May 2026 08:41:05 +0800 Subject: [PATCH 25/28] optimize gk decay way --- cula/ops/chunk_delta_h_bwd.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index dd667d4..ed420ec 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -869,9 +869,6 @@ def kernel( k_decay = cute.exp(gk_last, fastmath=self.use_fast_math) sGK[local_tidx, 0, gk_wait_early.index] = k_decay self.gk_precompute_bar.arrive_and_wait() - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait_early.index] cute.nvgpu.warpgroup.wait_group(1) else: cute.nvgpu.warpgroup.wait_group(0) @@ -922,14 +919,18 @@ def kernel( cute.nvgpu.warpgroup.commit_group() cute.nvgpu.warpgroup.wait_group(0) q_wait_early.release() - if cutlass.const_expr(self.use_gk): - gk_wait_early.release() for ei in cutlass.range(cute.size(rState), unroll_full=True): update = acc_qdo[ei] * Float32(self.scale) - acc_wdv[ei] - rState[ei] = rState[ei] + update + if cutlass.const_expr(self.use_gk): + v_rel, k_rel = tUcState[ei] + rState[ei] = rState[ei] * sGK[k_rel, 0, gk_wait_early.index] + update + else: + rState[ei] = rState[ei] + update w_wait.release() do_wait_early.release() + if cutlass.const_expr(self.use_gk): + gk_wait_early.release() else: do_wait = load_do_C.wait_and_advance() if cutlass.const_expr(self.use_g): From 9f4553baca7c6c5c56a988f6803c93009e745cc6 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Wed, 20 May 2026 09:01:21 +0800 Subject: [PATCH 26/28] dh tma load --- cula/ops/chunk_delta_h_bwd.py | 83 +++++++++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 14 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index ed420ec..36d0915 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -109,7 +109,7 @@ def __init__( self.w_stage = 3 self.gk_stage = 3 self.dh_store_stage = 2 - self.dv2_store_stage = 2 + self.dv2_store_stage = 1 self.io_dtype = cutlass.BFloat16 self.acc_dtype = cutlass.Float32 self.buffer_align_bytes = 128 @@ -198,6 +198,17 @@ def __call__( ) dht = cute.make_tensor(dht_ptr, final_layout) dh0 = cute.make_tensor(dh0_ptr, final_layout) + if cutlass.const_expr(self.transpose_state_layout): + dht_tma_layout = cute.make_layout( + (self.V, self.K, (self.H, N)), + stride=(self.K, 1, (self.K * self.V, self.H * self.K * self.V)), + ) + else: + dht_tma_layout = cute.make_layout( + (self.V, self.K, (self.H, N)), + stride=(1, self.V, (self.K * self.V, self.H * self.K * self.V)), + ) + dht_tma = cute.make_tensor(dht_ptr, dht_tma_layout) # TMA operand views. Varlen shifts the T dimension with domain_offset below. tk_layout = cute.make_layout((T, self.K, (self.H, B)), stride=(self.H * self.K, 1, (self.K, T * self.H * self.K))) @@ -291,6 +302,11 @@ def __call__( (self.BK, 1, self.gk_stage), stride=(1, self.BK, self.BK), ) + dht_smem_layout = ( + cute.make_layout((self.BV, self.BK), stride=(self.BK, 1)) + if cutlass.const_expr(self.transpose_state_layout) + else cute.make_layout((self.BV, self.BK), stride=(1, self.BV)) + ) dh_smem_layout_staged = sm90_utils.make_smem_layout_epi( self.io_dtype, dh_smem_layout_enum, @@ -337,6 +353,12 @@ def __call__( cute.slice_(gk_smem_layout_staged, (None, None, 0)), (self.BK, 1), ) + tma_atom_dht, tma_tensor_dht = cpasync.make_tiled_tma_atom( + tma_load_op, + dht_tma, + dht_smem_layout, + (self.BV, self.BK), + ) tma_atom_dh, tma_tensor_dh = cpasync.make_tiled_tma_atom( tma_store_op, dh_tma, @@ -355,6 +377,7 @@ def __call__( self.tma_q_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(q_smem_layout_staged, (None, None, 0))) self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, cute.slice_(w_smem_layout_staged, (None, None, 0))) self.tma_gk_bytes = cute.size_in_bytes(cutlass.Float32, cute.slice_(gk_smem_layout_staged, (None, None, 0))) + self.tma_dht_bytes = cute.size_in_bytes(cutlass.Float32, dht_smem_layout) # ===================== SharedStorage ===================== @cute.struct @@ -365,6 +388,7 @@ class SharedStorage: load_q_mbar: cute.struct.MemRange[Int64, self.q_stage * 2] load_w_mbar: cute.struct.MemRange[Int64, self.w_stage * 2] load_gk_mbar: cute.struct.MemRange[Int64, self.gk_stage * 2] + load_dht_mbar: cute.struct.MemRange[Int64, 2] store_dh_mbar: cute.struct.MemRange[Int64, self.dh_store_stage * 2] store_dv2_mbar: cute.struct.MemRange[Int64, self.dv2_store_stage * 2] sK: cute.struct.Align[ @@ -435,6 +459,9 @@ class SharedStorage: tma_tensor_w, tma_atom_gk, tma_tensor_gk, + tma_atom_dht, + tma_tensor_dht, + dht_smem_layout, tma_atom_dh, tma_tensor_dh, dh_smem_layout_staged, @@ -479,6 +506,9 @@ def kernel( tma_tensor_w: cute.Tensor, tma_atom_gk: cute.CopyAtom, tma_tensor_gk: cute.Tensor, + tma_atom_dht: cute.CopyAtom, + tma_tensor_dht: cute.Tensor, + dht_smem_layout: cute.Layout, tma_atom_dh: cute.CopyAtom, tma_tensor_dh: cute.Tensor, dh_smem_layout_staged: cute.ComposedLayout, @@ -519,6 +549,7 @@ def kernel( sW = storage.sW.get_tensor(w_smem_layout_staged.outer, swizzle=w_smem_layout_staged.inner) sDv2 = storage.sDv2.get_tensor(dv2_smem_layout_staged.outer, swizzle=dv2_smem_layout_staged.inner) sDh = storage.sDh.get_tensor(dh_smem_layout_staged.outer, swizzle=dh_smem_layout_staged.inner) + sDht = cute.make_tensor(cute.recast_ptr(storage.sDh.data_ptr(), dtype=cutlass.Float32), dht_smem_layout) # ===================== Pipelines ===================== load_k_P, load_k_C = pipeline.PipelineTmaAsync.create( @@ -564,6 +595,14 @@ def kernel( tx_count=self.tma_gk_bytes, barrier_storage=storage.load_gk_mbar.data_ptr(), ).make_participants() + if cutlass.const_expr(self.use_dht): + load_dht_P, load_dht_C = pipeline.PipelineTmaAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.num_compute_warps), + tx_count=self.tma_dht_bytes, + barrier_storage=storage.load_dht_mbar.data_ptr(), + ).make_participants() store_dh_P, store_dh_C = pipeline.PipelineAsync.create( num_stages=self.dh_store_stage, producer_group=make_thread_cooperative_group(self.num_compute_threads), @@ -590,6 +629,7 @@ def kernel( tma_tensor_dv2_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_dv2) if cutlass.const_expr(self.use_gk): tma_tensor_gk_use = cute.domain_offset((0, tok_offset, (0, 0)), tma_tensor_gk) + tma_tensor_dht_use = tma_tensor_dht else: tma_tensor_k_use = tma_tensor_k tma_tensor_dv_use = tma_tensor_dv @@ -600,6 +640,7 @@ def kernel( tma_tensor_dv2_use = tma_tensor_dv2 if cutlass.const_expr(self.use_gk): tma_tensor_gk_use = tma_tensor_gk + tma_tensor_dht_use = tma_tensor_dht _, bSG_sK, bSG_gK = self._epilog_partition( tma_atom_k, tma_tensor_k_use[None, None, (hidx, data_bidx)], (self.BT, self.BK), sK @@ -620,6 +661,10 @@ def kernel( _, bSG_sGK, bSG_gGK = self._epilog_partition( tma_atom_gk, tma_tensor_gk_use[None, None, (hidx, data_bidx)], (self.BK, 1), sGK ) + if cutlass.const_expr(self.use_dht): + _, bSG_sDht, bSG_gDht = self._epilog_partition( + tma_atom_dht, tma_tensor_dht_use[None, None, (hidx, bidx)], (self.BV, self.BK), sDht + ) _, bSG_sDh, bSG_gDh = self._epilog_partition( tma_atom_dh, tma_tensor_dh_use[None, None, (None, hidx, data_bidx)], (self.BV, self.BK), sDh ) @@ -688,19 +733,6 @@ def kernel( rDh_shape = cute.shape(thr_copy_dh_r2s.partition_S(sDh)) tRS_rDh_layout = cute.make_layout(rDh_shape[:3]) - # Initialize carried dh state in registers. - if is_compute_warp: - for ei in cutlass.range(cute.size(rState), unroll_full=True): - v_rel, k_rel = tUcState[ei] - v_idx = v_tile_base + v_rel - init = Float32(0.0) - if cutlass.const_expr(self.use_dht): - if cutlass.const_expr(self.transpose_state_layout): - init = dht[bidx, hidx, v_idx, k_rel].to(self.acc_dtype) - else: - init = dht[bidx, hidx, k_rel, v_idx].to(self.acc_dtype) - rState[ei] = init - # ========================================================================= # WARP SPECIALIZATION # load_warp_id : preloads K, dv, and optional gk for the next reverse chunk @@ -716,6 +748,15 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_dv) if cutlass.const_expr(self.use_gk): cpasync.prefetch_descriptor(tma_atom_gk) + if cutlass.const_expr(self.use_dht): + cpasync.prefetch_descriptor(tma_atom_dht) + dht_h = load_dht_P.acquire_and_advance() + cute.copy( + tma_atom_dht, + bSG_gDht[(None, v_tile_idx, 0)], + bSG_sDht[None], + tma_bar_ptr=dht_h.barrier, + ) if NT > 0: first_chunk = NT - 1 @@ -777,6 +818,20 @@ def kernel( cute.copy(tma_atom_w, bSG_gW[(None, 0, chunk_idx)], bSG_sW[None, w_h.index], tma_bar_ptr=w_h.barrier) elif is_compute_warp: + # Initialize carried dh state in registers. dht is loaded by the + # load warp into the sDh backing buffer before sDh is used for + # per-chunk output stores. + if cutlass.const_expr(self.use_dht): + dht_h = load_dht_C.wait_and_advance() + for ei in cutlass.range(cute.size(rState), unroll_full=True): + v_rel, k_rel = tUcState[ei] + init = Float32(0.0) + if cutlass.const_expr(self.use_dht): + init = sDht[v_rel, k_rel].to(self.acc_dtype) + rState[ei] = init + if cutlass.const_expr(self.use_dht): + dht_h.release() + for chunk_rev in cutlass.range(0, NT, unroll=0): chunk_idx = NT - 1 - chunk_rev chunk_start = chunk_idx * self.BT From ec18ae7f10d613f87c06bbbe19bc5e5e77732b0e Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Wed, 20 May 2026 09:39:46 +0800 Subject: [PATCH 27/28] dh0 tma store --- cula/ops/chunk_delta_h_bwd.py | 86 +++++++++++++++++++++++++++-------- 1 file changed, 67 insertions(+), 19 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index 36d0915..c037443 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -209,6 +209,7 @@ def __call__( stride=(1, self.V, (self.K * self.V, self.H * self.K * self.V)), ) dht_tma = cute.make_tensor(dht_ptr, dht_tma_layout) + dh0_tma = cute.make_tensor(dh0_ptr, dht_tma_layout) # TMA operand views. Varlen shifts the T dimension with domain_offset below. tk_layout = cute.make_layout((T, self.K, (self.H, B)), stride=(self.H * self.K, 1, (self.K, T * self.H * self.K))) @@ -359,6 +360,12 @@ def __call__( dht_smem_layout, (self.BV, self.BK), ) + tma_atom_dh0, tma_tensor_dh0 = cpasync.make_tiled_tma_atom( + tma_store_op, + dh0_tma, + dht_smem_layout, + (self.BV, self.BK), + ) tma_atom_dh, tma_tensor_dh = cpasync.make_tiled_tma_atom( tma_store_op, dh_tma, @@ -390,6 +397,7 @@ class SharedStorage: load_gk_mbar: cute.struct.MemRange[Int64, self.gk_stage * 2] load_dht_mbar: cute.struct.MemRange[Int64, 2] store_dh_mbar: cute.struct.MemRange[Int64, self.dh_store_stage * 2] + store_dh0_mbar: cute.struct.MemRange[Int64, 2] store_dv2_mbar: cute.struct.MemRange[Int64, self.dv2_store_stage * 2] sK: cute.struct.Align[ cute.struct.MemRange[self.io_dtype, cute.cosize(k_smem_layout_staged)], @@ -461,6 +469,8 @@ class SharedStorage: tma_tensor_gk, tma_atom_dht, tma_tensor_dht, + tma_atom_dh0, + tma_tensor_dh0, dht_smem_layout, tma_atom_dh, tma_tensor_dh, @@ -508,6 +518,8 @@ def kernel( tma_tensor_gk: cute.Tensor, tma_atom_dht: cute.CopyAtom, tma_tensor_dht: cute.Tensor, + tma_atom_dh0: cute.CopyAtom, + tma_tensor_dh0: cute.Tensor, dht_smem_layout: cute.Layout, tma_atom_dh: cute.CopyAtom, tma_tensor_dh: cute.Tensor, @@ -550,6 +562,7 @@ def kernel( sDv2 = storage.sDv2.get_tensor(dv2_smem_layout_staged.outer, swizzle=dv2_smem_layout_staged.inner) sDh = storage.sDh.get_tensor(dh_smem_layout_staged.outer, swizzle=dh_smem_layout_staged.inner) sDht = cute.make_tensor(cute.recast_ptr(storage.sDh.data_ptr(), dtype=cutlass.Float32), dht_smem_layout) + sDh0 = cute.make_tensor(cute.recast_ptr(storage.sK.data_ptr(), dtype=cutlass.Float32), dht_smem_layout) # ===================== Pipelines ===================== load_k_P, load_k_C = pipeline.PipelineTmaAsync.create( @@ -603,6 +616,13 @@ def kernel( tx_count=self.tma_dht_bytes, barrier_storage=storage.load_dht_mbar.data_ptr(), ).make_participants() + if cutlass.const_expr(self.use_dh0): + store_dh0_P, store_dh0_C = pipeline.PipelineAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(self.num_compute_threads), + consumer_group=make_thread_cooperative_group(self.threads_per_warp), + barrier_storage=storage.store_dh0_mbar.data_ptr(), + ).make_participants() store_dh_P, store_dh_C = pipeline.PipelineAsync.create( num_stages=self.dh_store_stage, producer_group=make_thread_cooperative_group(self.num_compute_threads), @@ -665,6 +685,10 @@ def kernel( _, bSG_sDht, bSG_gDht = self._epilog_partition( tma_atom_dht, tma_tensor_dht_use[None, None, (hidx, bidx)], (self.BV, self.BK), sDht ) + if cutlass.const_expr(self.use_dh0): + _, bSG_sDh0, bSG_gDh0 = self._epilog_partition( + tma_atom_dh0, tma_tensor_dh0[None, None, (hidx, bidx)], (self.BV, self.BK), sDh0 + ) _, bSG_sDh, bSG_gDh = self._epilog_partition( tma_atom_dh, tma_tensor_dh_use[None, None, (None, hidx, data_bidx)], (self.BV, self.BK), sDh ) @@ -897,13 +921,28 @@ def kernel( g_exp = cute.exp(g_cur, fastmath=self.use_fast_math) sG[local_tidx, 0] = g_decay sG[local_tidx, 1] = g_exp - if cutlass.const_expr((not self.use_g) and (not self.is_varlen)): + if cutlass.const_expr(not self.use_g): # Phase 3 is independent of K@dh, so overlap QDO and optional gk decay - # with the first GEMM in the no-scalar-g non-varlen fast path. + # with the first GEMM in the no-scalar-g fast path. For + # varlen tails, zero padded do positions before QDO so TMA + # overfetch into the next sequence cannot contribute. do_wait_early = load_do_C.wait_and_advance() q_wait_early = load_q_C.wait_and_advance() if cutlass.const_expr(self.use_gk): gk_wait_early = load_gk_C.wait_and_advance() + if cutlass.const_expr(self.is_varlen): + if remaining < self.BT: + linear_do = local_tidx + while linear_do < self.BV * self.BT: + v_rel = linear_do // self.BT + t_rel = linear_do - v_rel * self.BT + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait_early.index].to(self.acc_dtype) + sDo[v_rel, t_rel, do_wait_early.index] = do_scaled.to(self.io_dtype) + linear_do += self.num_compute_threads + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) acc_qdo.fill(0.0) cute.nvgpu.warpgroup.fence() for kp in cutlass.range(cute.size(tUrDo, mode=[2]), unroll_full=True): @@ -958,7 +997,7 @@ def kernel( # ======================================== # Phase 3/4: dh += scale * do^T @ q - dv2^T @ w # ======================================== - if cutlass.const_expr((not self.use_g) and (not self.is_varlen)): + if cutlass.const_expr(not self.use_g): w_wait = load_w_C.wait_and_advance() acc_wdv.fill(0.0) cute.nvgpu.warpgroup.fence() @@ -1001,17 +1040,18 @@ def kernel( do_wait.release() if cutlass.const_expr((not self.use_g) and self.is_varlen): # Phase 3a: zero padded do positions in SMEM for varlen tails. - linear_do = local_tidx - while linear_do < self.BV * self.BT: - v_rel = linear_do // self.BT - t_rel = linear_do - v_rel * self.BT - t_idx = chunk_start + t_rel - do_scaled = Float32(0.0) - if t_idx < seq_len: - do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) - sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) - linear_do += self.num_compute_threads - cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) + if remaining < self.BT: + linear_do = local_tidx + while linear_do < self.BV * self.BT: + v_rel = linear_do // self.BT + t_rel = linear_do - v_rel * self.BT + t_idx = chunk_start + t_rel + do_scaled = Float32(0.0) + if t_idx < seq_len: + do_scaled = sDo[v_rel, t_rel, do_wait.index].to(self.acc_dtype) + sDo[v_rel, t_rel, do_wait.index] = do_scaled.to(self.io_dtype) + linear_do += self.num_compute_threads + cute.arch.barrier(barrier_id=2, number_of_threads=self.num_compute_threads) # Phase 3b: QDO plus scalar/key decay while QDO is in flight. q_wait = load_q_C.wait_and_advance() @@ -1084,17 +1124,18 @@ def kernel( do_wait.release() if cutlass.const_expr(self.use_dh0): + dh0_h = store_dh0_P.acquire_and_advance() for ei in cutlass.range(cute.size(rState), unroll_full=True): v_rel, k_rel = tUcState[ei] - v_idx = v_tile_base + v_rel - if cutlass.const_expr(self.transpose_state_layout): - dh0[bidx, hidx, v_idx, k_rel] = rState[ei] - else: - dh0[bidx, hidx, k_rel, v_idx] = rState[ei] + sDh0[v_rel, k_rel] = rState[ei] + cute.arch.fence_proxy("async.shared", space="cta") + dh0_h.commit() elif warp_idx == self.store_warp_id: cpasync.prefetch_descriptor(tma_atom_dh) cpasync.prefetch_descriptor(tma_atom_dv2) + if cutlass.const_expr(self.use_dh0): + cpasync.prefetch_descriptor(tma_atom_dh0) for chunk_rev in cutlass.range(0, NT, unroll=0): chunk_idx = NT - 1 - chunk_rev @@ -1122,6 +1163,13 @@ def kernel( cute.arch.cp_async_bulk_wait_group(0, read=True) dv2_store_h.release() + if cutlass.const_expr(self.use_dh0): + dh0_h = store_dh0_C.wait_and_advance() + cute.copy(tma_atom_dh0, bSG_sDh0[None], bSG_gDh0[(None, v_tile_idx, 0)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh0_h.release() + @cute.jit def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): gC_epi = cute.flat_divide(gC_mnl, epi_tile) From 793d3c76b98a4247292af23eb29aba3b02a1db23 Mon Sep 17 00:00:00 2001 From: yechenzhi <136920488@qq.com> Date: Thu, 21 May 2026 12:15:34 +0800 Subject: [PATCH 28/28] dv2 store stage --- cula/ops/chunk_delta_h_bwd.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cula/ops/chunk_delta_h_bwd.py b/cula/ops/chunk_delta_h_bwd.py index c037443..a6cef26 100644 --- a/cula/ops/chunk_delta_h_bwd.py +++ b/cula/ops/chunk_delta_h_bwd.py @@ -109,7 +109,7 @@ def __init__( self.w_stage = 3 self.gk_stage = 3 self.dh_store_stage = 2 - self.dv2_store_stage = 1 + self.dv2_store_stage = 2 self.io_dtype = cutlass.BFloat16 self.acc_dtype = cutlass.Float32 self.buffer_align_bytes = 128 @@ -1146,8 +1146,6 @@ def kernel( dh_h = store_dh_C.wait_and_advance() cute.copy(tma_atom_dh, bSG_sDh[None, dh_h.index], bSG_gDh[(None, v_tile_idx, 0, chunk_idx)]) cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - dh_h.release() dv2_store_h = store_dv2_C.wait_and_advance() # Tail chunks skip TMA because the tile would cross sequence @@ -1160,7 +1158,8 @@ def kernel( bSG_gDv2[(None, v_tile_idx, chunk_idx)], ) cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.cp_async_bulk_wait_group(0, read=True) + dh_h.release() dv2_store_h.release() if cutlass.const_expr(self.use_dh0):