Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7c1f2d1
add neo_chat and neo_chat_moe
shihaobai Mar 27, 2026
0272ec3
neo rope
shihaobai Mar 27, 2026
10f760d
support x2i.
Mar 23, 2026
142e3df
add naive x2i backend.
Mar 25, 2026
a5a0844
interleave api.
Mar 27, 2026
d4aaa77
nit
Mar 27, 2026
99deab6
fix.
Mar 27, 2026
f4fccf2
add lightx2v & openrouter api
shihaobai Apr 3, 2026
06571d0
fix x2v && add openai image-only
shihaobai Apr 7, 2026
0619f9d
add naive option.
Apr 8, 2026
e958f11
x2v support
shihaobai Apr 8, 2026
a1870ba
x2v support
shihaobai Apr 8, 2026
539576b
fix x2v acc, because of seed
shihaobai Apr 8, 2026
29dda22
fix chat template
shihaobai Apr 8, 2026
e4effb8
smart resize
shihaobai Apr 8, 2026
a71817a
nit.
Apr 9, 2026
044158c
fix device.
Apr 9, 2026
067a679
support distributed lightx2v.
Apr 10, 2026
6cf6a10
change task distribute from pub/sub to push/pull
shihaobai Apr 13, 2026
29850a6
use poller.
shihaobai Apr 13, 2026
40a16b0
add num_images for x2v
shihaobai Apr 13, 2026
af5c710
fixup.
Apr 13, 2026
2e7c525
enable thinking & input_image num
shihaobai Apr 13, 2026
c9d19d0
keep the same resolution for it2i
shihaobai Apr 13, 2026
d0967b7
fixup.
Apr 14, 2026
f6ae0d1
workaround for illegal memory access.
Apr 15, 2026
20439fe
fix attention
WANDY666 Apr 20, 2026
35906f5
pass unit test
WANDY666 Apr 20, 2026
1c8e721
add fa3_neo
WANDY666 Apr 20, 2026
7838512
import flash_attn_with_kvcache_neo
WANDY666 Apr 20, 2026
b202de0
fix
WANDY666 Apr 20, 2026
b319915
reduce useless nblock
WANDY666 Apr 21, 2026
468ea98
delete max_image_q_idx and reduce_or in kernel
WANDY666 Apr 21, 2026
77dc66f
use triton when import fa3 failed
WANDY666 Apr 21, 2026
4959f4e
verify fa3 image_token_tag
WANDY666 Apr 21, 2026
77f224a
add t2i/it2i thinking
Charles2530 Apr 22, 2026
3e84165
fix build prompt for tool call
Charles2530 Apr 22, 2026
b2f4565
dynamic_resolution, height/weight, seed, image_size
Charles2530 Apr 22, 2026
fa5e1dc
Merge branch 'neo_plus_clean' of https://github.com/ModelTC/LightLLM …
Charles2530 Apr 22, 2026
a4c360d
Merge remote-tracking branch 'origin/main' into neo_plus_clean
Apr 23, 2026
c99e4b3
fix lint.
Apr 23, 2026
2f9489c
add seed rng session
shihaobai May 6, 2026
0613ace
merge main for openai api
shihaobai May 6, 2026
66b420c
Fix Neo image-token attention scope in Triton prefill
WANDY666 May 8, 2026
c09bd8f
delete position ids
WANDY666 May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightllm/common/basemodel/attention/base_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class AttControl:
nsa_prefill_dict: Dict = None
nsa_decode: bool = False
nsa_decode_dict: Dict = None
image_token_tag: Optional[torch.Tensor] = None


@dataclass
Expand Down
45 changes: 45 additions & 0 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
from typing import Optional, TYPE_CHECKING
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache

try:
from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo
import inspect

# Verify this is the neo-patched FA3 build (with image_token_tag support),
_sig = inspect.signature(flash_attn_with_kvcache_neo)
if "image_token_tag" not in _sig.parameters:
raise ImportError("flash_attn_interface found but missing image_token_tag support (need neo build)")

HAS_FLASH_ATTN_INTERFACE = True
except ImportError:
flash_attn_with_kvcache_neo = None
HAS_FLASH_ATTN_INTERFACE = False
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
Expand Down Expand Up @@ -92,6 +106,37 @@ def _nomarl_prefill_att(
k_descale, v_descale = None, None # disable quantization
Lq = q.shape[-1]
sm_scale = 1.0 / (Lq ** 0.5)

# neo_chat*: image-token bidirectional attention requires flash_attn_interface
# (sgl_kernel's flash_attn_with_kvcache does not support image_token_tag).
if att_control.image_token_tag is not None:
if not HAS_FLASH_ATTN_INTERFACE:
raise ImportError(
"flash_attn_interface (fa3-neo) is required for image_token_tag bidirectional "
"attention. Install it or set LIGHTLLM_NEO_PREFILL_TRITON_BACKEND=1 to use the "
"triton fallback."
)
extra_kwargs = {"image_token_tag": att_control.image_token_tag}
o = flash_attn_with_kvcache_neo(
q=q,
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]),
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.infer_state.max_q_seq_len,
softmax_scale=sm_scale,
causal=True,
window_size=window_size,
softcap=0.0,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=False,
**extra_kwargs,
)
return o

o = flash_attn_with_kvcache(
q=q,
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
Expand Down
311 changes: 311 additions & 0 deletions lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,314 @@ def load_cpu_kv_to_gpu(
num_stages=1,
)
return


@triton.jit
def _offload_gpu_kv_to_cpu_for_x2i(
token_indexes_ptr,
gpu_kv_cache_ptr,
gpu_stride0,
gpu_stride1,
gpu_stride2,
gpu_kv_cache_scale_ptr,
gpu_scale_stride0,
gpu_scale_stride1,
gpu_scale_stride2,
cpu_kv_cache_ptr,
cpu_stride0,
cpu_stride1,
cpu_stride2,
cpu_stride3,
cpu_kv_cache_scale_ptr,
cpu_scale_stride0,
cpu_scale_stride1,
cpu_scale_stride2,
cpu_scale_stride3,
page_indexes_ptr,
layer_num,
head_dim,
scale_head_dim,
block_num,
token_num,
cpu_k_start_head_index: tl.constexpr,
cpu_k_head_num: tl.constexpr,
gpu_k_start_head_index: tl.constexpr,
gpu_k_head_num: tl.constexpr,
cpu_v_start_head_index: tl.constexpr,
cpu_v_head_num: tl.constexpr,
gpu_v_start_head_index: tl.constexpr,
gpu_v_head_num: tl.constexpr,
BLOCK_HEAD_DIM: tl.constexpr,
TOKEN_BLOCK: tl.constexpr,
HAS_SCALE: tl.constexpr,
):
block_start_index = tl.program_id(0)
block_split_size = tl.num_programs(axis=0)

for block_index in tl.range(block_start_index, block_num, block_split_size):
cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64)
token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
token_range_mask = token_range < token_num
token_indexes = tl.load(token_indexes_ptr + token_range, mask=token_range_mask).to(tl.int64)
head_dim_range = tl.arange(0, BLOCK_HEAD_DIM)
head_dim_mask = head_dim_range < head_dim
scale_head_dim_mask = head_dim_range < scale_head_dim

token_head_mask = token_range_mask[:, None] & head_dim_mask[None, :]
token_scale_mask = token_range_mask[:, None] & scale_head_dim_mask[None, :]
for layer_index in range(layer_num):
for k_head_index in range(gpu_k_head_num):
gpu_k_head_index = k_head_index + gpu_k_start_head_index
cpu_k_head_index = k_head_index + cpu_k_start_head_index

gpu_ptr = (
gpu_kv_cache_ptr
+ layer_index.to(tl.int64) * gpu_stride0
+ token_indexes[:, None] * gpu_stride1
+ gpu_k_head_index.to(tl.int64) * gpu_stride2
+ head_dim_range[None, :]
)
gpu_data = tl.load(gpu_ptr, mask=token_head_mask, other=0.0)
cpu_ptr = (
cpu_kv_cache_ptr
+ cpu_page_index * cpu_stride0
+ layer_index.to(tl.int64) * cpu_stride1
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
+ cpu_k_head_index * cpu_stride3
+ head_dim_range[None, :]
)
tl.store(cpu_ptr, gpu_data, mask=token_head_mask, cache_modifier=".wt")

if HAS_SCALE:
gpu_scale_ptr = (
gpu_kv_cache_scale_ptr
+ layer_index.to(tl.int64) * gpu_scale_stride0
+ token_indexes[:, None] * gpu_scale_stride1
+ gpu_k_head_index.to(tl.int64) * gpu_scale_stride2
+ head_dim_range[None, :]
)
gpu_scale_data = tl.load(gpu_scale_ptr, mask=token_scale_mask, other=0.0)
cpu_scale_ptr = (
cpu_kv_cache_scale_ptr
+ cpu_page_index * cpu_scale_stride0
+ layer_index.to(tl.int64) * cpu_scale_stride1
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_scale_stride2
+ cpu_k_head_index * cpu_scale_stride3
+ head_dim_range[None, :]
)
tl.store(
cpu_scale_ptr,
gpu_scale_data,
mask=token_scale_mask,
cache_modifier=".wt",
)

for v_head_index in range(gpu_v_head_num):
gpu_v_head_index = v_head_index + gpu_v_start_head_index
cpu_v_head_index = v_head_index + cpu_v_start_head_index

gpu_ptr = (
gpu_kv_cache_ptr
+ layer_index.to(tl.int64) * gpu_stride0
+ token_indexes[:, None] * gpu_stride1
+ gpu_v_head_index.to(tl.int64) * gpu_stride2
+ head_dim_range[None, :]
)
gpu_data = tl.load(gpu_ptr, mask=token_head_mask, other=0.0)
cpu_ptr = (
cpu_kv_cache_ptr
+ cpu_page_index * cpu_stride0
+ layer_index.to(tl.int64) * cpu_stride1
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
+ cpu_v_head_index * cpu_stride3
+ head_dim_range[None, :]
)
tl.store(cpu_ptr, gpu_data, mask=token_head_mask, cache_modifier=".wt")

if HAS_SCALE:
gpu_scale_ptr = (
gpu_kv_cache_scale_ptr
+ layer_index.to(tl.int64) * gpu_scale_stride0
+ token_indexes[:, None] * gpu_scale_stride1
+ gpu_v_head_index.to(tl.int64) * gpu_scale_stride2
+ head_dim_range[None, :]
)
gpu_scale_data = tl.load(gpu_scale_ptr, mask=token_scale_mask, other=0.0)
cpu_scale_ptr = (
cpu_kv_cache_scale_ptr
+ cpu_page_index * cpu_scale_stride0
+ layer_index.to(tl.int64) * cpu_scale_stride1
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_scale_stride2
+ cpu_v_head_index * cpu_scale_stride3
+ head_dim_range[None, :]
)
tl.store(
cpu_scale_ptr,
gpu_scale_data,
mask=token_scale_mask,
cache_modifier=".wt",
)


@torch.no_grad()
def offload_gpu_kv_to_cpu_for_x2i(
token_indexes: torch.Tensor,
gpu_kv_cache: torch.Tensor,
gpu_kv_cache_scale: Optional[torch.Tensor],
cpu_kv_cache: torch.Tensor,
cpu_kv_cache_scale: Optional[torch.Tensor],
page_indexes: torch.Tensor,
tp_index: int,
tp_world_size: int,
grid_num: int,
_cache_data={},
):
"""
Args:
token_indexes: (token_num, )
gpu_kv_cache: (layer_num, token_num, head_num, head_dim)
cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim)
page_indexes: (page_num,)
"""

token_block_size = cpu_kv_cache.shape[2]
token_num = token_indexes.shape[0]
assert token_num <= page_indexes.shape[0] * token_block_size

gpu_heads = gpu_kv_cache.shape[2]
gpu_head_dim = gpu_kv_cache.shape[3]
cpu_heads = cpu_kv_cache.shape[3]
cpu_head_dim = cpu_kv_cache.shape[4]

assert gpu_head_dim == cpu_head_dim
assert gpu_kv_cache.shape[0] == cpu_kv_cache.shape[1]

scale_size = (tp_world_size * gpu_heads) // cpu_heads

if (gpu_heads, cpu_heads, tp_index, tp_world_size) in _cache_data:
need_offload, head_info_tuple = _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)]
else:
if cpu_heads > 1:
assert (tp_world_size * gpu_heads) % cpu_heads == 0
assert cpu_heads % 2 == 0
cpu_heads_index = (
torch.arange(0, cpu_heads, device="cpu", dtype=torch.int32)
.view(cpu_heads, 1)
.tile((1, scale_size))
.view(2, tp_world_size, -1)
)
k_cpu_heads_index = cpu_heads_index[0][tp_index]
v_cpu_heads_index = cpu_heads_index[1][tp_index]

cpu_heads_index = torch.cat([k_cpu_heads_index, v_cpu_heads_index], dim=0).view(2, -1).numpy()
gpu_heads_index = torch.arange(0, gpu_heads, device="cpu", dtype=torch.int32).view(2, -1)

need_offload = tp_index % scale_size == 0

cpu_k_start_head_index = int(cpu_heads_index[0, 0])
cpu_k_head_num = len(cpu_heads_index[0])
gpu_k_start_head_index = int(gpu_heads_index[0, 0])
gpu_k_head_num = len(gpu_heads_index[0])
assert cpu_k_head_num == gpu_k_head_num
cpu_v_start_head_index = int(cpu_heads_index[1, 0])
cpu_v_head_num = len(cpu_heads_index[1])
gpu_v_start_head_index = int(gpu_heads_index[1, 0])
gpu_v_head_num = len(gpu_heads_index[1])
assert cpu_v_head_num == gpu_v_head_num

else:
assert gpu_heads == 1
assert cpu_heads == 1

need_offload = tp_index == 0
cpu_k_start_head_index = 0
cpu_k_head_num = 1
gpu_k_start_head_index = 0
gpu_k_head_num = 1
cpu_v_start_head_index = 0
cpu_v_head_num = 0
gpu_v_start_head_index = 0
gpu_v_head_num = 0

head_info_tuple = (
cpu_k_start_head_index,
cpu_k_head_num,
gpu_k_start_head_index,
gpu_k_head_num,
cpu_v_start_head_index,
cpu_v_head_num,
gpu_v_start_head_index,
gpu_v_head_num,
)
_cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] = (need_offload, head_info_tuple)

if not need_offload:
return

(
cpu_k_start_head_index,
cpu_k_head_num,
gpu_k_start_head_index,
gpu_k_head_num,
cpu_v_start_head_index,
cpu_v_head_num,
gpu_v_start_head_index,
gpu_v_head_num,
) = head_info_tuple

assert token_block_size == triton.next_power_of_2(token_block_size)

page_num = page_indexes.shape[0]
grid = (grid_num,)
num_warps = 4
num_stages = 1
HAS_SCALE = gpu_kv_cache_scale is not None and cpu_kv_cache_scale is not None
if HAS_SCALE:
scale_head_dim = gpu_kv_cache_scale.shape[-1]
gpu_scale_stride = gpu_kv_cache_scale.stride()
cpu_scale_stride = cpu_kv_cache_scale.stride()
else:
scale_head_dim = 0
gpu_scale_stride = [0 for _ in range(5)]
cpu_scale_stride = [0 for _ in range(5)]

_offload_gpu_kv_to_cpu_for_x2i[grid](
token_indexes_ptr=token_indexes,
gpu_kv_cache_ptr=gpu_kv_cache,
gpu_stride0=gpu_kv_cache.stride(0),
gpu_stride1=gpu_kv_cache.stride(1),
gpu_stride2=gpu_kv_cache.stride(2),
gpu_kv_cache_scale_ptr=gpu_kv_cache_scale,
gpu_scale_stride0=gpu_scale_stride[0],
gpu_scale_stride1=gpu_scale_stride[1],
gpu_scale_stride2=gpu_scale_stride[2],
cpu_kv_cache_ptr=cpu_kv_cache,
cpu_stride0=cpu_kv_cache.stride(0),
cpu_stride1=cpu_kv_cache.stride(1),
cpu_stride2=cpu_kv_cache.stride(2),
cpu_stride3=cpu_kv_cache.stride(3),
cpu_kv_cache_scale_ptr=cpu_kv_cache_scale,
cpu_scale_stride0=cpu_scale_stride[0],
cpu_scale_stride1=cpu_scale_stride[1],
cpu_scale_stride2=cpu_scale_stride[2],
cpu_scale_stride3=cpu_scale_stride[3],
page_indexes_ptr=page_indexes,
layer_num=gpu_kv_cache.shape[0],
head_dim=gpu_head_dim,
scale_head_dim=scale_head_dim,
block_num=page_num,
token_num=token_num,
cpu_k_start_head_index=cpu_k_start_head_index,
cpu_k_head_num=cpu_k_head_num,
gpu_k_start_head_index=gpu_k_start_head_index,
gpu_k_head_num=gpu_k_head_num,
cpu_v_start_head_index=cpu_v_start_head_index,
cpu_v_head_num=cpu_v_head_num,
gpu_v_start_head_index=gpu_v_start_head_index,
gpu_v_head_num=gpu_v_head_num,
BLOCK_HEAD_DIM=triton.next_power_of_2(gpu_head_dim),
TOKEN_BLOCK=token_block_size,
HAS_SCALE=HAS_SCALE,
num_warps=num_warps,
num_stages=num_stages,
)
2 changes: 2 additions & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel
from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel
from lightllm.models.neo_chat.model import NeoTpPartModel
from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel
from lightllm.models.qwen3_5_moe.model import Qwen3_5MOETpPartModel
from .registry import get_model, get_model_class
Loading