Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
# ruff: noqa: E501
import torch
from einops import rearrange
import functools
import os
from lightllm.utils.log_utils import init_logger

from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
Expand All @@ -19,6 +22,36 @@
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd

logger = init_logger(__name__)


@functools.lru_cache(maxsize=1)
def _flashqla_chunk_gated_delta_rule():
if os.environ.get("LIGHTLLM_DISABLE_FLASHQLA", "0").lower() in ["1", "true", "yes"]:
return None
try:
import flash_qla
except ImportError:
return None
if not torch.cuda.is_available():
return None
if torch.cuda.get_device_capability() < (9, 0):
return None
tv = torch.__version__.split("+")[0].split(".")
if (int(tv[0]), int(tv[1])) < (2, 8):
return None
cv = torch.version.cuda
if cv is None:
return None
cv_parts = cv.split(".")
if (int(cv_parts[0]), int(cv_parts[1])) < (12, 8):
Comment on lines +41 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The version checks for PyTorch (>= 2.8) and CUDA (>= 12.8) appear to be placeholders or typos, as these versions are not yet released (current stable versions are typically PyTorch 2.5/2.6 and CUDA 12.4/12.6). As written, this logic will disable the FlashQLA backend for almost all current environments. Please verify if these should be lower versions (e.g., PyTorch 2.4 and CUDA 12.1).

return None
logger.info(
"qwen3next chunk_gated_delta_rule: using FlashQLA backend (flash_qla.chunk_gated_delta_rule); "
"set LIGHTLLM_DISABLE_FLASHQLA=1 to fall back to the FLA Triton kernels."
)
return flash_qla.chunk_gated_delta_rule


def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
Expand Down Expand Up @@ -183,6 +216,22 @@ def chunk_gated_delta_rule(
cu_seqlens=cu_seqlens
)
"""
flashqla_fn = _flashqla_chunk_gated_delta_rule()
if flashqla_fn is not None and not head_first:
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
Comment on lines +219 to +233
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The scale parameter is passed to flashqla_fn before it is assigned its default value. In the fallback path (lines 257-258), scale defaults to k.shape[-1] ** -0.5 if it is None. If flash_qla.chunk_gated_delta_rule does not handle None for the scale argument, this will lead to incorrect results or a crash. You should move the default scale calculation before the FlashQLA dispatch logic.

Suggested change
flashqla_fn = _flashqla_chunk_gated_delta_rule()
if flashqla_fn is not None and not head_first:
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
if scale is None:
scale = k.shape[-1] ** -0.5
flashqla_fn = _flashqla_chunk_gated_delta_rule()
if flashqla_fn is not None and not head_first:
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)


assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
Expand Down
Loading
Loading