Skip to content

[BUG]: Regression of branch type inference #70

@SchrodingerZhu

Description

@SchrodingerZhu

Version

1.1.0

Version

13.1

Which installation method(s) does this occur on?

No response

Describe the bug.

Latest code have some changed behavior in type inference of j on branch:

            if i % 2 == 0:
                j = step
            else:
                j = Tc - 1 - step

At 9af1b63, the following code can pass frontend:

# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

import argparse
import cuda.tile as ct
try:
    import cuda.tile_experimental as ct_experimental
except ImportError:
    ct_experimental = None
import torch
import math
import sys

from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
from utils.benchmark import report_benchmark
from types import SimpleNamespace


import numpy as np
from cuda.tile import RoundingMode as RMd


INV_LOG_2 = 1.0 / math.log(2)
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
TILE_X = 2

@ct.kernel(occupancy=1)
def fmha_kernel(Q, K, V, Out,
                qk_scale: float,
                input_pos: int,
                TILE_D: ConstInt,  # TILE_D = hidden_size
                H: ConstInt,
                TILE_M: ConstInt,
                TILE_N: ConstInt,
                QUERY_GROUP_SIZE: ConstInt,
                CAUSAL: ConstBool,
                EVEN_K: ConstBool,
                TILE_X: ConstInt):
    """
    cuTile kernel for Fused Multi-Head Attention (FMHA).
    Computes attention output for a specific batch item and head, using tiling and online softmax.
    """
    # Map block IDs to batch and head indices
    bid_start = ct.bid(0) * TILE_X
    bid_y = ct.bid(1)
    batch_idx = bid_y // H
    head_idx = bid_y % H
    off_kv_h = head_idx // QUERY_GROUP_SIZE

    # Adjust qk_scale for exp2
    qk_scale = qk_scale * INV_LOG_2
    for i in range(0, TILE_X):
        bid_x = bid_start + i
        # Initialize offsets for current query tile (M-dimension)
        offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32)  # [TILE_M]
        offs_m += input_pos
        offs_m = offs_m[:, None]  # [TILE_M, 1]

        # Initialize local offsets for key/value tile (N-dimension)
        offs_n_tile = ct.arange(TILE_N, dtype=np.int32)  # [TILE_N]
        offs_n_tile = offs_n_tile[None, :]  # [1, TILE_N]

        # Initialize online softmax accumulators in float32 for stability
        m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32)
        l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32)
        acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)

        # Load query tile for this batch, head, and M-chunk
        q = ct.load(
            Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
        ).reshape((TILE_M, TILE_D))  # [TILE_M, TILE_D]

        # loop over k, v and update accumulator
        m_end = input_pos + (bid_x + 1) * TILE_M
        k_seqlen = K.shape[2]
        if CAUSAL:
            # when kv pos could exceed q pos
            mask_start = (input_pos + bid_x * TILE_M) // TILE_N
            # when kv pos could exceed k_seqlen
            mask_start = min(mask_start, k_seqlen // TILE_N)
            Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
        else:
            Tc = ct.cdiv(k_seqlen, TILE_N)
            mask_start = k_seqlen // TILE_N

        # Loop over K, V blocks (N-dimension chunks)
        for step in range(0, Tc):
            if i % 2 == 0:
                j = step
            else:
                j = Tc - 1 - step
            # --- Compute QK product ---
            k = ct.load(
                K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
                order=(0, 1, 3, 2),
                latency=2,
            )
            k = k.reshape((TILE_D, TILE_N))  # [TILE_D, TILE_N]
            qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
            qk = ct.mma(q, k, qk)  # [TILE_M, TILE_N]

            # --- Apply Causal Masking ---
            if (CAUSAL or not EVEN_K) and j >= mask_start:
                offs_n = j * TILE_N + offs_n_tile
                mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool)
                # out of bound mask
                if not EVEN_K:
                    mask = mask & (offs_n < k_seqlen)
                # causal mask
                if CAUSAL:
                    mask = mask & (offs_m >= offs_n)  # [TILE_M, TILE_N]
                mask = ct.where(mask, 0.0, -np.inf)  # [TILE_M, TILE_N]
                qk += mask

            # --- Online Softmax Update ---
            # Moving qk_scale multiplication after reduce_max is to improve performance.
            m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale)
            qk = qk * qk_scale - m_ij  # [TILE_M, TILE_N]

            # attention weights
            p = ct.exp2(qk, flush_to_zero=True)  # [TILE_M, TILE_N]
            l_ij = ct.sum(p, axis=-1, keepdims=True)  # [TILE_M, 1]
            alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)  # [TILE_M, 1]
            # update m_i and l_i
            l_i = l_i * alpha + l_ij  # [TILE_M, 1]
            # scale acc
            acc = acc * alpha  # [TILE_M, TILE_N]

            # --- Compute PV product ---
            v = ct.load(
                V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
                latency=4,
            ).reshape((TILE_N, TILE_D))  # [TILE_N, TILE_D]
            p = p.astype(Q.dtype)
            acc = ct.mma(p, v, acc)  # [TILE_M, TILE_N]
            m_i = m_ij  # [TILE_M, 1]

        # --- Final Normalization and Store ---
        acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
        acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
        ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)


# --- Wrapper function to launch the FMHA kernel ---
def cutile_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                qk_scale: float | None = None,
                input_pos: int = 0,
                tile_m: int = 128,
                tile_n: int = 128,
                query_group_size: int = 1,
                causal: bool = False) -> torch.Tensor:
    """
    Performs Fused Multi-Head Attention (FMHA) using a cuTile kernel.

    Args:
        Q (torch.Tensor): Query tensor (Batch, Heads, SeqLen_Q, D_k).
        K (torch.Tensor): Key tensor (Batch, KV_Heads, SeqLen_KV, D_k).
        V (torch.Tensor): Value tensor (Batch, KV_Heads, SeqLen_KV, D_v).
        qk_scale (float, optional): Scaling factor for QK dot product. Defaults to 1/sqrt(D_k).
        input_pos (int, optional): Global start pos for queries (causal masking). Defaults to 0.
        tile_m (int): Tile size for Query sequence length (M dimension).
        tile_n (int): Tile size for Key/Value sequence length (N dimension).
        query_group_size (int): Number of query heads per key/value head.
        causal (bool): If True, applies causal masking.

    Returns:
        torch.Tensor: Output tensor (Batch, Heads, SeqLen_Q, D_v).
    """
    # --- Input Validation ---
    if Q.ndim != 4 or K.ndim != 4 or V.ndim != 4:
        raise ValueError("Input tensors Q, K, V must be 4D (Batch, Heads, SeqLen, Dim).")
    if Q.shape[0] != K.shape[0] or Q.shape[0] != V.shape[0]:
        raise ValueError("Batch dimensions must match for Q, K, V.")
    if Q.shape[1] % query_group_size != 0:
        raise ValueError("Number of query heads must be divisible by query_group_size.")
    if K.shape[1] * query_group_size != Q.shape[1]:
        raise ValueError("K_Heads * query_group_size must equal Q_Heads.")
    if Q.shape[3] != K.shape[3]:
        raise ValueError("D_k (last dim of Q and K) must match.")
    if K.shape[2] != V.shape[2]:
        raise ValueError("SeqLen_KV (dim 2 of K and V) must match.")
    if Q.device != K.device or Q.device != V.device or not Q.is_cuda:
        raise ValueError("All input tensors must be on the same CUDA device.")
    if Q.dtype != K.dtype or Q.dtype != V.dtype:
        raise ValueError("All input tensors must have the same data type.")

    Batch, Heads, SeqLen_Q, D_k = Q.shape
    _, KV_Heads, SeqLen_KV, D_v = V.shape
    even_k = (SeqLen_KV % tile_n) == 0

    if qk_scale is None:
        qk_scale = 1.0 / math.sqrt(D_k)

    # --- Create Output Tensor ---
    Out = torch.empty((Batch, Heads, SeqLen_Q, D_v), dtype=Q.dtype, device=Q.device)

    # --- Calculate Grid Dimensions ---
    grid_x = math.ceil(math.ceil(SeqLen_Q / tile_m)/TILE_X) # we manually tile x by 2
    grid_y = Batch * Heads
    grid = (grid_x, grid_y, 1)

    # --- Launch the FMHA Kernel ---
    ct.launch(torch.cuda.current_stream(), grid, fmha_kernel, (
        Q, K, V, Out,
        qk_scale,
        input_pos,
        D_k,
        Heads,
        tile_m,
        tile_n,
        query_group_size,
        causal,
        even_k,
        TILE_X,
    ))

    return Out


# --- Wrapper function to launch the FMHA kernel with autotuning ---
def cutile_autotune_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                         qk_scale: float,
                         input_pos: int = 0,
                         query_group_size: int = 1,
                         causal: bool = False) -> tuple[torch.Tensor, dict[str, int]]:
    """
    Performs Fused Multi-Head Attention (FMHA) using a cuTile kernel with autotuning.

    Args:
        Q (torch.Tensor): Query tensor (Batch, Heads, SeqLen_Q, D_k).
        K (torch.Tensor): Key tensor (Batch, KV_Heads, SeqLen_KV, D_k).
        V (torch.Tensor): Value tensor (Batch, KV_Heads, SeqLen_KV, D_v).
        qk_scale (float, optional): Scaling factor for QK dot product. Defaults to 1/sqrt(D_k).
        input_pos (int, optional): Global start pos for queries (causal masking). Defaults to 0.
        query_group_size (int): Number of query heads per key/value head.
        causal (bool): If True, applies causal masking.
        autotuner (Autotuner | None): Autotuner object that was injected by the autotune decorator.

    Returns:
        torch.Tensor: Output tensor (Batch, Heads, SeqLen_Q, D_v).
        dict[str, int]: The best configuration found by the autotuner.
    """
    Batch, Heads, SeqLen_Q, D_k = Q.shape
    _, KV_Heads, SeqLen_KV, D_v = V.shape

    # --- Create Output Tensor ---
    Out = torch.empty((Batch, Heads, SeqLen_Q, D_v), dtype=Q.dtype, device=Q.device)

    # --- Tune/Get the best configuration for the FMHA Kernel ---
    tuned_result = ct_experimental.autotune_launch(
        torch.cuda.current_stream(),
        grid_fn=lambda cfg: (math.ceil(SeqLen_Q / cfg.TILE_M), Batch * Heads, 1),
        kernel=fmha_kernel,
        args_fn=lambda cfg: (
            Q, K, V, Out,
            qk_scale, input_pos, D_k, Heads,
            cfg.TILE_M, cfg.TILE_N, query_group_size, causal, (SeqLen_KV % cfg.TILE_N) == 0
        ),
        hints_fn=lambda cfg: {
            "num_ctas": cfg.num_ctas,
            "occupancy": cfg.occupancy,
        },
        search_space=[
            SimpleNamespace(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=2),
            SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=2, occupancy=2),
            SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2),
            SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=1),
            SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=4),
            SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=2, occupancy=1),
            SimpleNamespace(TILE_M=64, TILE_N=32, num_ctas=1, occupancy=2),
            SimpleNamespace(TILE_M=256, TILE_N=32, num_ctas=2, occupancy=2),
            SimpleNamespace(TILE_M=32, TILE_N=32, num_ctas=1, occupancy=1),
        ],
    )

    return Out, tuned_result.tuned_config


def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
               is_causal: bool, enable_gqa: bool) -> torch.Tensor:
    backend = SDPBackend.CUDNN_ATTENTION \
            if (Q.shape[2] == K.shape[2]) \
            else SDPBackend.FLASH_ATTENTION
    with sdpa_kernel(backend):
        ret = scaled_dot_product_attention(Q, K, V,
                                           is_causal=is_causal,
                                           enable_gqa=enable_gqa)
    return ret

Minimum reproducible example

Relevant log output

(venv) ubuntu@140-238-201-75:~/AttnFMHA/packaged$ python3 AttentionFMHAEntryPoint.py --variant=tile_alt
--- Running cuTile FMHA: variant=tile_alt, quant=fp16, causal=False ---
  Configuration:
    Batch Size: 1
    Number of Heads: 1
    Query Sequence Length: 262144
    KV Sequence Length: 262144
    Head Dimension (D_k): 64
    Value Dimension (D_v): 64
    Quantization: fp16 (torch.float16)
  Input Q shape: torch.Size([1, 1, 262144, 64]), dtype: torch.float16
  Input K shape: torch.Size([1, 1, 262144, 64]), dtype: torch.float16
  Input V shape: torch.Size([1, 1, 262144, 64]), dtype: torch.float16
  Estimated FLOPs: 17592186044416

--- Causal = False ---
Traceback (most recent call last):
  File "/home/ubuntu/AttnFMHA/packaged/AttentionFMHAEntryPoint.py", line 246, in <module>
    main()
  File "/home/ubuntu/AttnFMHA/packaged/AttentionFMHAEntryPoint.py", line 235, in main
    run_benchmark(
  File "/home/ubuntu/AttnFMHA/packaged/AttentionFMHAEntryPoint.py", line 124, in run_benchmark
    output_fmha_cutile = cutile_fmha_fn(
                         ^^^^^^^^^^^^^^^
  File "/home/ubuntu/AttnFMHA/packaged/AttentionFMHATileAlt.py", line 206, in cutile_fmha
    ct.launch(torch.cuda.current_stream(), grid, fmha_kernel, (
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_compile.py", line 258, in __call__
    lib = compile_tile(self.pyfunc, pyfunc_args, self.compiler_options, tile_context)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_compile.py", line 73, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_compile.py", line 185, in compile_tile
    func_ir = _get_final_ir(pyfunc, ir_args, context.config)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_compile.py", line 83, in _get_final_ir
    func_body = hir2ir(func_hir, args, ir_ctx)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 28, in hir2ir
    return run_coroutine(_hir2ir_coroutine(func_hir, args, ir_ctx))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_coroutine_util.py", line 38, in run_coroutine
    raise exc_info[1]
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_coroutine_util.py", line 23, in run_coroutine
    continuation = top.send(ret) if exc_info is None else top.throw(*exc_info)
                   ^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 50, in _hir2ir_coroutine
    await _dispatch_hir_block_inner(func_hir.body, ir_builder)
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 84, in _dispatch_hir_block_inner
    await _dispatch_call(call, scope)
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 136, in _dispatch_call
    retval = await call(callee_var, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 150, in call
    result = await result
             ^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/op_impl.py", line 59, in wrapper
    return await func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/ops.py", line 222, in loop_impl
    await dispatch_hir_block(body)
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 74, in dispatch_hir_block
    await _dispatch_hir_block_inner(block, cur_builder)
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 84, in _dispatch_hir_block_inner
    await _dispatch_call(call, scope)
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 136, in _dispatch_call
    retval = await call(callee_var, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 150, in call
    result = await result
             ^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/op_impl.py", line 59, in wrapper
    return await func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/ops.py", line 222, in loop_impl
    await dispatch_hir_block(body)
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 74, in dispatch_hir_block
    await _dispatch_hir_block_inner(block, cur_builder)
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 84, in _dispatch_hir_block_inner
    await _dispatch_call(call, scope)
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 136, in _dispatch_call
    retval = await call(callee_var, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_passes/hir2ir.py", line 148, in call
    result = impl(*arg_list)
             ^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/op_impl.py", line 70, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/ops.py", line 3901, in load_var_impl
    ret.get_type()  # Trigger an InvalidType check
    ^^^^^^^^^^^^^^
  File "/home/ubuntu/venv/lib/python3.12/site-packages/cuda/tile/_ir/ir.py", line 169, in get_type
    raise TileTypeError(ty.error_message, ty.loc)
cuda.tile._exception.TileTypeError: Type of `j` depends on path taken: Tile[int32,()] (line 94) vs. int32 (line 92)
  "/home/ubuntu/AttnFMHA/packaged/AttentionFMHATileAlt.py", line 94, col 17, in fmha_kernel:
                    j = Tc - 1 - step

Full env printout

Other/Misc.

No response

Contributing Guidelines

  • I agree to follow cuTile Python's contributing guidelines
  • I have searched the open bugs and have found no duplicates for this bug report

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions