Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/bench_la_decode_vs_fla.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def kernel_fla():
# cute kernel: pre-create compiled + stream handle
cute_state_k = state_init.clone().permute(0, 1, 3, 2).reshape(B * H, V, K).contiguous()
out_cute_k = torch.empty(B, H, V, device=device, dtype=dtype)
cache = _get_compiled_kernel(B, 1, H, K, V, scale, USE_FAST_MATH)
cache = _get_compiled_kernel(B, 1, H, K, V, cute_state_k.shape[0], scale, USE_FAST_MATH)
compiled_cute = cache["compiled"]
stream_handle = cuda_drv.CUstream(torch.cuda.current_stream().cuda_stream)

Expand Down
29 changes: 18 additions & 11 deletions cula/ops/la_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def la_decode_kernel_small_batch_pretranspose(
cute.arch.barrier()

# Get current batch
gSrc_batch = h0_source[(batch_idx, None, None)] # (V, K)
gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (batch_idx, None, 0))
pool_idx = h0_indices[i_n] * HV + i_hv
gSrc_batch = h0_source[(pool_idx, None, None)] # (V, K)
gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (pool_idx, None, 0))

# split tiles in V-dimension
gSrc = cute.local_tile(gSrc_batch, (TILE_V, TILE_K), (None, 0)) # (TILE_V, TILE_K, num_v_tiles)
Expand Down Expand Up @@ -289,8 +290,9 @@ def la_decode_kernel_big_batch_pretranspose(
cute.arch.barrier()

# Get current batch
gSrc_batch = h0_source[(batch_idx, None, None)] # (V, K)
gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (batch_idx, None, 0))
pool_idx = h0_indices[i_n] * HV + i_hv
gSrc_batch = h0_source[(pool_idx, None, None)] # (V, K)
gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (pool_idx, None, 0))

# split tiles in V-dimension
gSrc = cute.local_tile(gSrc_batch, (TILE_V, TILE_K), (None, 0)) # (TILE_V, TILE_K, num_v_tiles)
Expand Down Expand Up @@ -418,7 +420,7 @@ def run_la_decode_kernel_big_batch_pretranspose(
stream: cuda.CUstream,
):
# h0_source: (B*HV, V, K)
batch_size, v_dim, _k_dim = (
_pool_dim0, v_dim, _k_dim = (
h0_source.layout.shape[0],
h0_source.layout.shape[1],
h0_source.layout.shape[2],
Expand Down Expand Up @@ -477,7 +479,7 @@ def run_la_decode_kernel_big_batch_pretranspose(
TILE_V_BIG,
NUM_STAGES_BIG,
).launch(
grid=(batch_size, 1, 1),
grid=(B * H, 1, 1),
block=[NUM_THREADS_BIG, 1, 1],
smem=smem_bytes,
stream=stream,
Expand All @@ -502,7 +504,7 @@ def run_la_decode_kernel_small_batch_pretranspose(
stream: cuda.CUstream,
):
# h0_source: (B*H, V, K)
batch_size, v_dim, _k_dim = (
_pool_dim0, v_dim, _k_dim = (
h0_source.layout.shape[0],
h0_source.layout.shape[1],
h0_source.layout.shape[2],
Expand Down Expand Up @@ -561,15 +563,17 @@ def run_la_decode_kernel_small_batch_pretranspose(
TILE_V_SMALL,
NUM_STAGES_SMALL,
).launch(
grid=(batch_size * NUM_BLOCKS_PER_STATE, 1, 1),
grid=(B * H * NUM_BLOCKS_PER_STATE, 1, 1),
block=[NUM_THREADS_SMALL, 1, 1],
smem=smem_bytes,
stream=stream,
)


@functools.cache
def _get_compiled_kernel(B: int, T: int, H: int, K: int, V: int, softmax_scale: float, use_fast_math: bool = True):
def _get_compiled_kernel(
B: int, T: int, H: int, K: int, V: int, pool_dim0: int, softmax_scale: float, use_fast_math: bool = True
):
"""Get or create compiled kernel cache."""
return {}

Expand Down Expand Up @@ -625,10 +629,14 @@ def linear_attention_decode(
raise NotImplementedError(f"CuTe kernel doesn't support K splitting (k_dim_block={k_dim_block})")

# Get compiled kernel (cached)
cache_key = (B, 1, H, HEAD_DIM, HEAD_DIM, softmax_scale, USE_FAST_MATH)
pool_dim0 = s.shape[0]
cache_key = (B, 1, H, HEAD_DIM, HEAD_DIM, pool_dim0, softmax_scale, USE_FAST_MATH)
cache = _get_compiled_kernel(*cache_key)

h0_source = s

# Validate state pool dimensions
assert s.shape[0] % H == 0, f"s.shape[0] must be divisible by H={H}, got {s.shape[0]}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation assumes that s is a 3D tensor with shape [pool_size * H, V, K] and out is a 3D tensor with shape [B, H, V]. However, the docstrings (lines 585-586 and 607-608) incorrectly describe them as 4D tensors. To prevent runtime indexing errors or incorrect results when users follow the docstrings, it is recommended to explicitly validate the dimensionality of these tensors here.

    # Validate state pool dimensions
    assert s.ndim == 3, f"s must be a 3D tensor [pool_size * H, V, K], got {s.ndim}D"
    assert out.ndim == 3, f"out must be a 3D tensor [B, H, V], got {out.ndim}D"
    assert s.shape[0] % H == 0, f"s.shape[0] must be divisible by H={H}, got {s.shape[0]}"

# First-time compilation
if "compiled" not in cache:
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
Expand All @@ -644,7 +652,6 @@ def linear_attention_decode(
v_view = v
o_view = out

# Use s_offsets directly (pass to kernel but not actually used in current implementation)
h0_indices = s_offsets

# Convert to CuTe format for compilation
Expand Down
263 changes: 263 additions & 0 deletions tests/test_la_decode_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
#!/usr/bin/env python3
# Copyright 2025-2026 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test for Lightning Attention decode state-pool indirect indexing.

Exposes the bug where la_decode ignores s_offsets and indexes state
by flattened batch_idx directly. With identity offsets the bug is invisible.
With non-identity offsets, the kernel reads/writes wrong state slots.
"""

import pathlib
import sys

import pytest
import torch

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent))

from cula.ops.la_decode import linear_attention_decode


def torch_la_decode_ref(q, k, v, state, decay_scales, scale):
"""Pure PyTorch reference — state is [B, H, K, V] (BHKV)."""
B, H, D = q.shape
q_f = q.float() * scale
k_f = k.float()
v_f = v.float()
decay = torch.exp(-decay_scales).view(1, H, 1, 1)
state_new = state * decay + k_f.unsqueeze(-1) * v_f.unsqueeze(-2)
o = torch.einsum("bhk,bhkv->bhv", q_f, state_new)
return o.to(torch.bfloat16), state_new


def run_la_decode_with_pool(q, k, v, state_pool_4d, s_offsets, decay_scales, scale):
"""
Run la_decode with a state pool and arbitrary offsets.

state_pool_4d: [pool_size, H, K, V] — the full pool (BHKV layout)
s_offsets: [B] — which pool slot each batch element uses
"""
B, H, D = q.shape
pool_size = state_pool_4d.shape[0]

# la_decode expects BHVK layout: [pool_size*H, V, K]
state_cute = state_pool_4d.clone().transpose(-1, -2).contiguous().reshape(pool_size * H, D, D)
out = torch.zeros(B, H, D, device=q.device, dtype=torch.bfloat16)

linear_attention_decode(
q,
k,
v,
state_cute,
out,
softmax_scale=scale,
stride_q=0,
stride_k=0,
stride_v=0,
stride_s=0,
stride_o=0,
s_offsets=s_offsets,
decay_scales=decay_scales,
HEAD_DIM=D,
K_SPLIT_DIM=D,
V_SPLIT_DIM=D,
)

state_out = state_cute.reshape(pool_size, H, D, D).transpose(-1, -2).contiguous()
return out, state_out


# ---------------------------------------------------------------------------
# Test 1: Identity offsets (baseline — should always pass)
# ---------------------------------------------------------------------------
def test_identity_offsets():
"""Identity offsets: s_offsets=[0,1,2,3]. Bug is invisible."""
B, H, D = 4, 8, 128
scale = D**-0.5
decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H

torch.manual_seed(42)
q = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
state_4d = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32) * 0.1

s_offsets = torch.arange(B, device="cuda", dtype=torch.int32)
out, _ = run_la_decode_with_pool(q, k, v, state_4d, s_offsets, decay_scales, scale)

o_ref, _ = torch_la_decode_ref(q, k, v, state_4d, decay_scales, scale)
rmse = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item()
max_ref = torch.abs(o_ref.float()).max().item()
rel_err = rmse / (max_ref + 1e-8)

assert rel_err < 0.01, f"Identity offsets: rel_err={rel_err:.6f}"


# ---------------------------------------------------------------------------
# Test 2: Non-identity offsets (exposes the bug)
# ---------------------------------------------------------------------------
def test_non_identity_offsets():
"""
pool_size=6, batch=4, offsets=[2, 0, 5, 1].
Each batch reads a different, non-sequential pool slot.
Bug: kernel reads slots [0,1,2,3] instead of [2,0,5,1].
"""
B = 4
POOL_SIZE = 6
H, D = 8, 128
scale = D**-0.5
decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H

torch.manual_seed(42)
q = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)

# Large state magnitude so wrong-slot reads produce clearly different outputs
state_pool = torch.randn(POOL_SIZE, H, D, D, device="cuda", dtype=torch.float32) * 0.1

offsets = [2, 0, 5, 1]
s_offsets = torch.tensor(offsets, device="cuda", dtype=torch.int32)

out, _ = run_la_decode_with_pool(q, k, v, state_pool, s_offsets, decay_scales, scale)

# Reference: manually select the correct state for each batch element
state_selected = state_pool[s_offsets.long()] # [B, H, D, D]
o_ref, _ = torch_la_decode_ref(q, k, v, state_selected, decay_scales, scale)

rmse = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item()
max_ref = torch.abs(o_ref.float()).max().item()
rel_err = rmse / (max_ref + 1e-8)

assert rel_err < 0.01, f"Non-identity offsets {offsets}: rel_err={rel_err:.6f}"


# ---------------------------------------------------------------------------
# Test 3: Reversed offsets (another non-identity pattern)
# ---------------------------------------------------------------------------
def test_reversed_offsets():
"""
pool_size=B, offsets=[3,2,1,0] (reversed).
Batch 0 reads slot 3, batch 3 reads slot 0.
"""
B, H, D = 4, 8, 128
scale = D**-0.5
decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H

torch.manual_seed(42)
q = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
state_pool = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32) * 0.1

offsets = list(reversed(range(B)))
s_offsets = torch.tensor(offsets, device="cuda", dtype=torch.int32)

out, _ = run_la_decode_with_pool(q, k, v, state_pool, s_offsets, decay_scales, scale)

state_selected = state_pool[s_offsets.long()]
o_ref, _ = torch_la_decode_ref(q, k, v, state_selected, decay_scales, scale)

rmse = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item()
max_ref = torch.abs(o_ref.float()).max().item()
rel_err = rmse / (max_ref + 1e-8)

assert rel_err < 0.01, f"Reversed offsets {offsets}: rel_err={rel_err:.6f}"


# ---------------------------------------------------------------------------
# Test 4: State writeback with non-identity offsets
# ---------------------------------------------------------------------------
def test_state_writeback_non_identity():
"""
Verify that state updates go to the correct pool slots.
After decode, pool slot offsets[i] should have the updated state for batch i.
Other pool slots should be unchanged.
"""
B = 4
POOL_SIZE = 6
H, D = 8, 128
scale = D**-0.5
decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H

torch.manual_seed(42)
q = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
state_pool = torch.randn(POOL_SIZE, H, D, D, device="cuda", dtype=torch.float32) * 0.1
state_pool_orig = state_pool.clone()

offsets = [2, 0, 5, 1]
s_offsets = torch.tensor(offsets, device="cuda", dtype=torch.int32)

_, state_out = run_la_decode_with_pool(q, k, v, state_pool, s_offsets, decay_scales, scale)

# Reference: compute expected new state for each active batch element
state_selected = state_pool_orig[s_offsets.long()]
_, state_ref = torch_la_decode_ref(q, k, v, state_selected, decay_scales, scale)

# Check that active slots were updated correctly
for b_idx, pool_slot in enumerate(offsets):
slot_rmse = torch.sqrt(torch.mean((state_out[pool_slot].float() - state_ref[b_idx].float()) ** 2)).item()
slot_max = torch.abs(state_ref[b_idx].float()).max().item()
slot_rel = slot_rmse / (slot_max + 1e-8)
assert slot_rel < 0.001, f"State writeback: pool slot {pool_slot} (batch {b_idx}) rel_err={slot_rel:.6f}"

# Check that inactive slots (3, 4) were NOT touched
inactive = set(range(POOL_SIZE)) - set(offsets)
for slot in inactive:
diff = torch.abs(state_out[slot] - state_pool_orig[slot]).max().item()
assert diff < 1e-8, f"Inactive pool slot {slot} was modified! max_diff={diff}"


# ---------------------------------------------------------------------------
# Test 5: Big batch (B > 32) with non-identity offsets
# ---------------------------------------------------------------------------
def test_big_batch_non_identity_offsets():
"""
B=33 triggers the big-batch kernel path (B > 32).
pool_size=40, shifted offsets so batch i reads slot (i + 7) % 40.
"""
B = 33
POOL_SIZE = 40
H, D = 8, 128
scale = D**-0.5
decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H

torch.manual_seed(42)
q = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)

state_pool = torch.randn(POOL_SIZE, H, D, D, device="cuda", dtype=torch.float32) * 0.1

offsets = [(i + 7) % POOL_SIZE for i in range(B)]
s_offsets = torch.tensor(offsets, device="cuda", dtype=torch.int32)

out, _ = run_la_decode_with_pool(q, k, v, state_pool, s_offsets, decay_scales, scale)

state_selected = state_pool[s_offsets.long()]
o_ref, _ = torch_la_decode_ref(q, k, v, state_selected, decay_scales, scale)

rmse = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item()
max_ref = torch.abs(o_ref.float()).max().item()
rel_err = rmse / (max_ref + 1e-8)

assert rel_err < 0.01, f"Big batch non-identity offsets: rel_err={rel_err:.6f}"


if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])