Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
196 commits
Select commit Hold shift + click to select a range
b636310
add /flush_cache (#1108)
shihaobai Nov 14, 2025
60c379e
Aborted reqs (#1113)
shihaobai Nov 18, 2025
4095831
flush cache mulit node (#1116)
shihaobai Nov 19, 2025
ca9325f
[bugfix]: flush cache in single node (#1118)
shihaobai Nov 19, 2025
9948925
add pause and continue (#1120)
shihaobai Nov 19, 2025
4b32287
add launch_server and StartArgs (#1119)
sufubao Nov 21, 2025
27abcf5
Update weight (#1127)
kingder Dec 1, 2025
c210c82
release and resume (#1122)
shihaobai Dec 1, 2025
094df8c
use portpicker (#1142)
sufubao Dec 8, 2025
560be02
Rl weight (#1143)
shihaobai Dec 8, 2025
3d225d7
add_cli
sufubao Nov 25, 2025
499074a
add 30b moe configs
shihaobai Dec 8, 2025
f737585
update requirement
shihaobai Dec 9, 2025
8a67a47
add-neo-chat
Dec 26, 2025
fdc1369
add-neo-chat
Dec 30, 2025
e8e7416
add-neo-chat
Dec 31, 2025
ba44983
add-neo-chat
Dec 31, 2025
4d41a33
add-neo-chat
Dec 31, 2025
0e8845c
fix-neo-chat
Jan 1, 2026
b48cd49
fix-neo-chat-position-ids-h
Jan 5, 2026
7a904f3
add-neo-chat-dense
Jan 6, 2026
4b757dd
add-neo-chat-dense
Jan 6, 2026
e208733
support verl.
Jan 8, 2026
245357c
improve0108
Jan 8, 2026
6503ac8
add min/max pixels sampling parameters
Jan 8, 2026
07df460
fix fused_moe not installed use pip.
Jan 12, 2026
a6f00fb
add visual nccl port alloc
shihaobai Jan 15, 2026
9360197
fix0115
Jan 15, 2026
920a741
fix0115
Jan 15, 2026
3aa5e18
fp8 online quant for moe
shihaobai Jan 16, 2026
7cb890b
hotfix for fa3 of llama
shihaobai Jan 16, 2026
c242a75
fp8w8a8 triton config
shihaobai Jan 19, 2026
a0195aa
fp16 config
shihaobai Jan 19, 2026
7f0c437
release ipc tensor early.
Jan 21, 2026
5738d9e
bugfix: fix flattened_bucket update weights
yqyao Jan 21, 2026
e11bf58
bugfix: fix update_weights from tensor
yqyao Jan 22, 2026
f767609
merge main
shihaobai Jan 28, 2026
ce76f8a
fix start
shihaobai Jan 29, 2026
45259ec
add-merge-kv-mode
Jan 29, 2026
da3b53d
add-neo-chat0129
Jan 29, 2026
1e066d0
Merge branch 'add-neo-chat-rebase' into rl_verl
Jan 29, 2026
043e898
moe fused weight
shihaobai Jan 30, 2026
52085a4
Merge branch 'rl_verl_rebase' of https://github.com/ModelTC/lightllm …
shihaobai Jan 30, 2026
80cfcc4
fix neo
shihaobai Jan 30, 2026
6bbdb4f
fix launch
shihaobai Jan 30, 2026
e436ba5
fix launch
shihaobai Jan 30, 2026
aef65bc
fix tp slice for merged moe weight
shihaobai Jan 30, 2026
bc87692
fix fusemoe weight
shihaobai Jan 30, 2026
cf5bcbf
fa3 for neo
shihaobai Jan 30, 2026
a23288b
fix dead visual process
shihaobai Jan 30, 2026
f558540
auto visual dp
shihaobai Jan 30, 2026
12c6c6b
fix format
shihaobai Jan 30, 2026
fd91cad
fix decode scale
Feb 2, 2026
2681263
add new mode support text_ids+image_ids
Feb 2, 2026
fd17aa0
add new mode support text_ids+image_ids
Feb 2, 2026
e516bd9
add cuda empty cache
shihaobai Feb 2, 2026
81a0c12
add invalid token ids to sampling_param for rl training
shihaobai Feb 2, 2026
14132d5
add unitest for apply_invalid_tokens
shihaobai Feb 2, 2026
ed41960
add gc collect
shihaobai Feb 3, 2026
706ae2e
logit_bias
shihaobai Feb 3, 2026
f432f5a
logit_bias
shihaobai Feb 3, 2026
92bf83a
Merge branch 'main' into rl_verl_rebase
shihaobai Feb 3, 2026
8f8ed44
merge main
shihaobai Feb 4, 2026
cac2edf
neo moe inferece speedup
shihaobai Feb 6, 2026
02078ad
port random generate
shihaobai Feb 9, 2026
68954b0
feat: add MoE expert routing capture for R3 rollout replay
sufubao Feb 9, 2026
3569d53
fix
sufubao Feb 9, 2026
fe54253
add node-id for env_utils
shihaobai Feb 9, 2026
92470f7
Merge branch 'rl_verl_rebase' of https://github.com/ModelTC/lightllm …
shihaobai Feb 9, 2026
8eead2b
Revert "add node-id for env_utils"
sufubao Feb 9, 2026
27f9e87
Revert "port random generate"
sufubao Feb 9, 2026
6fa8f74
add assert none
shihaobai Feb 9, 2026
bf83078
set_unique_server_name
shihaobai Feb 10, 2026
3eab5a7
fix return_routed_experts
sufubao Feb 10, 2026
14cfc95
fix r3
sufubao Feb 10, 2026
e8ed8b5
add-neo++
Feb 12, 2026
77b73c2
feat: add Qwen3Next linear attention model support
sufubao Feb 19, 2026
a4ab210
refactor: simplify mamba buffer copy and integrate Triton kernels
sufubao Feb 20, 2026
1686d34
fix conv3d
sufubao Feb 21, 2026
dd9b611
[draft] qwen3.5 dense
sufubao Feb 26, 2026
6a3a17c
split dense and moe
sufubao Feb 26, 2026
e1cdfb4
feat: add mamba_cache_ratio for automatic memory allocation
sufubao Feb 26, 2026
f2e148e
refactor: simplify mamba_cache_ratio to direct percentage
sufubao Feb 26, 2026
b4fe201
add H100 config
sufubao Feb 26, 2026
e2ce9c0
refactor: align radix_cache_class with infer_state_class style
sufubao Feb 27, 2026
b1adbf3
fix: add missing attention_chunk param to flashattention_nopad.py
sufubao Feb 27, 2026
c744ebd
refactor: clarify naming in mamba_buffer_copy
sufubao Feb 27, 2026
9def697
clean
sufubao Feb 27, 2026
2b3deb8
fix
sufubao Feb 27, 2026
61f8945
clean
sufubao Feb 27, 2026
f7280a3
split
sufubao Feb 27, 2026
86d3bfb
Merge origin/qwen3.5_clean into rl_verl_qwen35
sufubao Feb 28, 2026
c05838e
fix: lazy-initialize SHM name constants to avoid import-time crash
sufubao Feb 28, 2026
243c6a0
fix: revert weight slicing and rmsnorm precision regressions
sufubao Feb 28, 2026
711e30c
fix
sufubao Feb 28, 2026
7734c21
feat: add Qwen3Next linear attention model support
sufubao Feb 19, 2026
c757b06
refactor: simplify mamba buffer copy and integrate Triton kernels
sufubao Feb 20, 2026
340d11c
fix conv3d
sufubao Feb 21, 2026
a6a2435
[draft] qwen3.5 dense
sufubao Feb 26, 2026
054035d
split dense and moe
sufubao Feb 26, 2026
01b112a
feat: add mamba_cache_ratio for automatic memory allocation
sufubao Feb 26, 2026
174757d
refactor: simplify mamba_cache_ratio to direct percentage
sufubao Feb 26, 2026
dd2516e
add H100 config
sufubao Feb 26, 2026
326ae22
refactor: align radix_cache_class with infer_state_class style
sufubao Feb 27, 2026
e996cd2
fix: add missing attention_chunk param to flashattention_nopad.py
sufubao Feb 27, 2026
5e5cdbe
refactor: clarify naming in mamba_buffer_copy
sufubao Feb 27, 2026
9cf783c
clean
sufubao Feb 27, 2026
e120edb
fix
sufubao Feb 27, 2026
f3330cf
clean
sufubao Feb 27, 2026
d030a67
split
sufubao Feb 27, 2026
e1f6129
style: apply black formatting to mamba_buffer_copy
sufubao Mar 1, 2026
74f82d1
perf: add autotune configs for mamba_buffer_copy/fork kernels on H200
sufubao Mar 1, 2026
c1ea769
refactor: rename buffer copy methods for clarity
sufubao Mar 2, 2026
b81baaa
clean the code
sufubao Mar 2, 2026
52b422a
vlm tokenizer support token list
shihaobai Mar 2, 2026
aa442a4
fix
shihaobai Mar 2, 2026
0fd0202
clean code
sufubao Mar 2, 2026
eed0a9c
qwen35 qkv improve
shihaobai Mar 6, 2026
b9a386e
code simplify
shihaobai Mar 9, 2026
86f17b6
clean code
shihaobai Mar 9, 2026
a1849e6
fix
shihaobai Mar 13, 2026
61f74ac
remove contiguous
shihaobai Mar 16, 2026
bf0f254
remove gemma rms norm config
shihaobai Mar 16, 2026
76782c2
clean code
sufubao Mar 17, 2026
fdd2052
add get_radix_class
sufubao Mar 17, 2026
733e851
fix acc of mamba cache
shihaobai Mar 17, 2026
b1f8233
fix acc of mamba cache
shihaobai Mar 17, 2026
90120b0
fix warmup
shihaobai Mar 17, 2026
4ef6091
merge main
shihaobai Mar 18, 2026
13edba2
simplify the qwen3next layer_infer
shihaobai Mar 18, 2026
ec499ce
openai api simplify
shihaobai Mar 18, 2026
3c8597d
simplify mem manager
shihaobai Mar 18, 2026
20edcc1
slime code
shihaobai Mar 19, 2026
eed9863
remove mtp of base_backend
shihaobai Mar 19, 2026
90df4f1
slime mode_backend
shihaobai Mar 19, 2026
3b832af
merge qwen3.5 and main
shihaobai Mar 19, 2026
91edf3b
fix invalid memory of release_memory
shihaobai Mar 19, 2026
711667a
flush_cache for hybrid cache
shihaobai Mar 19, 2026
b181c0a
fix rpyc
shihaobai Mar 19, 2026
c6a6dda
fix: node is None
sufubao Mar 20, 2026
ee3a7d5
fix resume invalid memory
shihaobai Mar 20, 2026
32d795d
fix reqs queue
shihaobai Mar 20, 2026
a0937a9
fix
shihaobai Mar 20, 2026
1de0e53
fix
shihaobai Mar 20, 2026
2dbd2f7
pop weight after load
shihaobai Mar 20, 2026
33bbfda
async update weight
shihaobai Mar 20, 2026
6017484
model.norm.weight: add 1 during runtime
shihaobai Mar 23, 2026
b98f6d7
fix r3
sufubao Mar 23, 2026
c0cebba
fix qwen35 nrom
sufubao Mar 23, 2026
5f4fa78
Revert "fix qwen35 nrom"
sufubao Mar 23, 2026
64506b3
fix
shihaobai Mar 23, 2026
9da13c1
Merge branch 'qwen3.5_clean' of https://github.com/ModelTC/lightllm i…
shihaobai Mar 23, 2026
73b10ca
remove unused log
shihaobai Mar 23, 2026
1d46601
Merge branch 'rl_verl_qwen35' of https://github.com/ModelTC/lightllm …
shihaobai Mar 23, 2026
b2ab0bf
Merge remote-tracking branch 'origin/main' into qwen3.5_clean
sufubao Mar 23, 2026
a93509f
fix mamba_len
shihaobai Mar 23, 2026
1115543
fix
sufubao Mar 23, 2026
d562b7b
fix
sufubao Mar 23, 2026
8f11f08
fix and remove unused code
shihaobai Mar 23, 2026
709075a
Merge branch 'qwen3.5_clean' of https://github.com/ModelTC/lightllm i…
shihaobai Mar 23, 2026
267412d
fix format
shihaobai Mar 23, 2026
f7bee08
gatermsnorm weight and mamba profile_size
shihaobai Mar 23, 2026
b85b6ca
simpliy code
shihaobai Mar 24, 2026
3965845
update tp param
shihaobai Mar 24, 2026
ef41d77
fix: restore tool_calls arguments JSON string to dict conversion
sufubao Mar 24, 2026
7d0458f
fix: restore tool_calls arguments JSON string to dict conversion
sufubao Mar 24, 2026
8f1212a
fix build_prompt too
sufubao Mar 25, 2026
77bfcba
fix
sufubao Mar 25, 2026
3585432
fix buffer idx
shihaobai Mar 26, 2026
334e3c4
fix
shihaobai Mar 26, 2026
2f34bac
merge the update of qwen3.5_clean
shihaobai Mar 26, 2026
0974ba9
fix
shihaobai Mar 27, 2026
fe91aa3
add instance_id with improved robustness and code quality
sufubao Mar 31, 2026
f4a0cb7
fix: occasional accuracy drop in rollout
shihaobai Apr 1, 2026
f4caa8f
reset req manager
shihaobai Apr 1, 2026
8794f43
fix typo
shihaobai Apr 2, 2026
1abf95a
add fp8 rl for qwen35
shihaobai Apr 10, 2026
901bd13
fix abort
shihaobai Apr 15, 2026
8de8baf
add logs for detoken
shihaobai Apr 16, 2026
2dc39fa
fix decode overflow
shihaobai Apr 17, 2026
8c20369
fix bytes decode
shihaobai Apr 18, 2026
6cd300c
merge main
shihaobai May 8, 2026
1f466c7
remove neo
shihaobai May 8, 2026
9e54f20
remove unused code
shihaobai May 8, 2026
46d2ee2
remove unused code
shihaobai May 8, 2026
f2c1a3e
remove unused code
shihaobai May 8, 2026
e7c1475
slime code
shihaobai May 8, 2026
1ecf015
slim code
shihaobai May 8, 2026
a93dcb6
slime code
shihaobai May 9, 2026
ccc8832
slime radix cache
shihaobai May 9, 2026
11ea37a
slime radixcache
shihaobai May 9, 2026
f446e5b
slim code
shihaobai May 9, 2026
998020a
remove unused code
shihaobai May 9, 2026
5a745e5
fix
shihaobai May 9, 2026
90ed556
lazy init cache dir
shihaobai May 9, 2026
eb42e5b
fix linear flush_cache
shihaobai May 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dist
.vscode
tmp/
requirements-musa.txt
CLAUDE.md
104 changes: 91 additions & 13 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn.functional as F
import triton
from typing import final, List
from typing import final, List, Optional
from tqdm import tqdm

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
Expand All @@ -33,6 +33,10 @@
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
from lightllm.utils.torch_memory_saver_utils import (
TorchMemorySaverWrapper,
MemoryTag,
)
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend

Expand Down Expand Up @@ -91,6 +95,7 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)
self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
Expand All @@ -104,19 +109,21 @@ def __init__(self, kvargs):
self._verify_params()
self._init_quant()

self._init_weights()
self._init_req_manager()
self._init_mem_manager()
with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup):
self._init_weights()
with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
self._init_req_manager()
self._init_mem_manager()
self._init_kv_move_buffer()

# 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state
# 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值
self.req_manager.mem_manager = self.mem_manager

self._init_kv_move_buffer()
self._check_mem_size()
self._init_infer_layer()
self._init_some_value()
self._init_custom()
self._load_hf_weights()
self.load_weights(self.weight_dict)
# wait必须在init cudagraph 之前,避免错误捕获
self._wait_other_modules_ready()

Expand Down Expand Up @@ -181,17 +188,15 @@ def _init_weights(self, start_layer_index=0):
]
return

def _load_hf_weights(self):
def load_weights(self, weight_dict: dict):
assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None"
load_hf_weights(
self.data_type,
weight_dir=self.weight_dir_,
self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict,
weight_dict=weight_dict,
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
return

def _init_mem_manager(self):
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
Expand Down Expand Up @@ -999,6 +1004,7 @@ def _check_max_len_infer(self):
)
logger.error(exception_str)
raise Exception(exception_str)
torch.cuda.empty_cache()
return

def autotune_layers(self):
Expand Down Expand Up @@ -1133,6 +1139,9 @@ def _init_padded_req(self):
del b_seq_len
del b_ready_cache_len
del model_output
del b_mtp_index
del b_prefill_start_loc
del b_q_seq_len
torch.cuda.empty_cache()
return

Expand All @@ -1153,3 +1162,72 @@ def _gen_special_model_input(self, token_num: int):
special_model_input["mtp_draft_input_hiddens"] = None

return special_model_input

def release_memory_occupation(self, tags: Optional[List[MemoryTag]]):
torch.cuda.synchronize()
if tags is None:
self.release_all()
return
if MemoryTag.WEIGHT in tags:
self.release_weight()
if MemoryTag.KV_CACHE in tags:
self.release_kv_cache()
if MemoryTag.GRAPH in tags:
self.release_graph()
return

def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]):
if tags is None:
self.resume_all()
return
if MemoryTag.WEIGHT in tags:
self.resume_weight()
if MemoryTag.KV_CACHE in tags:
self.resume_kv_cache()
if MemoryTag.GRAPH in tags:
self.resume_graph()
return

def release_weight(self):
self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
torch.cuda.empty_cache()
gc.collect()

def release_kv_cache(self):
self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
torch.cuda.empty_cache()
gc.collect()

def release_graph(self):
self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
torch.cuda.empty_cache()
gc.collect()

def release_all(self):
self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
torch.cuda.empty_cache()
gc.collect()

def resume_weight(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)

def resume_kv_cache(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)

def resume_graph(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)

def resume_all(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
9 changes: 7 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.utils.torch_memory_saver_utils import (
TorchMemorySaverWrapper,
MemoryTag,
)
from .infer_struct import InferStateInfo


Expand All @@ -26,6 +30,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int =
self.max_batch_size = max_batch_size
self.graph_max_len_in_batch = max_len_in_batch
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)

# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
Expand Down Expand Up @@ -96,7 +101,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
delattr(infer_state, param_name)

with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
graph_obj.replay()
Expand Down Expand Up @@ -134,7 +139,7 @@ def _capture_decode_overlap(

with lightllm_capture_graph(dist_group1):
with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
graph_obj,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
moe_layer_index: int = 0,
) -> None:
super().__init__(data_type=data_type)
self.w1_weight_name = gate_proj_name
Expand All @@ -50,6 +51,7 @@ def __init__(
self.enable_ep_moe = get_env_start_args().enable_ep_moe
self.n_routed_experts = n_routed_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.moe_layer_index = moe_layer_index
self._init_config(network_config)
self._init_redundancy_expert_params()
self._init_parallel_params()
Expand Down Expand Up @@ -130,6 +132,7 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
microbatch_index: int = 0,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
return self.fuse_moe_impl(
Expand All @@ -145,6 +148,8 @@ def experts(
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
moe_layer_index=self.moe_layer_index,
microbatch_index=microbatch_index,
)

def low_latency_dispatch(
Expand Down Expand Up @@ -295,6 +300,7 @@ def _create_weight(self):
device_id=self.device_id_,
num_experts=self.local_n_routed_experts,
)
self.w1, self.w3 = w13_param_list
self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0])
self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1])
self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2)
Expand All @@ -307,7 +313,8 @@ def _get_expert_weight_list(self, weight_pack: WeightPack):
return weight_list

def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]):

# for merged weights
self._load_merge_weight(weights)
# Load each expert with TP slicing
for expert_idx, local_expert_idx in expert_idx_to_local_idx.items():
with self.lock:
Expand All @@ -332,6 +339,7 @@ def _load_expert(
w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}"
w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}"
w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}"

row_slice_func = self.row_slicer._slice_weight
col_slice_func = self.col_slicer._slice_weight
if w1_weight in weights:
Expand All @@ -341,6 +349,19 @@ def _load_expert(
if w2_weight in weights:
self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx])

def _load_merge_weight(self, weights: Dict[str, torch.Tensor]):
w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}"
w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}"
w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}"
row_slice_func = self.row_slicer._slice_weight
col_slice_func = self.col_slicer._slice_weight
if w1_merge_weight in weights:
self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1)
if w2_merge_weight in weights:
self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2)
if w3_merge_weight in weights:
self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3)

def _load_expert_scale(
self,
expert_idx: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightllm.common.quantization import Quantcfg
from lightllm.common.quantization.quantize_method import QuantizationMethod
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel import routing_manager as _routing_mgr

logger = init_logger(__name__)

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
moe_layer_index: int = 0,
) -> None:
network_config["norm_topk_prob"] = None
super().__init__(
Expand All @@ -62,6 +64,7 @@ def __init__(
num_fused_shared_experts=num_fused_shared_experts,
layer_num=layer_num,
network_config=network_config,
moe_layer_index=moe_layer_index,
)

self.hidden_size = network_config["hidden_size"]
Expand Down Expand Up @@ -144,10 +147,15 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
microbatch_index: int = 0,
):

topk_weights, topk_ids = self._router(router_logits, top_k)

# Rollout router replay
if _routing_mgr.g_routing_capture_manager is not None:
_routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)

w1, w1_scale = self.w1
w2, w2_scale = self.w2
use_fp8_w8a8 = self.quant_method is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,7 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
moe_layer_index: Optional[int] = None,
microbatch_index: int = 0,
) -> torch.Tensor:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lightllm.common.quantization.no_quant import WeightPack
from lightllm.common.quantization.quantize_method import QuantizationMethod
from .base_impl import FuseMoeBaseImpl
from lightllm.common.basemodel import routing_manager as _routing_mgr


class FuseMoeTriton(FuseMoeBaseImpl):
Expand Down Expand Up @@ -125,6 +126,8 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
moe_layer_index: Optional[int] = None,
microbatch_index: int = 0,
):
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
Expand All @@ -137,6 +140,10 @@ def __call__(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)

if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None:
_routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index)

output = self._fused_experts(
input_tensor=input_tensor,
w13=w13,
Expand Down
Loading