Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from typing import Optional, TYPE_CHECKING
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.sgl_utils import flash_attn_with_kvcache, flash_attn_with_kvcache_autotune
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
Expand Down Expand Up @@ -222,7 +222,7 @@ def _normal_decode_att(
k_descale, v_descale = None, None # disable quantization
Lq = q.shape[-1]
sm_scale = 1.0 / (Lq ** 0.5)
o = flash_attn_with_kvcache(
o = flash_attn_with_kvcache_autotune(
q=q,
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]),
Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _nsa_decode_att(
kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
from sgl_kernel.flash_attn import flash_attn_with_kvcache
from lightllm.utils.sgl_utils import flash_attn_with_kvcache

nsa_dict = att_control.nsa_decode_dict
topk_mem_indices = nsa_dict["topk_mem_indices"]
Expand Down
15 changes: 15 additions & 0 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import orjson
import os
import inspect
import gc
import torch
import torch.distributed as dist
import random
Expand Down Expand Up @@ -274,7 +275,20 @@ def kernel_call():
except (OutOfResources, PTXASError, CompileTimeAssertionFailure, RuntimeError, Exception):
return float("inf")

def _autotune_boundary_sync(self, world_size):
# Clear autotune boundary state so adjacent kernel tuning runs interfere less.
torch.cuda.synchronize()
if world_size > 1:
dist.barrier(group=self._get_autotune_group())
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
if world_size > 1:
dist.barrier(group=self._get_autotune_group())

def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
self._autotune_boundary_sync(world_size)

is_key_all_same = True
if world_size > 1:
all_keys = [None for _ in range(world_size)]
Expand Down Expand Up @@ -355,6 +369,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
)
logger.info(f"Saved configs for {self.kernel_name} - {_static_key}")

self._autotune_boundary_sync(world_size)
logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished")

def _mutate_args_clone(self, args, kwargs):
Expand Down
4 changes: 4 additions & 0 deletions lightllm/models/qwen2_vl/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def init_some_extra_state(self, model):
InferStateInfo.init_some_extra_state(self, model)
if self.is_prefill:
self.position_ids = self.get_mrope_position(self.multimodal_params)
elif self.multimodal_params is None or not any(
p.get("images") or p.get("audios") for p in self.multimodal_params
):
self.position_ids = self.position_ids.unsqueeze(0).expand(3, -1)
else:
b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
for batch_idx, p in enumerate(self.multimodal_params):
Expand Down
207 changes: 204 additions & 3 deletions lightllm/utils/sgl_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import torch

from frozendict import frozendict
from lightllm.common.triton_utils.autotuner import AutotuneLevel, Autotuner
from lightllm.utils.envs_utils import get_triton_autotune_level
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)

_DEFAULT_NUM_SPLITS = 0

try:
import sgl_kernel

Expand All @@ -17,16 +25,209 @@
)

try:
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache as _flash_attn_with_kvcache

flash_attn_varlen_func = flash_attn_varlen_func
flash_attn_with_kvcache = flash_attn_with_kvcache
merge_state_v2 = sgl_ops.merge_state_v2
except:
flash_attn_varlen_func = None
flash_attn_with_kvcache = None
_flash_attn_with_kvcache = None
merge_state_v2 = None
logger.warning(
"sgl_kernel is not installed, or the installed version did not support fa3. \
Try to upgrade it."
)


def _flash_attn_kvcache_num_splits_configs():
return [{"num_splits": num_splits} for num_splits in [0, 16, 32]]


def _flash_attn_kvcache_static_key(q, k_cache, v_cache, causal, window_size, softcap, sinks):
return {
"qd": str(q.dtype),
"kd": str(k_cache.dtype),
"vd": str(v_cache.dtype),
"qh": int(q.shape[-2]),
"kh": int(k_cache.shape[-2]),
"hd": int(q.shape[-1]),
"vh": int(v_cache.shape[-1]),
"pb": int(k_cache.shape[-3]),
"c": int(bool(causal)),
"wl": int(window_size[0]),
"wr": int(window_size[1]),
"sc": int(softcap > 0.0),
"sk": int(sinks is not None),
"sgl": getattr(sgl_ops, "__version__", "unknown"),
}
Comment on lines +46 to +62

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _flash_attn_kvcache_static_key, window_size is assumed to be a subscriptable sequence of at least 2 elements. However, in some contexts or configurations, window_size can be passed as an integer (for symmetric window size) or None. Accessing window_size[0] directly will raise a TypeError in those cases. We should handle None and integer types for window_size robustly.

def _flash_attn_kvcache_static_key(q, k_cache, v_cache, causal, window_size, softcap, sinks):
    if window_size is None:
        wl, wr = -1, -1
    elif isinstance(window_size, int):
        wl, wr = window_size, window_size
    else:
        wl, wr = window_size[0], window_size[1]
    return {
        "qd": str(q.dtype),
        "kd": str(k_cache.dtype),
        "vd": str(v_cache.dtype),
        "qh": int(q.shape[-2]),
        "kh": int(k_cache.shape[-2]),
        "hd": int(q.shape[-1]),
        "vh": int(v_cache.shape[-1]),
        "pb": int(k_cache.shape[-3]),
        "c": int(bool(causal)),
        "wl": int(wl),
        "wr": int(wr),
        "sc": int(softcap > 0.0),
        "sk": int(sinks is not None),
        "sgl": getattr(sgl_ops, "__version__", "unknown"),
    }



def _flash_attn_max_q_len(q, max_seqlen_q):
if max_seqlen_q is not None:
return int(max_seqlen_q)
if q.dim() >= 4:
return int(q.shape[1])
return int(q.shape[0])


def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q):
batch_size = int(page_table.shape[0])
max_q_len = _flash_attn_max_q_len(q, max_seqlen_q)
max_kv_len = int(page_table.shape[1])
return batch_size * 1_000_000_000_000 + max_q_len * 1_000_000 + max_kv_len
Comment on lines +73 to +77

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _flash_attn_kvcache_run_key, using 1_000_000 as the multiplier for max_q_len can lead to key collisions for ultra-long context lengths (where max_kv_len exceeds 1,000,000 tokens). To support context lengths of 1M+ tokens robustly, we should increase the multipliers for max_q_len and batch_size.

Suggested change
def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q):
batch_size = int(page_table.shape[0])
max_q_len = _flash_attn_max_q_len(q, max_seqlen_q)
max_kv_len = int(page_table.shape[1])
return batch_size * 1_000_000_000_000 + max_q_len * 1_000_000 + max_kv_len
def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q):
batch_size = int(page_table.shape[0])
max_q_len = _flash_attn_max_q_len(q, max_seqlen_q)
max_kv_len = int(page_table.shape[1])
return batch_size * 1_000_000_000_000_000 + max_q_len * 100_000_000 + max_kv_len



def _flash_attn_is_decode_like(q, page_table, max_seqlen_q=None):
if page_table is None or page_table.dim() < 2:
return False

max_q_len = _flash_attn_max_q_len(q, max_seqlen_q)
if max_q_len <= 0 or int(page_table.shape[1]) <= max_q_len:
return False

q_token_num = int(q.shape[0]) * int(q.shape[1]) if q.dim() >= 4 else int(q.shape[0])
return q_token_num == int(page_table.shape[0]) * max_q_len


def _flash_attn_should_autotune(q, kwargs):
return (
kwargs.get("num_splits", _DEFAULT_NUM_SPLITS) == _DEFAULT_NUM_SPLITS
and kwargs.get("k") is None
and kwargs.get("v") is None
and kwargs.get("out") is None
and kwargs.get("qv") is None
and kwargs.get("q_descale") is None
and kwargs.get("k_descale") is None
and kwargs.get("v_descale") is None
and _flash_attn_is_decode_like(q, kwargs.get("page_table"), kwargs.get("max_seqlen_q"))
)


def _flash_attn_with_kvcache_autotune_call(call_kwargs):
tuner = _flash_attn_with_kvcache_autotuned

if get_triton_autotune_level() == AutotuneLevel.ADAPTIVE_AUTOTUNE:
static_key = frozendict(tuner._static_key(**call_kwargs))
run_key = str(tuner._run_key(**call_kwargs))
tuner._try_load_cache(static_key)

if run_key not in tuner.cached_configs.get(static_key, {}) and not Autotuner.is_autotune_warmup():
Autotuner.start_autotune_warmup()
try:
return tuner(**call_kwargs)
finally:
Autotuner.end_autotune_warmup()

return tuner(**call_kwargs)


def _flash_attn_decode_bench_kv_lens(page_table):
if page_table is None or page_table.dim() < 2:
return []

max_kv_len = int(page_table.shape[1])
if max_kv_len <= 0:
return []

return [10240, max_kv_len]
Comment on lines +124 to +132

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _flash_attn_decode_bench_kv_lens, returning a hardcoded benchmark sequence length of 10240 when max_kv_len is smaller than 10240 will cause out-of-bounds indexing on the page_table tensor during autotuning. This can lead to illegal memory accesses or GPU page faults. We should cap the benchmark sequence length to max_kv_len if it is smaller than 10240.

Suggested change
def _flash_attn_decode_bench_kv_lens(page_table):
if page_table is None or page_table.dim() < 2:
return []
max_kv_len = int(page_table.shape[1])
if max_kv_len <= 0:
return []
return [10240, max_kv_len]
def _flash_attn_decode_bench_kv_lens(page_table):
if page_table is None or page_table.dim() < 2:
return []
max_kv_len = int(page_table.shape[1])
if max_kv_len <= 0:
return []
if max_kv_len < 10240:
return [max_kv_len]
return [10240, max_kv_len]



class _FlashAttnKvcacheAutotuner(Autotuner):
def _bench(self, *args, n_repeat=3, n_retries=3, **kwargs):
page_table = kwargs.get("page_table")
cache_seqlens = kwargs.get("cache_seqlens")

bench_times = []
for bench_kv_len in _flash_attn_decode_bench_kv_lens(page_table):
bench_kwargs = kwargs.copy()
if isinstance(cache_seqlens, torch.Tensor):
bench_cache_seqlens = cache_seqlens.clone()
bench_cache_seqlens.fill_(bench_kv_len)
else:
bench_cache_seqlens = bench_kv_len
bench_kwargs["cache_seqlens"] = bench_cache_seqlens

cu_seqlens_k_new = bench_kwargs.get("cu_seqlens_k_new")
if isinstance(cu_seqlens_k_new, torch.Tensor) and cu_seqlens_k_new.numel() != 0:
bench_cu_seqlens_k_new = torch.arange(
cu_seqlens_k_new.numel(), device=cu_seqlens_k_new.device, dtype=cu_seqlens_k_new.dtype
)
bench_cu_seqlens_k_new *= bench_kv_len
bench_kwargs["cu_seqlens_k_new"] = bench_cu_seqlens_k_new

bench_times.append(super()._bench(*args, n_repeat=n_repeat, n_retries=n_retries, **bench_kwargs))

if bench_times:
return sum(bench_times) / len(bench_times)

return super()._bench(*args, n_repeat=n_repeat, n_retries=n_retries, **kwargs)


if _flash_attn_with_kvcache is not None and torch.cuda.is_available():

@torch.no_grad()
def _flash_attn_with_kvcache_autotuned_impl(
q,
k_cache,
v_cache,
cache_seqlens=None,
page_table=None,
cu_seqlens_q=None,
cu_seqlens_k_new=None,
max_seqlen_q=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
num_splits=0,
sinks=None,
run_config=None,
**kwargs,
):
if run_config is not None:
num_splits = run_config["num_splits"]
return _flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
page_table=page_table,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k_new,
max_seqlen_q=max_seqlen_q,
causal=causal,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
sinks=sinks,
**kwargs,
)

_flash_attn_with_kvcache_autotuned = _FlashAttnKvcacheAutotuner(
fn=_flash_attn_with_kvcache_autotuned_impl,
kernel_name="sgl_fa3_kvcache_ns:v1",
configs_gen_func=_flash_attn_kvcache_num_splits_configs,
static_key_func=_flash_attn_kvcache_static_key,
run_key_func=_flash_attn_kvcache_run_key,
)

else:
_flash_attn_with_kvcache_autotuned = None


def _flash_attn_with_kvcache_autotune_wrapper(q, k_cache, v_cache, **kwargs):
if _flash_attn_with_kvcache_autotuned is None or not _flash_attn_should_autotune(q, kwargs):
return _flash_attn_with_kvcache(q=q, k_cache=k_cache, v_cache=v_cache, **kwargs)

call_kwargs = {"q": q, "k_cache": k_cache, "v_cache": v_cache, **kwargs}
call_kwargs.setdefault("causal", False)
call_kwargs.setdefault("window_size", (-1, -1))
call_kwargs.setdefault("softcap", 0.0)
call_kwargs.setdefault("sinks", None)
call_kwargs.setdefault("num_splits", _DEFAULT_NUM_SPLITS)
return _flash_attn_with_kvcache_autotune_call(call_kwargs)


flash_attn_with_kvcache = _flash_attn_with_kvcache
flash_attn_with_kvcache_autotune = (
None if _flash_attn_with_kvcache is None else _flash_attn_with_kvcache_autotune_wrapper
)
Loading