-
Notifications
You must be signed in to change notification settings - Fork 115
Open
Labels
Description
Version
1.1.0
Version
13.1
Which installation method(s) does this occur on?
No response
Describe the bug.
Latest code have some changed behavior in type inference of j on branch:
if i % 2 == 0:
j = step
else:
j = Tc - 1 - step
At 9af1b63, the following code can pass frontend:
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import argparse
import cuda.tile as ct
try:
import cuda.tile_experimental as ct_experimental
except ImportError:
ct_experimental = None
import torch
import math
import sys
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
from utils.benchmark import report_benchmark
from types import SimpleNamespace
import numpy as np
from cuda.tile import RoundingMode as RMd
INV_LOG_2 = 1.0 / math.log(2)
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
TILE_X = 2
@ct.kernel(occupancy=1)
def fmha_kernel(Q, K, V, Out,
qk_scale: float,
input_pos: int,
TILE_D: ConstInt, # TILE_D = hidden_size
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
QUERY_GROUP_SIZE: ConstInt,
CAUSAL: ConstBool,
EVEN_K: ConstBool,
TILE_X: ConstInt):
"""
cuTile kernel for Fused Multi-Head Attention (FMHA).
Computes attention output for a specific batch item and head, using tiling and online softmax.
"""
# Map block IDs to batch and head indices
bid_start = ct.bid(0) * TILE_X
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
off_kv_h = head_idx // QUERY_GROUP_SIZE
# Adjust qk_scale for exp2
qk_scale = qk_scale * INV_LOG_2
for i in range(0, TILE_X):
bid_x = bid_start + i
# Initialize offsets for current query tile (M-dimension)
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M]
offs_m += input_pos
offs_m = offs_m[:, None] # [TILE_M, 1]
# Initialize local offsets for key/value tile (N-dimension)
offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N]
offs_n_tile = offs_n_tile[None, :] # [1, TILE_N]
# Initialize online softmax accumulators in float32 for stability
m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)
# Load query tile for this batch, head, and M-chunk
q = ct.load(
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D]
# loop over k, v and update accumulator
m_end = input_pos + (bid_x + 1) * TILE_M
k_seqlen = K.shape[2]
if CAUSAL:
# when kv pos could exceed q pos
mask_start = (input_pos + bid_x * TILE_M) // TILE_N
# when kv pos could exceed k_seqlen
mask_start = min(mask_start, k_seqlen // TILE_N)
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
# Loop over K, V blocks (N-dimension chunks)
for step in range(0, Tc):
if i % 2 == 0:
j = step
else:
j = Tc - 1 - step
# --- Compute QK product ---
k = ct.load(
K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2,
)
k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N]
qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
qk = ct.mma(q, k, qk) # [TILE_M, TILE_N]
# --- Apply Causal Masking ---
if (CAUSAL or not EVEN_K) and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool)
# out of bound mask
if not EVEN_K:
mask = mask & (offs_n < k_seqlen)
# causal mask
if CAUSAL:
mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N]
mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N]
qk += mask
# --- Online Softmax Update ---
# Moving qk_scale multiplication after reduce_max is to improve performance.
m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale)
qk = qk * qk_scale - m_ij # [TILE_M, TILE_N]
# attention weights
p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N]
l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1]
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1]
# update m_i and l_i
l_i = l_i * alpha + l_ij # [TILE_M, 1]
# scale acc
acc = acc * alpha # [TILE_M, TILE_N]
# --- Compute PV product ---
v = ct.load(
V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
latency=4,
).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D]
p = p.astype(Q.dtype)
acc = ct.mma(p, v, acc) # [TILE_M, TILE_N]
m_i = m_ij # [TILE_M, 1]
# --- Final Normalization and Store ---
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
# --- Wrapper function to launch the FMHA kernel ---
def cutile_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
qk_scale: float | None = None,
input_pos: int = 0,
tile_m: int = 128,
tile_n: int = 128,
query_group_size: int = 1,
causal: bool = False) -> torch.Tensor:
"""
Performs Fused Multi-Head Attention (FMHA) using a cuTile kernel.
Args:
Q (torch.Tensor): Query tensor (Batch, Heads, SeqLen_Q, D_k).
K (torch.Tensor): Key tensor (Batch, KV_Heads, SeqLen_KV, D_k).
V (torch.Tensor): Value tensor (Batch, KV_Heads, SeqLen_KV, D_v).
qk_scale (float, optional): Scaling factor for QK dot product. Defaults to 1/sqrt(D_k).
input_pos (int, optional): Global start pos for queries (causal masking). Defaults to 0.
tile_m (int): Tile size for Query sequence length (M dimension).
tile_n (int): Tile size for Key/Value sequence length (N dimension).
query_group_size (int): Number of query heads per key/value head.
causal (bool): If True, applies causal masking.
Returns:
torch.Tensor: Output tensor (Batch, Heads, SeqLen_Q, D_v).
"""
# --- Input Validation ---
if Q.ndim != 4 or K.ndim != 4 or V.ndim != 4:
raise ValueError("Input tensors Q, K, V must be 4D (Batch, Heads, SeqLen, Dim).")
if Q.shape[0] != K.shape[0] or Q.shape[0] != V.shape[0]:
raise ValueError("Batch dimensions must match for Q, K, V.")
if Q.shape[1] % query_group_size != 0:
raise ValueError("Number of query heads must be divisible by query_group_size.")
if K.shape[1] * query_group_size != Q.shape[1]:
raise ValueError("K_Heads * query_group_size must equal Q_Heads.")
if Q.shape[3] != K.shape[3]:
raise ValueError("D_k (last dim of Q and K) must match.")
if K.shape[2] != V.shape[2]:
raise ValueError("SeqLen_KV (dim 2 of K and V) must match.")
if Q.device != K.device or Q.device != V.device or not Q.is_cuda:
raise ValueError("All input tensors must be on the same CUDA device.")
if Q.dtype != K.dtype or Q.dtype != V.dtype:
raise ValueError("All input tensors must have the same data type.")
Batch, Heads, SeqLen_Q, D_k = Q.shape
_, KV_Heads, SeqLen_KV, D_v = V.shape
even_k = (SeqLen_KV % tile_n) == 0
if qk_scale is None:
qk_scale = 1.0 / math.sqrt(D_k)
# --- Create Output Tensor ---
Out = torch.empty((Batch, Heads, SeqLen_Q, D_v), dtype=Q.dtype, device=Q.device)
# --- Calculate Grid Dimensions ---
grid_x = math.ceil(math.ceil(SeqLen_Q / tile_m)/TILE_X) # we manually tile x by 2
grid_y = Batch * Heads
grid = (grid_x, grid_y, 1)
# --- Launch the FMHA Kernel ---
ct.launch(torch.cuda.current_stream(), grid, fmha_kernel, (
Q, K, V, Out,
qk_scale,
input_pos,
D_k,
Heads,
tile_m,
tile_n,
query_group_size,
causal,
even_k,
TILE_X,
))
return Out
# --- Wrapper function to launch the FMHA kernel with autotuning ---
def cutile_autotune_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
qk_scale: float,
input_pos: int = 0,
query_group_size: int = 1,
causal: bool = False) -> tuple[torch.Tensor, dict[str, int]]:
"""
Performs Fused Multi-Head Attention (FMHA) using a cuTile kernel with autotuning.
Args:
Q (torch.Tensor): Query tensor (Batch, Heads, SeqLen_Q, D_k).
K (torch.Tensor): Key tensor (Batch, KV_Heads, SeqLen_KV, D_k).
V (torch.Tensor): Value tensor (Batch, KV_Heads, SeqLen_KV, D_v).
qk_scale (float, optional): Scaling factor for QK dot product. Defaults to 1/sqrt(D_k).
input_pos (int, optional): Global start pos for queries (causal masking). Defaults to 0.
query_group_size (int): Number of query heads per key/value head.
causal (bool): If True, applies causal masking.
autotuner (Autotuner | None): Autotuner object that was injected by the autotune decorator.
Returns:
torch.Tensor: Output tensor (Batch, Heads, SeqLen_Q, D_v).
dict[str, int]: The best configuration found by the autotuner.
"""
Batch, Heads, SeqLen_Q, D_k = Q.shape
_, KV_Heads, SeqLen_KV, D_v = V.shape
# --- Create Output Tensor ---
Out = torch.empty((Batch, Heads, SeqLen_Q, D_v), dtype=Q.dtype, device=Q.device)
# --- Tune/Get the best configuration for the FMHA Kernel ---
tuned_result = ct_experimental.autotune_launch(
torch.cuda.current_stream(),
grid_fn=lambda cfg: (math.ceil(SeqLen_Q / cfg.TILE_M), Batch * Heads, 1),
kernel=fmha_kernel,
args_fn=lambda cfg: (
Q, K, V, Out,
qk_scale, input_pos, D_k, Heads,
cfg.TILE_M, cfg.TILE_N, query_group_size, causal, (SeqLen_KV % cfg.TILE_N) == 0
),
hints_fn=lambda cfg: {
"num_ctas": cfg.num_ctas,
"occupancy": cfg.occupancy,
},
search_space=[
SimpleNamespace(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=2),
SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=2, occupancy=2),
SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2),
SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=1),
SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=4),
SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=2, occupancy=1),
SimpleNamespace(TILE_M=64, TILE_N=32, num_ctas=1, occupancy=2),
SimpleNamespace(TILE_M=256, TILE_N=32, num_ctas=2, occupancy=2),
SimpleNamespace(TILE_M=32, TILE_N=32, num_ctas=1, occupancy=1),
],
)
return Out, tuned_result.tuned_config
def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
is_causal: bool, enable_gqa: bool) -> torch.Tensor:
backend = SDPBackend.CUDNN_ATTENTION \
if (Q.shape[2] == K.shape[2]) \
else SDPBackend.FLASH_ATTENTION
with sdpa_kernel(backend):
ret = scaled_dot_product_attention(Q, K, V,
is_causal=is_causal,
enable_gqa=enable_gqa)
return retMinimum reproducible example
Relevant log output
(venv) ubuntu@140-238-201-75:~/AttnFMHA/packaged$ python3 AttentionFMHAEntryPoint.py --variant=tile_alt
--- Running cuTile FMHA: variant=tile_alt, quant=fp16, causal=False ---
Configuration:
Batch Size: 1
Number of Heads: 1
Query Sequence Length: 262144
KV Sequence Length: 262144
Head Dimension (D_k): 64
Value Dimension (D_v): 64
Quantization: fp16 (torch.float16)
Input Q shape: torch.Size([1, 1, 262144, 64]), dtype: torch.float16
Input K shape: torch.Size([1, 1, 262144, 64]), dtype: torch.float16
Input V shape: torch.Size([1, 1, 262144, 64]), dtype: torch.float16
Estimated FLOPs: 17592186044416
--- Causal = False ---
Traceback (most recent call last):
File "/home/ubuntu/AttnFMHA/packaged/AttentionFMHAEntryPoint.py", line 246, in <module>
main()
File "/home/ubuntu/AttnFMHA/packaged/AttentionFMHAEntryPoint.py", line 235, in main
run_benchmark(
File "/home/ubuntu/AttnFMHA/packaged/AttentionFMHAEntryPoint.py", line 124, in run_benchmark
output_fmha_cutile = cutile_fmha_fn(
^^^^^^^^^^^^^^^
File "/home/ubuntu/AttnFMHA/packaged/AttentionFMHATileAlt.py", line 206, in cutile_fmha
ct.launch(torch.cuda.current_stream(), grid, fmha_kernel, (
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_compile.py", line 258, in __call__
lib = compile_tile(self.pyfunc, pyfunc_args, self.compiler_options, tile_context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_compile.py", line 73, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_compile.py", line 185, in compile_tile
func_ir = _get_final_ir(pyfunc, ir_args, context.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_compile.py", line 83, in _get_final_ir
func_body = hir2ir(func_hir, args, ir_ctx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 28, in hir2ir
return run_coroutine(_hir2ir_coroutine(func_hir, args, ir_ctx))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_coroutine_util.py", line 38, in run_coroutine
raise exc_info[1]
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_coroutine_util.py", line 23, in run_coroutine
continuation = top.send(ret) if exc_info is None else top.throw(*exc_info)
^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 50, in _hir2ir_coroutine
await _dispatch_hir_block_inner(func_hir.body, ir_builder)
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 84, in _dispatch_hir_block_inner
await _dispatch_call(call, scope)
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 136, in _dispatch_call
retval = await call(callee_var, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 150, in call
result = await result
^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/op_impl.py", line 59, in wrapper
return await func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/ops.py", line 222, in loop_impl
await dispatch_hir_block(body)
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 74, in dispatch_hir_block
await _dispatch_hir_block_inner(block, cur_builder)
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 84, in _dispatch_hir_block_inner
await _dispatch_call(call, scope)
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 136, in _dispatch_call
retval = await call(callee_var, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 150, in call
result = await result
^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/op_impl.py", line 59, in wrapper
return await func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/ops.py", line 222, in loop_impl
await dispatch_hir_block(body)
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 74, in dispatch_hir_block
await _dispatch_hir_block_inner(block, cur_builder)
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 84, in _dispatch_hir_block_inner
await _dispatch_call(call, scope)
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 136, in _dispatch_call
retval = await call(callee_var, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 148, in call
result = impl(*arg_list)
^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/op_impl.py", line 70, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/ops.py", line 3901, in load_var_impl
ret.get_type() # Trigger an InvalidType check
^^^^^^^^^^^^^^
File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/ir.py", line 169, in get_type
raise TileTypeError(ty.error_message, ty.loc)
cuda.tile._exception.TileTypeError: Type of `j` depends on path taken: Tile[int32,()] (line 94) vs. int32 (line 92)
"/home/ubuntu/AttnFMHA/packaged/AttentionFMHATileAlt.py", line 94, col 17, in fmha_kernel:
j = Tc - 1 - stepFull env printout
Other/Misc.
No response
Contributing Guidelines
- I agree to follow cuTile Python's contributing guidelines
- I have searched the open bugs and have found no duplicates for this bug report
Reactions are currently unavailable