Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
21c3eeb
support 31B
WANDY666 Apr 30, 2026
99b790c
fix
WANDY666 May 6, 2026
4c30c73
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 6, 2026
15a5379
support moe
WANDY666 May 7, 2026
83f4983
support e4b (PLE and shared_kv)
WANDY666 May 9, 2026
d969a5f
support visual module
WANDY666 May 11, 2026
08f066d
optimize sliding window
WANDY666 May 12, 2026
7678de8
fix
WANDY666 May 12, 2026
63c658a
simplify
WANDY666 May 13, 2026
300e577
minor improvements
WANDY666 May 13, 2026
50822f0
fix
WANDY666 May 13, 2026
b4b13cc
fix attention cuda graph
WANDY666 May 13, 2026
f19074b
fused gelu gate up
WANDY666 May 14, 2026
5b61450
add out_dtype
WANDY666 May 14, 2026
c0ca212
minor improvements
WANDY666 May 14, 2026
9499a00
fix eos_token_ids
WANDY666 May 14, 2026
de7e220
for HF format
WANDY666 May 14, 2026
bfc59ff
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 14, 2026
109d27c
fix window_size
WANDY666 May 14, 2026
2ea258e
fix window_size
WANDY666 May 14, 2026
b297af5
fix
WANDY666 May 14, 2026
7a81e85
add reasoning_parser for gemma4
WANDY666 May 15, 2026
d619534
[fix]ple support cudagraph
WANDY666 May 16, 2026
c2578c0
fix PLE illegal memory access
WANDY666 May 18, 2026
d744cbc
support sliding_window_right
WANDY666 May 18, 2026
05a0db8
fix notes
WANDY666 May 18, 2026
6f1bd2e
tune in H200
WANDY666 May 19, 2026
90643db
fix
hiworldwzj May 19, 2026
a2b74ab
fix
hiworldwzj May 19, 2026
e606e05
fix
hiworldwzj May 19, 2026
7354da2
fix
hiworldwzj May 20, 2026
afa0194
fix
hiworldwzj May 20, 2026
46ce6af
fix
hiworldwzj May 20, 2026
0188c10
fix
hiworldwzj May 20, 2026
393ec69
fix
hiworldwzj May 20, 2026
e96c2b7
fix
WANDY666 May 20, 2026
f806326
fix
hiworldwzj May 20, 2026
c5b2b81
fix
hiworldwzj May 20, 2026
3bd46d7
fix
hiworldwzj May 20, 2026
91051f0
Merge branch 'support_gemma4' of https://github.com/ModelTC/LightLLM …
WANDY666 May 20, 2026
fb75045
fix
WANDY666 May 20, 2026
7c664c3
fix
hiworldwzj May 20, 2026
0d35e8b
fix
hiworldwzj May 20, 2026
74a4b1f
fix
hiworldwzj May 20, 2026
d2df0a0
fix
hiworldwzj May 20, 2026
3491641
fix
hiworldwzj May 20, 2026
c8812f2
fix
hiworldwzj May 20, 2026
8f160b5
fix
hiworldwzj May 20, 2026
ee92fee
fix
hiworldwzj May 21, 2026
131a163
fix
hiworldwzj May 21, 2026
6d7729f
fix
hiworldwzj May 21, 2026
87da477
fix
WANDY666 May 21, 2026
819497c
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 21, 2026
c57e062
format
WANDY666 May 21, 2026
6ebf9db
finish
WANDY666 May 21, 2026
e682f9b
fix
WANDY666 May 21, 2026
c28f085
Merge https://github.com/ModelTC/LightLLM into gemma4_mtp
WANDY666 May 22, 2026
6099413
format
WANDY666 May 22, 2026
ba47045
format
WANDY666 May 22, 2026
33d7ceb
format
WANDY666 May 22, 2026
bca6990
refactor: unify gemma4_mtp post forward with ModelOutput
zhangts20 Jul 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 102 additions & 59 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn.functional as F
import triton
from typing import final, List
from typing import final, List, Optional
from tqdm import tqdm

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
Expand Down Expand Up @@ -96,6 +96,7 @@ def __init__(self, kvargs):
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
"eagle_frozen_kv",
]
self.prefill_graph: PrefillCudaGraph = None

Expand Down Expand Up @@ -494,6 +495,67 @@ def _create_unpad_prefill_model_output(

return new_model_output

def _gather_last_input_embs(
self,
input_embs: torch.Tensor,
infer_state: InferStateInfo,
) -> torch.Tensor:
"""Gather last input embs from post layer."""
last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.is_prefill and infer_state.need_dp_prefill_balance:
last_input_embs = infer_state._all_to_all_unbalance_get(data=last_input_embs)
return last_input_embs

def _get_mtp_main_output_hiddens(
self,
post_out: ModelOutput,
input_embs: torch.Tensor,
infer_state: InferStateInfo,
) -> Optional[torch.Tensor]:
"""Get mtp main output hiddens from post layer output."""
if not self.is_mtp_mode:
return None

mtp_hiddens = post_out.mtp_main_output_hiddens
if mtp_hiddens is None:
mtp_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.is_prefill and infer_state.need_dp_prefill_balance:
mtp_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_hiddens)

return mtp_hiddens.contiguous()

def _build_model_output(
self,
post_out: torch.Tensor | ModelOutput,
input_embs: torch.Tensor,
infer_state: InferStateInfo,
) -> ModelOutput:
"""Build model output from post layer output."""
if isinstance(post_out, torch.Tensor):
post_out = ModelOutput(logits=post_out.contiguous())
return ModelOutput(
logits=post_out.logits,
mtp_main_output_hiddens=self._get_mtp_main_output_hiddens(
post_out=post_out,
input_embs=input_embs,
infer_state=infer_state,
),
)

def _post_forward_to_model_output(
self,
input_embs: torch.Tensor,
infer_state: InferStateInfo,
) -> ModelOutput:
"""Run post layer forward and build model output."""
last_input_embs = self._gather_last_input_embs(input_embs, infer_state)
post_out = self.post_infer.token_forward(last_input_embs, infer_state, self.pre_post_weight)
return self._build_model_output(
post_out=post_out,
input_embs=input_embs,
infer_state=infer_state,
)

def _prefill(
self,
model_input: ModelInput,
Expand Down Expand Up @@ -653,20 +715,7 @@ def prefill_func(input_tensors, infer_state):
g_cache_manager.cache_env_out()

input_embs = output_tensors[0]

last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.need_dp_prefill_balance:
last_input_embs = infer_state._all_to_all_unbalance_get(data=last_input_embs)

predict_logits = self.post_infer.token_forward(last_input_embs, infer_state, self.pre_post_weight)
model_output = ModelOutput(logits=predict_logits)

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.need_dp_prefill_balance:
input_embs = infer_state._all_to_all_unbalance_get(data=input_embs)
model_output.mtp_main_output_hiddens = input_embs.contiguous()
model_output = self._post_forward_to_model_output(input_embs, infer_state)

# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
# 该调用没有实际意义
Expand All @@ -684,17 +733,7 @@ def _token_forward(self, infer_state: InferStateInfo):
layer = self.layers_infer[i]
input_embs: torch.Tensor = layer.token_forward(input_embs, infer_state, self.trans_layers_weight[i])

last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
predict_logits: torch.Tensor = self.post_infer.token_forward(
last_input_embs, infer_state=infer_state, layer_weight=self.pre_post_weight
)

model_output = ModelOutput(logits=predict_logits.contiguous())

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
model_output.mtp_main_output_hiddens = input_embs.contiguous()
model_output = self._post_forward_to_model_output(input_embs, infer_state)

# 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。
if infer_state.is_cuda_graph:
Expand Down Expand Up @@ -921,28 +960,24 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state
infer_state.call_overlap_hook()
infer_state1.call_overlap_hook()

last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
last_input_embs1 = self.post_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1)
if infer_state.need_dp_prefill_balance:
last_input_embs = infer_state._all_to_all_unbalance_get(data=last_input_embs)
last_input_embs1 = infer_state1._all_to_all_unbalance_get(data=last_input_embs1)
last_input_embs = self._gather_last_input_embs(input_embs, infer_state)
last_input_embs1 = self._gather_last_input_embs(input_embs1, infer_state1)

predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward(
post_out0, post_out1 = self.post_infer.overlap_tpsp_token_forward(
last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight
)
g_cache_manager.cache_env_out()

model_output = ModelOutput(logits=predict_logits.contiguous())
model_output1 = ModelOutput(logits=predict_logits1.contiguous())

if self.is_mtp_mode:
input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
input_embs1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1)
if infer_state.need_dp_prefill_balance:
input_embs = infer_state._all_to_all_unbalance_get(data=input_embs)
input_embs1 = infer_state1._all_to_all_unbalance_get(data=input_embs1)
model_output.mtp_main_output_hiddens = input_embs.contiguous()
model_output1.mtp_main_output_hiddens = input_embs1.contiguous()
model_output = self._build_model_output(
post_out=post_out0,
input_embs=input_embs,
infer_state=infer_state,
)
model_output1 = self._build_model_output(
post_out=post_out1,
input_embs=input_embs1,
infer_state=infer_state1,
)

return model_output, model_output1

Expand All @@ -963,21 +998,23 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1:
infer_state.call_overlap_hook()
infer_state1.call_overlap_hook()

last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
last_input_embs1 = self.post_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1)
last_input_embs = self._gather_last_input_embs(input_embs, infer_state)
last_input_embs1 = self._gather_last_input_embs(input_embs1, infer_state1)

predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward(
post_out0, post_out1 = self.post_infer.overlap_tpsp_token_forward(
last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight
)

model_output = ModelOutput(logits=predict_logits.contiguous())
model_output1 = ModelOutput(logits=predict_logits1.contiguous())

if self.is_mtp_mode:
input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
input_embs1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1)
model_output.mtp_main_output_hiddens = input_embs.contiguous()
model_output1.mtp_main_output_hiddens = input_embs1.contiguous()
model_output = self._build_model_output(
post_out=post_out0,
input_embs=input_embs,
infer_state=infer_state,
)
model_output1 = self._build_model_output(
post_out=post_out1,
input_embs=input_embs1,
infer_state=infer_state1,
)

if infer_state.is_cuda_graph:
model_output.to_no_ref_tensor()
Expand Down Expand Up @@ -1185,15 +1222,21 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

cls_name = str(self.__class__)
is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
"Deepseek3MTPModel" in cls_name
or "Qwen3MOEMTPModel" in cls_name
or "MistralMTPModel" in cls_name
or "Glm4MoeLiteMTPModel" in cls_name
or "Gemma4MTPModel" in cls_name
)
if is_mtp_draft_model:
# Gemma-4's drafter consumes the recurrent hidden state in backbone
# width (the target's hidden size), not its own draft width; the other
# MTP drafters have draft width == backbone width so hidden_size fits.
hidden_size = self.config.get("backbone_hidden_size", self.config["hidden_size"])
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
token_num, hidden_size, dtype=self.data_type, device="cuda"
)
else:
special_model_input["mtp_draft_input_hiddens"] = None
Expand Down
52 changes: 49 additions & 3 deletions lightllm/models/gemma4/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import torch
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.common.basemodel.batch_objs import ModelOutput
from lightllm.distributed.communication_op import all_gather


class Gemma4PostLayerInfer(LlamaPostLayerInfer):
Expand All @@ -10,11 +13,54 @@ class Gemma4PostLayerInfer(LlamaPostLayerInfer):

def __init__(self, network_config):
super().__init__(network_config)
self.final_logit_softcapping = float(network_config.get("final_logit_softcapping"))
cap = network_config.get("final_logit_softcapping")
self.final_logit_softcapping = float(cap) if cap is not None else None

def token_forward(self, input_embdings, infer_state, layer_weight):
logits = super().token_forward(input_embdings, infer_state, layer_weight)
def _get_normed_last_hidden(self, input_embdings, infer_state, layer_weight):
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
input_embdings = None
last_input = self._norm(last_input, infer_state, layer_weight)
return last_input, token_num

def _dense_logits_from_normed(self, normed, token_num, infer_state, layer_weight):
input_embdings_dtype = normed.dtype
lm_input = normed.permute(1, 0).view(-1, token_num)
logic_batch = layer_weight.lm_head_weight_(input=lm_input, alloc_func=self.alloc_tensor)
vocab_size = layer_weight.lm_head_weight_.vocab_size
if self.tp_world_size_ == 1:
gather_data = logic_batch
else:
gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype)
split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64)
all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)],
logic_batch,
group=infer_state.dist_group,
async_op=False,
)
logic_batch = None
logits = self.alloc_tensor((token_num, vocab_size), dtype=torch.float32)
logits[:, :] = gather_data.permute(1, 0)
gather_data = None
return logits

def _apply_logit_softcapping(self, logits: torch.Tensor) -> torch.Tensor:
if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0:
cap = self.final_logit_softcapping
logits = torch.tanh(logits / cap) * cap
return logits

def _logits_from_normed(self, normed, token_num, infer_state, layer_weight):
return self._dense_logits_from_normed(normed, token_num, infer_state, layer_weight)

def _mtp_hiddens_from_normed(self, normed):
return normed

def token_forward(self, input_embdings, infer_state, layer_weight):
normed, token_num = self._get_normed_last_hidden(input_embdings, infer_state, layer_weight)
logits = self._logits_from_normed(normed, token_num, infer_state, layer_weight)
logits = self._apply_logit_softcapping(logits)
return ModelOutput(
logits=logits,
mtp_main_output_hiddens=self._mtp_hiddens_from_normed(normed),
)
Empty file.
Empty file.
61 changes: 61 additions & 0 deletions lightllm/models/gemma4_mtp/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
from lightllm.models.gemma4.layer_infer.post_layer_infer import Gemma4PostLayerInfer


class Gemma4MTPPostLayerInfer(Gemma4PostLayerInfer):
def __init__(self, network_config):
super().__init__(network_config)

self.use_ordered_embeddings_ = bool(network_config.get("use_ordered_embeddings"))
self._post_projection_weight_ = None
if self.use_ordered_embeddings_:
self.num_centroids_ = network_config["num_centroids"]
self.centroid_top_k_ = network_config["centroid_intermediate_top_k"]
self.vocab_size_ = network_config["vocab_size"]
assert (
self.vocab_size_ % self.num_centroids_ == 0
), f"vocab_size={self.vocab_size_} must be divisible by num_centroids={self.num_centroids_}"
self._vocab_per_centroid_ = self.vocab_size_ // self.num_centroids_
# token -> centroid mapping is derived lazily from the loaded
# token_ordering buffer (weights are not loaded yet at __init__).
self._centroid_of_token_ = None

def _centroid_logits(self, last_hidden, token_num, layer_weight):
"""Gather lm_head rows for the per-token top-K centroid blocks,
dot with the post-norm hidden, scatter into a [N, vocab] -inf tensor
at the original vocab positions. Mathematically equivalent to
dense logits + mask but avoids the [N, vocab] bool tensor and matches
the reference implementations exactly.
"""
centroid_scores = layer_weight.centroids_weight_.mm(last_hidden) # [N, num_centroids]
topk_centroids = torch.topk(centroid_scores, k=self.centroid_top_k_, dim=-1).indices # [N, K]
# token_ordering[i] = original vocab id at reordered position i;
# row c of the (C, vpc) view holds the vocab ids of centroid c.
token_ordering = layer_weight.token_ordering_.weight # [vocab] int64
clusters = token_ordering.view(self.num_centroids_, self._vocab_per_centroid_) # [C, vpc]
selected_vocab = clusters[topk_centroids] # [N, K, vpc] - original vocab ids
num_selected = self.centroid_top_k_ * self._vocab_per_centroid_
selected_vocab = selected_vocab.reshape(token_num, num_selected) # [N, num_selected]
# Gather lm_head rows for the selected vocab ids.
lm_head_w = layer_weight.lm_head_weight_.weight # [vocab, draft_hidden]
selected_embeddings = lm_head_w[selected_vocab] # [N, num_selected, H]
# Sparse logits: dot product per token vs its selected rows.
selected_logits = torch.einsum("nh,nsh->ns", last_hidden, selected_embeddings)
# Scatter to [N, vocab] with -inf elsewhere.
output = torch.full(
(token_num, self.vocab_size_),
float("-inf"),
dtype=selected_logits.dtype,
device=selected_logits.device,
)
output.scatter_(-1, selected_vocab, selected_logits)
return output

def _logits_from_normed(self, normed, token_num, infer_state, layer_weight):
if self.use_ordered_embeddings_:
return self._centroid_logits(normed, token_num, layer_weight)
return super()._logits_from_normed(normed, token_num, infer_state, layer_weight)

def _mtp_hiddens_from_normed(self, normed):
assert self._post_projection_weight_ is not None, "post_projection weight is not initialized"
return self._post_projection_weight_.mm(normed)
Loading
Loading