Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
7c7bd61
one pass
WANDY666 Jun 3, 2026
d790ad2
Optimization
WANDY666 Jun 5, 2026
a161244
add prompt cache
WANDY666 Jun 5, 2026
61eed87
support cudagraph
WANDY666 Jun 5, 2026
19866d0
refact tokenizer
WANDY666 Jun 8, 2026
29c6082
add statement
WANDY666 Jun 8, 2026
ffafdbf
format
WANDY666 Jun 8, 2026
e8009cb
pass gsm8k but need review
WANDY666 Jun 11, 2026
b3b8123
fix
WANDY666 Jun 11, 2026
6002866
fix rope
WANDY666 Jun 11, 2026
6bc34ad
dsv4: enable decode cudagraph; fix warmup-baked FlashMLASchedMeta
WANDY666 Jun 11, 2026
e78e0d4
dsv4: enable prefill cudagraph; zero pad-row attention output
Jun 11, 2026
c09dc6a
fix profile
WANDY666 Jun 11, 2026
c07e38c
support fp8
WANDY666 Jun 12, 2026
ff71706
optimize
WANDY666 Jun 12, 2026
d7dd6e0
fix
WANDY666 Jun 12, 2026
3a5dcdc
compress infer
WANDY666 Jun 14, 2026
d76450f
add c128 to mem_manager
WANDY666 Jun 14, 2026
07d2308
refact
WANDY666 Jun 15, 2026
d4dcd8a
opt
WANDY666 Jun 15, 2026
62c16d5
opt
WANDY666 Jun 15, 2026
69824d0
delete launch.sh
WANDY666 Jun 15, 2026
df70ecb
fix
WANDY666 Jun 15, 2026
1ad981d
restore
WANDY666 Jun 16, 2026
7b17bb5
support parser
WANDY666 Jun 16, 2026
6837abd
fix
WANDY666 Jun 16, 2026
e1376fe
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 Jun 16, 2026
02a24ce
add c4 paged indexes
WANDY666 Jun 18, 2026
52a1528
fix chunk_size and page_size
WANDY666 Jun 18, 2026
0dbc90b
add sglang third_party
WANDY666 Jun 18, 2026
e8c49d1
fix tpsp
WANDY666 Jun 18, 2026
88309b5
fix profile
WANDY666 Jun 21, 2026
cf433fb
fix swa insufficient
WANDY666 Jun 22, 2026
40f5810
fix
WANDY666 Jun 22, 2026
f527ca2
rename
WANDY666 Jun 22, 2026
255e90d
tune config
WANDY666 Jun 22, 2026
d88dc71
prepare opt
WANDY666 Jun 22, 2026
a56c79b
delete
WANDY666 Jun 22, 2026
e286943
item1: wire fused_q_indexer_rope_hadamard_quant (rope+hadamard+fp8qua…
WANDY666 Jun 23, 2026
58b145b
item3: lazy-cache layer-independent c4 paged metadata (page_table/ctx…
WANDY666 Jun 23, 2026
a0379bb
gate-bf16 (flag) + drop redundant attn_sink fp32 copy + lazy gen_nsa_…
WANDY666 Jun 23, 2026
b796d48
cache prefill FlashMLA sched-meta per compress-ratio (was rebuilt eve…
WANDY666 Jun 23, 2026
e07b85e
2-stream
WANDY666 Jun 23, 2026
51f5c84
fix parser
WANDY666 Jun 24, 2026
2f12a07
fix multi-invoke
WANDY666 Jun 24, 2026
da3fec2
speed up prepare
WANDY666 Jun 24, 2026
82cb6d6
fix arguments
WANDY666 Jun 24, 2026
bc22591
tune H100
WANDY666 Jun 24, 2026
77baa05
add encoding_dsv4
WANDY666 Jun 25, 2026
2225e1a
fix c4 error
WANDY666 Jun 26, 2026
112247b
fuse wq_a+wkv & indexer wkv+wgate GEMMs; fp8 wo_a at tp8 (1 group/rank)
WANDY666 Jun 26, 2026
0d52fde
reduce alloc fragment
WANDY666 Jun 28, 2026
c711fe9
set _C4_PREFILL_LOGITS_BUDGET_BYTES reduce max memory usage
WANDY666 Jun 28, 2026
fced38b
fix(deepseek-v4): align fp8 serving numerics with reference
WANDY666 Jul 1, 2026
17a0d54
fix stirde bug (if-inverse 73)
WANDY666 Jul 2, 2026
c19b853
default topk from huggingface's 50 to -1, if_inverse 70.6 -> 73
WANDY666 Jul 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 289 additions & 2 deletions lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py

Large diffs are not rendered by default.

62 changes: 62 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ def _init_custom(self):

@torch.no_grad()
def forward(self, model_input: ModelInput):
# decode 槽位 prep: 放在 to_cuda 前, 优先使用 b_req_idx/b_seq_len 的 CPU mirror,
# 且此刻已在 forward 的 CUDA stream 上 -> 与后续 attention 同流, 无跨流竞态、无 D2H。
# mem_indexes_cpu is None 时跳过: cudagraph warmup 的输入全在 CUDA 且 b_req_idx 全为 HOLD, prep 本就是 no-op。
if not model_input.is_prefill and model_input.mem_indexes_cpu is not None:
self.req_manager.prepare_decode(
model_input.b_req_idx_cpu,
model_input.b_seq_len_cpu,
model_input.mem_indexes_cpu,
)
model_input.to_cuda()
assert model_input.mem_indexes.is_cuda

Expand Down Expand Up @@ -371,6 +380,15 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0
)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2)
new_model_input.b_req_idx_cpu = F.pad(
new_model_input.b_req_idx_cpu,
(0, padded_batch_size),
mode="constant",
value=self.req_manager.HOLD_REQUEST_ID,
)
new_model_input.b_seq_len_cpu = F.pad(
new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2
)
new_model_input.mem_indexes = F.pad(
new_model_input.mem_indexes,
(0, padded_batch_size),
Expand Down Expand Up @@ -428,6 +446,15 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle
new_model_input.b_mtp_index = F.pad(new_model_input.b_mtp_index, (0, 1), mode="constant", value=0)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, 1), mode="constant", value=padded_token_num)
new_model_input.b_ready_cache_len = F.pad(new_model_input.b_ready_cache_len, (0, 1), mode="constant", value=0)
new_model_input.b_req_idx_cpu = F.pad(
new_model_input.b_req_idx_cpu, (0, 1), mode="constant", value=self.req_manager.HOLD_REQUEST_ID
)
new_model_input.b_seq_len_cpu = F.pad(
new_model_input.b_seq_len_cpu, (0, 1), mode="constant", value=padded_token_num
)
new_model_input.b_ready_cache_len_cpu = F.pad(
new_model_input.b_ready_cache_len_cpu, (0, 1), mode="constant", value=0
)
b_q_seq_len = new_model_input.b_seq_len - new_model_input.b_ready_cache_len
new_model_input.b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
# 构建新的list, 使用 append 可能会让外面使用的数组引用发生变化,导致错误。
Expand Down Expand Up @@ -521,6 +548,15 @@ def _prefill(
alloc_mem_index=infer_state.mem_index,
max_q_seq_len=infer_state.max_q_seq_len,
)
if model_input.b_req_idx_cpu is not None:
self.req_manager.prepare_prefill(
b_req_idx=infer_state.b_req_idx,
b_ready_cache_len=infer_state.b_ready_cache_len,
b_seq_len=infer_state.b_seq_len,
b_req_idx_cpu=model_input.b_req_idx_cpu,
b_ready_cache_len_cpu=model_input.b_ready_cache_len_cpu,
b_seq_len_cpu=model_input.b_seq_len_cpu,
)
prefill_mem_indexes_ready_event = torch.cuda.Event()
prefill_mem_indexes_ready_event.record()

Expand Down Expand Up @@ -741,6 +777,15 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state0.mem_index,
max_q_seq_len=infer_state0.max_q_seq_len,
)
if model_input0.b_req_idx_cpu is not None:
self.req_manager.prepare_prefill(
b_req_idx=infer_state0.b_req_idx,
b_ready_cache_len=infer_state0.b_ready_cache_len,
b_seq_len=infer_state0.b_seq_len,
b_req_idx_cpu=model_input0.b_req_idx_cpu,
b_ready_cache_len_cpu=model_input0.b_ready_cache_len_cpu,
b_seq_len_cpu=model_input0.b_seq_len_cpu,
)
infer_state0.init_some_extra_state(self)
infer_state0.init_att_state()

Expand All @@ -754,6 +799,15 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state1.mem_index,
max_q_seq_len=infer_state1.max_q_seq_len,
)
if model_input1.b_req_idx_cpu is not None:
self.req_manager.prepare_prefill(
b_req_idx=infer_state1.b_req_idx,
b_ready_cache_len=infer_state1.b_ready_cache_len,
b_seq_len=infer_state1.b_seq_len,
b_req_idx_cpu=model_input1.b_req_idx_cpu,
b_ready_cache_len_cpu=model_input1.b_ready_cache_len_cpu,
b_seq_len_cpu=model_input1.b_seq_len_cpu,
)
infer_state1.init_some_extra_state(self)
infer_state1.init_att_state()

Expand Down Expand Up @@ -781,6 +835,14 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod

@torch.no_grad()
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
# decode 槽位 prep: 在 to_cuda 前使用 CPU mirror, 且已在 forward 的 CUDA stream 上 (见 forward 注释)。
for mi in (model_input0, model_input1):
if mi.mem_indexes_cpu is not None:
self.req_manager.prepare_decode(
mi.b_req_idx_cpu,
mi.b_seq_len_cpu,
mi.mem_indexes_cpu,
)
model_input0.to_cuda()
model_input1.to_cuda()
assert self.args.enable_tpsp_mix_mode
Expand Down
16 changes: 16 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class ModelInput:
multimodal_params: list = None
# cpu 变量
mem_indexes_cpu: torch.Tensor = None
b_req_idx_cpu: torch.Tensor = None
b_seq_len_cpu: torch.Tensor = None
b_ready_cache_len_cpu: torch.Tensor = None
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
# 的一些变量
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
Expand All @@ -53,6 +56,18 @@ class ModelInput:
# 的 draft 模型的输入
mtp_draft_input_hiddens: Optional[torch.Tensor] = None

def _capture_cpu_mirror(self, tensor_name: str, mirror_name: str):
tensor = getattr(self, tensor_name)
if tensor is not None and not tensor.is_cuda:
setattr(self, mirror_name, tensor)
return

def capture_cpu_mirrors(self):
self._capture_cpu_mirror("b_req_idx", "b_req_idx_cpu")
self._capture_cpu_mirror("b_seq_len", "b_seq_len_cpu")
self._capture_cpu_mirror("b_ready_cache_len", "b_ready_cache_len_cpu")
return

def to_cuda(self):
if self.input_ids is not None:
self.input_ids = self.input_ids.cuda(non_blocking=True)
Expand Down Expand Up @@ -82,6 +97,7 @@ def to_cuda(self):
self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True)

def __post_init__(self):
self.capture_cpu_mirrors()
self.check_input()

def check_input(self):
Expand Down
19 changes: 19 additions & 0 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@
logger = init_logger(__name__)


def _reset_att_state_sched_meta(infer_state: InferStateInfo):
# capture 前调用: warmup 趟用 copy.copy 浅拷贝共享 decode_att_state,其内部惰性初始化的
# 调度对象(如 FlashMLASchedMeta,首次内核调用时按当时数据规划并回写)会被 warmup 的
# dummy 负载锁定;若不重置,捕获趟将绑定为 dummy 规划的调度张量,所有 replay 都用错误
# 的 tile schedule(DSV4 实测 gsm8k 0.96 -> 0.74)。重置后规划发生在捕获区内,随 replay 重算。
for att_state in (infer_state.decode_att_state, infer_state.decode_att_state1):
if att_state is None:
continue
reset_fn = getattr(att_state, "reset_sched_meta_for_capture", None)
if reset_fn is not None:
reset_fn()
return


class CudaGraph:
# CudaGraph forward pass for the decoding stage.

Expand Down Expand Up @@ -94,6 +108,8 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
if param_name not in pure_para_set:
delattr(infer_state, param_name)

_reset_att_state_sched_meta(infer_state)

with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
Expand Down Expand Up @@ -128,6 +144,9 @@ def _capture_decode_overlap(
if para_name not in pure_para_set1:
delattr(infer_state1, para_name)

_reset_att_state_sched_meta(infer_state)
_reset_att_state_sched_meta(infer_state1)

with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ def __init__(
auto_update_redundancy_expert=self.auto_update_redundancy_expert,
)
self.lock = threading.Lock()
self._moe_weight_finalized = False
self._create_weight()

def _init_config(self, network_config: Dict[str, Any]):
self.n_group = network_config.get("n_group", 0)
self.use_grouped_topk = self.n_group > 0
self.norm_topk_prob = network_config["norm_topk_prob"]
self.norm_topk_prob = network_config.get("norm_topk_prob", False)
self.topk_group = network_config.get("topk_group", 0)
self.num_experts_per_tok = network_config["num_experts_per_tok"]
self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0)
Expand Down Expand Up @@ -136,6 +137,7 @@ def experts(
is_prefill: Optional[bool] = None,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
self._finalize_moe_weight()
return self.fuse_moe_impl(
input_tensor=input_tensor,
router_logits=router_logits,
Expand All @@ -152,6 +154,25 @@ def experts(
per_expert_scale=self.per_expert_scale,
)

def experts_with_preselected(
self,
input_tensor: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_prefill: Optional[bool] = None,
clamp_limit: Optional[float] = None,
) -> torch.Tensor:
self._finalize_moe_weight()
return self.fuse_moe_impl.fused_experts_with_topk(
input_tensor=input_tensor,
w13=self.w13,
w2=self.w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
is_prefill=is_prefill,
clamp_limit=clamp_limit,
)

def low_latency_dispatch(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -280,7 +301,18 @@ def verify_load(self):
e_score_correction_bias_load_ok = (
True if self.e_score_correction_bias is None else getattr(self.e_score_correction_bias, "load_ok", False)
)
return weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok
load_ok = weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok
if load_ok:
self._finalize_moe_weight()
return load_ok

def _finalize_moe_weight(self):
if self._moe_weight_finalized:
return
finalize = getattr(self.quant_method, "finalize_moe_weight", None)
if finalize is not None:
finalize(self)
self._moe_weight_finalized = True

def _create_weight(self):
intermediate_size = self.split_inter_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
from .triton_impl import FuseMoeTriton
from .marlin_impl import FuseMoeMarlin
from .deepgemm_impl import FuseMoeDeepGEMM
from .mxfp4_impl import FuseMoeMXFP4


def select_fuse_moe_impl(quant_method: QuantizationMethod, enable_ep_moe: bool):
if quant_method.method_name == "marlin-mxfp4w4a16-b32":
if enable_ep_moe:
raise RuntimeError("marlin-mxfp4w4a16-b32 does not support enable_ep_moe yet")
return FuseMoeMXFP4

if enable_ep_moe:
return FuseMoeDeepGEMM

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
clamp_limit: Optional[float] = None,
):
assert clamp_limit is None, "EP deepgemm fused MoE does not support clamp_limit yet"
output = fused_experts(
hidden_states=input_tensor,
w13=w13,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
clamp_limit: Optional[float] = None,
):
assert clamp_limit is None, "awq_marlin fused MoE does not support clamp_limit yet"

w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point
w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from typing import Optional

from lightllm.common.quantization.quantize_method import WeightPack
from .triton_impl import FuseMoeTriton


class FuseMoeMXFP4(FuseMoeTriton):
def create_workspace(self):
return None

def _fused_experts(
self,
input_tensor: torch.Tensor,
w13: WeightPack,
w2: WeightPack,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
clamp_limit: Optional[float] = None,
):
try:
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.experts.marlin_moe import fused_marlin_moe
from vllm.scalar_type import scalar_types
except Exception as e:
raise RuntimeError(f"MXFP4 fused MoE requires vLLM fused kernels, error={repr(e)}") from e

return fused_marlin_moe(
hidden_states=input_tensor.contiguous(),
w1=w13.weight,
w2=w2.weight,
bias1=None,
bias2=None,
w1_scale=w13.weight_scale,
w2_scale=w2.weight_scale,
topk_weights=topk_weights.to(torch.float32).contiguous(),
topk_ids=topk_ids.to(torch.long).contiguous(),
quant_type_id=scalar_types.float4_e2m1f.id,
global_num_experts=self.n_routed_experts,
activation=MoEActivation.SILU,
clamp_limit=clamp_limit,
)
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: bool = False,
clamp_limit: Optional[float] = None,
):
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
Expand All @@ -111,9 +112,30 @@ def _fused_experts(
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w13_scale,
w2_scale=w2_scale,
limit=clamp_limit,
)
return input_tensor

def fused_experts_with_topk(
self,
input_tensor: torch.Tensor,
w13: WeightPack,
w2: WeightPack,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_prefill: Optional[bool] = None,
clamp_limit: Optional[float] = None,
):
return self._fused_experts(
input_tensor=input_tensor,
w13=w13,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
is_prefill=is_prefill,
clamp_limit=clamp_limit,
)

def __call__(
self,
input_tensor: torch.Tensor,
Expand Down
Loading