Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
89a27fd
feat(mtp): MTP verify-decode infrastructure
sufubao Jun 9, 2026
61e4dcb
feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP draft models
sufubao Jun 9, 2026
89a163a
feat(qwen3next): GDN spec-decode verify path + linear-att cache split
sufubao Jun 9, 2026
47ddb6c
feat(scheduler): MTP verify backend + accept-len transport
sufubao Jun 9, 2026
085a185
test(mtp): MTP unit tests + static benchmark
sufubao Jun 9, 2026
db50f25
Fix Qwen3Next MTP linear-att page moves
sufubao Jun 9, 2026
45ec253
revert formatting churn on pre-existing code
sufubao Jun 15, 2026
5883b41
revert(mtp): drop eagle reduced-batch draft optimization
sufubao Jun 15, 2026
82522e6
revert(mtp): run the MTP draft on upstream's grouped verify layout
sufubao Jun 15, 2026
cd6b918
clean code
sufubao Jun 16, 2026
fe9ac22
clean code
sufubao Jun 16, 2026
10473dd
refactor(mtp): GPU-resident req_to_accept_len + simplify verify-decod…
sufubao Jun 16, 2026
45831a2
revert: drop all test/ and unit_tests/ changes from this branch
sufubao Jun 16, 2026
31fa641
style: black-format fp8.py k/v_descale lines (pre-commit)
sufubao Jun 16, 2026
c4c3c2f
clean code
sufubao Jun 16, 2026
6f78b54
Merge upstream/main into qw35_mtp_feature
sufubao Jun 23, 2026
7871295
Merge remote-tracking branch 'upstream/main' into qw35_mtp_feature
sufubao Jun 29, 2026
0c2f7d0
fix
sufubao Jun 29, 2026
f75dfaf
fix
sufubao Jun 29, 2026
01447ec
Merge remote-tracking branch 'upstream/main' into qw35_mtp_feature
sufubao Jun 29, 2026
5814653
fix
sufubao Jun 29, 2026
694fbe6
fix format
sufubao Jun 29, 2026
39d3822
clean code: mtp_verify_extra_state.py
sufubao Jun 29, 2026
efda16d
clean code
sufubao Jun 29, 2026
8d51682
clean code
sufubao Jun 29, 2026
a4c79d6
clean code
sufubao Jun 29, 2026
7bc84fc
clean code
sufubao Jun 29, 2026
ebe6ae8
fix
sufubao Jun 30, 2026
fb98c5c
restore cudagraph
shihaobai Jul 1, 2026
145fb32
update infer_struct
shihaobai Jul 1, 2026
47f0e99
clean code
shihaobai Jul 1, 2026
4f85a3f
fix
sufubao Jul 1, 2026
6a05570
clean transformers layer_weight
shihaobai Jul 1, 2026
c51bc4b
Merge branch 'qw35_mtp_feature' of https://github.com/sufubao/lightll…
shihaobai Jul 1, 2026
0ee2150
clean req_manager.py
shihaobai Jul 1, 2026
d2e46e4
fix
shihaobai Jul 1, 2026
e119a68
fix
shihaobai Jul 1, 2026
c6ed3ea
clean code
shihaobai Jul 1, 2026
40a143d
clean model.py
shihaobai Jul 1, 2026
e2aab74
fix
hiworldwzj Jul 2, 2026
5671f1e
fix
hiworldwzj Jul 2, 2026
e7552cf
fix
hiworldwzj Jul 2, 2026
2beb876
fix
hiworldwzj Jul 2, 2026
e06bce6
fix
hiworldwzj Jul 2, 2026
36d65f8
fix
hiworldwzj Jul 2, 2026
104abc7
fix
hiworldwzj Jul 2, 2026
740cb2d
fix
hiworldwzj Jul 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ dist
.vscode
tmp/
requirements-musa.txt
logs/
logs/

/benchmark/
artifacts/
22 changes: 14 additions & 8 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def init_state(self):
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
)
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)

self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

def prefill_att(
self,
Expand Down Expand Up @@ -120,16 +123,19 @@ def init_state(self):
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0

device = self.infer_state.input_ids.device
batch_size = att_batch_size
mem_manager = self.backend.model.mem_manager

offline_scales: torch.Tensor = mem_manager.scales
head_num = mem_manager.head_num

# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

return

Expand Down Expand Up @@ -180,11 +186,11 @@ def _fp8_decode_att(
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cache_seqlens=self.b_att_seq_len,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.decode_max_q_seq_len,
causal=False,
causal=True,
window_size=(-1, -1),
softcap=0.0,
q_descale=q_scale.view(self.infer_state.batch_size, k_head_num),
Expand Down
7 changes: 1 addition & 6 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,12 +1171,7 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
)
is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
Expand Down
74 changes: 38 additions & 36 deletions lightllm/common/basemodel/triton_kernel/linear_att_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

@triton.jit
def _copy_linear_att_state_to_kv_buffer(
gpu_conv_ptr, # [linear_layer_num, size_num, xdim]
gpu_ssm_ptr, # [linear_layer_num, size_num, xxdim]
cpu_kv_conv_ptr, # [size, linear_layer_num, xdim]
cpu_kv_ssm_ptr, # [size, linear_layer_num, xxdim]
gpu_conv_ptr, # uint8 view: [linear_layer_num, req_num, conv_dim, gpu_conv_row_bytes]
gpu_ssm_ptr, # uint8 view: [linear_layer_num, req_num * (mtp_step + 1), ssm_bytes]
cpu_kv_conv_ptr, # uint8 view: [buffer_num, linear_layer_num, conv_dim, cpu_conv_row_bytes]
cpu_kv_ssm_ptr, # uint8 view: [buffer_num, linear_layer_num, ssm_bytes]
b_req_idx, # [batch_size,]
big_page_buffer_ids, # [batch_size,]
gpu_conv_stride_l,
Expand All @@ -24,7 +24,8 @@ def _copy_linear_att_state_to_kv_buffer(
cpu_kv_ssm_stride_l,
cpu_kv_ssm_stride_d,
mtp_step,
gpu_conv_tail_dim,
conv_dim, # number of conv rows
cpu_conv_row_bytes, # bytes copied per conv row; equals the CPU/cache row width
gpu_ssm_tail_dim,
BLOCK: tl.constexpr,
):
Expand All @@ -33,7 +34,10 @@ def _copy_linear_att_state_to_kv_buffer(
cpu_kv_conv_stride_s = tl.cast(cpu_kv_conv_stride_s, dtype=tl.int64)
cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, dtype=tl.int64)
gpu_conv_stride_s = tl.cast(gpu_conv_stride_s, dtype=tl.int64)
gpu_conv_stride_d = tl.cast(gpu_conv_stride_d, dtype=tl.int64)
gpu_ssm_stride_s = tl.cast(gpu_ssm_stride_s, dtype=tl.int64)
cpu_kv_conv_stride_d = tl.cast(cpu_kv_conv_stride_d, dtype=tl.int64)
cpu_conv_row_bytes = tl.cast(cpu_conv_row_bytes, dtype=tl.int64)

big_page_buffer_idx = tl.load(big_page_buffer_ids + cur_batch)
if big_page_buffer_idx == -1:
Expand All @@ -42,20 +46,16 @@ def _copy_linear_att_state_to_kv_buffer(
cur_req_idx = tl.load(b_req_idx + cur_batch).to(tl.int64)
cur_state_req_idx = (cur_req_idx * (mtp_step + 1)).to(tl.int64)

for i in range(tl.cdiv(gpu_conv_tail_dim, BLOCK)):
gpu_start_off = i * BLOCK + tl.arange(0, BLOCK)
mask = gpu_start_off < gpu_conv_tail_dim
conv_data = tl.load(
gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_state_req_idx * gpu_conv_stride_s + gpu_start_off,
mask=mask,
)
dest_conv_ptr = (
cpu_kv_conv_ptr
+ big_page_buffer_idx * cpu_kv_conv_stride_s
+ cur_layer * cpu_kv_conv_stride_l
+ gpu_start_off
)
tl.store(dest_conv_ptr, conv_data, mask=mask)
gpu_conv_base = gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_req_idx * gpu_conv_stride_s
cpu_conv_base = cpu_kv_conv_ptr + big_page_buffer_idx * cpu_kv_conv_stride_s + cur_layer * cpu_kv_conv_stride_l
conv_tail_dim = conv_dim * cpu_conv_row_bytes
for i in range(tl.cdiv(conv_tail_dim, BLOCK)):
conv_start = i * BLOCK + tl.arange(0, BLOCK)
conv_row = conv_start // cpu_conv_row_bytes
conv_col = conv_start - conv_row * cpu_conv_row_bytes
mask = conv_row < conv_dim
conv_data = tl.load(gpu_conv_base + conv_row * gpu_conv_stride_d + conv_col, mask=mask)
tl.store(cpu_conv_base + conv_row * cpu_kv_conv_stride_d + conv_col, conv_data, mask=mask)

for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)):
gpu_start_off = i * BLOCK + tl.arange(0, BLOCK)
Expand All @@ -75,32 +75,33 @@ def _copy_linear_att_state_to_kv_buffer(
def copy_linear_att_state_to_kv_buffer(
b_req_idx: torch.Tensor,
big_page_buffer_ids: torch.Tensor,
gpu_conv_state: torch.Tensor, # [linear_layer_num, s, ...]
gpu_ssm_state: torch.Tensor, # [linear_layer_num, s, ...]
cpu_kv_conv_state: torch.Tensor, # [s, linear_layer_num, ...]
cpu_kv_ssm_state: torch.Tensor, # [s, linear_layer_num, ...]
gpu_conv_state: torch.Tensor, # [linear_layer_num, req_num, conv_dim, widened_width]
gpu_ssm_state: torch.Tensor, # [linear_layer_num, req_num * (mtp_step + 1), ...]
cpu_kv_conv_state: torch.Tensor, # [buffer_num, linear_layer_num, conv_dim, base_width]
cpu_kv_ssm_state: torch.Tensor, # [buffer_num, linear_layer_num, ...]
mtp_step: int,
):
assert len(b_req_idx) == big_page_buffer_ids.shape[0]
BLOCK = 4096
gpu_conv_state = gpu_conv_state.view(gpu_conv_state.shape[0], gpu_conv_state.shape[1], -1).view(dtype=torch.uint8)

assert gpu_conv_state.dim() >= 4, "gpu_conv_state must be [layer, s, conv_dim, widened_width]"
assert cpu_kv_conv_state.dim() >= 4, "cpu_kv_conv_state must be [size, layer, conv_dim, width_narrow]"
gpu_conv_state = gpu_conv_state.view(
gpu_conv_state.shape[0], gpu_conv_state.shape[1], gpu_conv_state.shape[2], -1
).view(dtype=torch.uint8)
cpu_kv_conv_state = cpu_kv_conv_state.view(
cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], cpu_kv_conv_state.shape[2], -1
).view(dtype=torch.uint8)
gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8)
cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], -1).view(
dtype=torch.uint8
)
cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], cpu_kv_ssm_state.shape[1], -1).view(
dtype=torch.uint8
)
assert gpu_conv_state.shape[-1] == cpu_kv_conv_state.shape[-1]
assert gpu_conv_state.shape[2] == cpu_kv_conv_state.shape[2], "conv_dim mismatch between gpu and cpu conv buffers"
assert gpu_ssm_state.shape[-1] == cpu_kv_ssm_state.shape[-1]
assert (
gpu_conv_state.stride(-1)
== gpu_ssm_state.stride(-1)
== cpu_kv_conv_state.stride(-1)
== cpu_kv_ssm_state.stride(-1)
)

gpu_conv_tail_dim = gpu_conv_state.shape[-1]
conv_dim = gpu_conv_state.shape[2]
cpu_conv_row_bytes = cpu_kv_conv_state.shape[-1]
assert cpu_conv_row_bytes <= gpu_conv_state.shape[-1]
gpu_ssm_tail_dim = gpu_ssm_state.shape[-1]

layer_num = gpu_conv_state.shape[0]
Expand All @@ -127,7 +128,8 @@ def copy_linear_att_state_to_kv_buffer(
cpu_kv_ssm_stride_l=cpu_kv_ssm_state.stride(1),
cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(2),
mtp_step=mtp_step,
gpu_conv_tail_dim=gpu_conv_tail_dim,
conv_dim=conv_dim,
cpu_conv_row_bytes=cpu_conv_row_bytes,
gpu_ssm_tail_dim=gpu_ssm_tail_dim,
BLOCK=BLOCK,
)
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,7 @@ def copy_kv_buffer_to_cpu_cache(
cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1]
full_att_layer_num = gpu_kv_full_att_state.shape[-2]

assert (
full_att_layer_num
== (linear_config.all_layer_num // linear_config.full_attention_interval)
== (linear_config.all_layer_num - linear_config.linear_layer_num)
)
assert full_att_layer_num == linear_config.get_full_att_kv_layer_num_with_draft_model()
assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1]
assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1]
assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1]
Expand Down Expand Up @@ -388,7 +384,6 @@ def copy_cpu_cache_to_kv_buffer(
linear_config: LinearAttCacheConfig,
grid_num: int = 12,
):

assert len(mem_indexes) % len(page_indexes) == 0

BLOCK = 4096
Expand Down
66 changes: 35 additions & 31 deletions lightllm/common/basemodel/triton_kernel/mtp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def _fwd_kernel_mtp_scatter_next_token_ids(
mtp_step,
BLOCK_SIZE: tl.constexpr,
):

cur_index = tl.program_id(0)
req_start_loc = tl.load(b_req_mtp_start_loc + cur_index)
accept_len = tl.load(mtp_accept_len + cur_index)
Expand Down Expand Up @@ -149,35 +148,48 @@ def mtp_scatter_next_token_ids(


@triton.jit
def _fwd_kernel_gen_b_req_mtp_start_loc(
b_mtp_index,
def _fwd_kernel_scatter_accept_len(
req_to_accept_len,
b_req_mtp_start_loc,
num_reqs: tl.constexpr,
batch_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
b_req_idx,
mtp_accept_len,
):
offset = tl.arange(0, BLOCK_SIZE)
cur_mtp_index = tl.load(b_mtp_index + offset, mask=offset < batch_size, other=-1)
non_zero_mask = tl.where(cur_mtp_index == 0, 1, 0) # 1 0 1 0 0
output_offset = tl.cumsum(non_zero_mask) - 1
tl.store(b_req_mtp_start_loc + output_offset, offset, mask=non_zero_mask == 1)
cur_index = tl.program_id(0)
req_start_loc = tl.load(b_req_mtp_start_loc + cur_index)
cur_req_idx = tl.load(b_req_idx + req_start_loc)
accept_len = tl.load(mtp_accept_len + cur_index)
tl.store(req_to_accept_len + cur_req_idx, accept_len)
return


def gen_b_req_mtp_start_loc(b_mtp_index: torch.Tensor, num_reqs: int):
b_req_mtp_start_loc = torch.empty((num_reqs,), dtype=torch.int32, device=b_mtp_index.device)
BLOCK_SIZE = triton.next_power_of_2(b_mtp_index.shape[0])
batch_size = b_mtp_index.shape[0]
grid = (1,)
_fwd_kernel_gen_b_req_mtp_start_loc[grid](
b_mtp_index=b_mtp_index,
def scatter_mtp_accept_len(
req_to_accept_len: torch.Tensor,
b_req_mtp_start_loc: torch.Tensor,
b_req_idx: torch.Tensor,
mtp_accept_len: torch.Tensor,
):
"""
将本步每个真实请求(组首)的 accept 数量写入 GPU 常驻的 req_to_accept_len[req_idx]。
融合 `req_to_accept_len[b_req_idx[b_req_mtp_start_loc]] = mtp_accept_len` 的 gather+scatter
为单次 launch、无中间张量。每个 program 处理一个真实请求。
Args:
req_to_accept_len: (max_req_num + 1,)
b_req_mtp_start_loc: (num_reqs,) 每组首行在 batch 中的偏移
b_req_idx: (batch_size,) grouped 布局的 req_idx(组首即该请求的 req_idx)
mtp_accept_len: (num_reqs,)
"""
num_reqs = mtp_accept_len.shape[0]
if num_reqs == 0:
return
grid = (num_reqs,)
_fwd_kernel_scatter_accept_len[grid](
req_to_accept_len=req_to_accept_len,
b_req_mtp_start_loc=b_req_mtp_start_loc,
num_reqs=num_reqs,
batch_size=batch_size,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=8,
b_req_idx=b_req_idx,
mtp_accept_len=mtp_accept_len,
num_warps=1,
num_stages=1,
)
return b_req_mtp_start_loc


def test_mtp_verify():
Expand All @@ -201,13 +213,5 @@ def test_mtp_verify():
print(accepted_index)


def test_gen_b_req_mtp_start_loc():
b_mtp_index = torch.tensor([0, 1, 0, 1, 2], dtype=torch.int32, device="cuda")
gt_output = torch.where(b_mtp_index == 0)[0]
b_req_mtp_start_loc = gen_b_req_mtp_start_loc(b_mtp_index, 2)
print(b_req_mtp_start_loc, gt_output)


if __name__ == "__main__":
test_mtp_verify()
# test_gen_b_req_mtp_start_loc()
25 changes: 15 additions & 10 deletions lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,8 @@ def write_req_to_page(
dp_mems: List["Qwen3NextMemManager"],
):
conv_page, ssm_page = self.view_page_to_linear_att_state(page_index)
req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1)
for tp_index, mem in enumerate(dp_mems):
self._write_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page)
self._write_one_rank(mem, tp_index, req_idx, conv_page, ssm_page)
return

def read_page_to_req(
Expand All @@ -220,21 +219,26 @@ def read_page_to_req(
dp_mems: List["Qwen3NextMemManager"],
):
conv_page, ssm_page = self.view_page_to_linear_att_state(page_index)
req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1)
for tp_index, mem in enumerate(dp_mems):
self._read_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page)
self._read_one_rank(mem, tp_index, req_idx, conv_page, ssm_page)
return

def _get_req_state_indexes(self, req_idx: int):
mtp_size = get_env_start_args().mtp_step + 1
# Conv is one widened slot per request; SSM keeps the historical S+1 block layout.
return req_idx, req_idx * mtp_size

def _write_one_rank(
self,
mem: "Qwen3NextMemManager",
tp_index: int,
req_buffer_idx: int,
req_idx: int,
conv_page: torch.Tensor,
ssm_page: torch.Tensor,
):
conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...]
ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...]
conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx)
conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]]
ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...]
self._copy_conv_state_to_page(conv_state, conv_page, mem, tp_index)
self._copy_ssm_state_to_page(ssm_state, ssm_page, mem, tp_index)
return
Expand Down Expand Up @@ -408,12 +412,13 @@ def _read_one_rank(
self,
mem: "Qwen3NextMemManager",
tp_index: int,
req_buffer_idx: int,
req_idx: int,
conv_page: torch.Tensor,
ssm_page: torch.Tensor,
):
conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...]
ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...]
conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx)
conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]]
ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...]
self._copy_page_to_conv_state(conv_page, conv_state, mem, tp_index)
self._copy_page_to_ssm_state(ssm_page, ssm_state, mem, tp_index)
return
Loading
Loading