From b636310df008c61dce2096835dfd668daa76999c Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Fri, 14 Nov 2025 13:38:31 +0800 Subject: [PATCH 001/180] add /flush_cache (#1108) --- lightllm/server/api_http.py | 26 +++++++++++ lightllm/server/api_start.py | 8 ++-- lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/httpserver/manager.py | 11 +++++ lightllm/server/io_struct.py | 7 +++ .../router/dynamic_prompt/radix_cache.py | 24 +++++++++++ lightllm/server/router/manager.py | 34 +++++++++++++++ lightllm/server/router/mananger_rpc.py | 43 +++++++++++++++++++ .../model_infer/mode_backend/base_backend.py | 5 +++ .../server/router/model_infer/model_rpc.py | 19 ++++++++ .../router/dynamic_prompt/test_radix_cache.py | 27 ++++++++++++ 11 files changed, 202 insertions(+), 3 deletions(-) create mode 100644 lightllm/server/io_struct.py create mode 100644 lightllm/server/router/mananger_rpc.py diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8bda50fb76..2ef01ea903 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -58,6 +58,7 @@ CompletionRequest, CompletionResponse, ) +from .io_struct import AbortReq from .build_prompt import build_prompt, init_tokenizer logger = init_logger(__name__) @@ -291,6 +292,30 @@ async def metrics() -> Response: return response +@app.post("/abort_req") +async def abort_req(request: AbortReq, raw_request: Request): + """Abort a request.""" + try: + await g_objs.httpserver_manager.abort_req(request) + return Response(status_code=200) + except Exception as e: + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +@app.post("/flush_cache") +@app.get("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + ret = await g_objs.httpserver_manager.flush_cache() + return Response( + content="Cache flushed successfully." + if ret + else "Cache flush failed. " + + "When there are running or waiting requests, the operation will not be performed.", + status_code=200 if ret else 500, + ) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() @@ -357,6 +382,7 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) + g_objs.httpserver_manager.connect_router_rpc() loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f73be30dbc..138b0a599b 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -225,11 +225,12 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=8 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=9 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( router_port, + router_rpc_port, detokenization_port, http_server_port, visual_port, @@ -237,8 +238,8 @@ def normal_or_p_d_start(args): cache_port, metric_port, multi_level_kv_cache_port, - ) = can_use_ports[0:8] - can_use_ports = can_use_ports[8:] + ) = can_use_ports[0:9] + can_use_ports = can_use_ports[9:] visual_model_tp_ports = [] for _ in range(args.visual_dp): @@ -248,6 +249,7 @@ def normal_or_p_d_start(args): # 将申请好的端口放入args参数中 args.router_port = router_port + args.router_rpc_port = router_rpc_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port args.visual_port = visual_port diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 69d907fff5..659aab1dc7 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -113,6 +113,7 @@ class StartArgs: disk_cache_storage_size: float = field(default=10) # zmp ports router_port: int = field(default=None) + router_rpc_port: int = field(default=None) detokenization_port: int = field(default=None) http_server_port: int = field(default=None) visual_port: int = field(default=None) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 11919398e3..7158b89235 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -255,6 +255,13 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert False, "dead code path" return group_request_id + def connect_router_rpc(self): + from lightllm.server.router.mananger_rpc import connect_router_rpc + + self.router_rpc_client = connect_router_rpc(self.args.router_rpc_port) + logger.info("HttpServerManager connected to Router RPC service successfully") + return + async def generate( self, prompt: Union[str, List[int]], @@ -763,6 +770,10 @@ async def handle_loop(self): self.recycle_event.set() return + async def flush_cache(self): + ret = await self.router_rpc_client.flush_cache() + return ret + class ReqStatus: def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py new file mode 100644 index 0000000000..68d32816f1 --- /dev/null +++ b/lightllm/server/io_struct.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass +class AbortReq: + request_id: int = -1 + abort_all: bool = False diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 2bf0a4d5ab..9c207ec30c 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -424,6 +424,30 @@ def clear_tree_nodes(self): self.refed_tokens_num.arr[0] = 0 return + def flush_cache(self): + nodes_to_clear = collections.deque(self.root_node.children.values()) + self.root_node.children.clear() + while nodes_to_clear: + node = nodes_to_clear.popleft() + nodes_to_clear.extend(node.children.values()) + node.parent = None + node.children.clear() + + self.root_node.token_id_key[:] = 0 + self.root_node.token_mem_index_value[:] = 0 + self.root_node.ref_counter = 1 # 保持为1,确保不会被evict + self.root_node.time_id = time_gen.generate_time_id() + self.root_node.node_value_len = 0 + self.root_node.node_prefix_total_len = 0 + + self.evict_tree_set.clear() + self.evict_tree_set.add(self.root_node) + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + + return + def dec_node_ref_counter(self, node: TreeNode): if node is None: return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 3c8ca2399a..b6132b12f8 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -5,6 +5,7 @@ import pickle import inspect import setproctitle +import rpyc asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq @@ -151,6 +152,9 @@ async def wait_to_model_ready(self): rpc_finished_event=self.rpc_finished_event, ) + # 启动 rpyc 服务,供 HTTP Server 远程调用 + self._start_router_rpc_service() + kvargs = { "args": self.args, "rank_id": None, # 由后续处理填充真实数据 @@ -231,6 +235,25 @@ async def wait_to_model_ready(self): return + def _start_router_rpc_service(self): + """launch a rpyc service for httpserver to call RouterManager""" + import threading + from rpyc.utils.server import ThreadedServer + import lightllm.utils.rpyc_fix_utils as _ + from .mananger_rpc import RouterRpcService + + service = RouterRpcService(self) + port = self.args.router_rpc_port + + def start_server(): + t = ThreadedServer(service, port=port, protocol_config={"allow_pickle": True}) + t.start() + + rpc_thread = threading.Thread(target=start_server, daemon=True) + rpc_thread.start() + logger.info(f"Router RPC service started successfully on port {port}") + return + def _get_schedule_time_interval(self): # dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求 return self.schedule_time_interval @@ -535,6 +558,17 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return + def flush_cache(self) -> bool: + if self.running_batch is not None: + return False + if self.req_queue.get_wait_req_num() > 0: + return False + # if radix cache client is not initialized, just return True + if self.radix_cache_client is None: + return True + # only flush cache when no running batch and no waiting requests + return self.model_rpc_client.flush_radix_cache() + def clean_up(self): return diff --git a/lightllm/server/router/mananger_rpc.py b/lightllm/server/router/mananger_rpc.py new file mode 100644 index 0000000000..60f9e0458b --- /dev/null +++ b/lightllm/server/router/mananger_rpc.py @@ -0,0 +1,43 @@ +import rpyc +import asyncio +import socket +from .manager import RouterManager + + +class RouterRpcService(rpyc.Service): + def __init__(self, router_manager: "RouterManager"): + super().__init__() + self.router_manager = router_manager + return + + def exposed_flush_cache(self) -> bool: + return self.router_manager.flush_cache() + + +class RouterRpcClient: + def __init__(self, router_rpc_conn): + self.router_rpc_conn = router_rpc_conn + + def async_wrap(f): + f = rpyc.async_(f) + + async def _func(*args, **kwargs): + ans = f(*args, **kwargs) + await asyncio.to_thread(ans.wait) + # raise if exception + return ans.value + + return _func + + self._flush_cache = async_wrap(self.router_rpc_conn.root.flush_cache) + return + + async def flush_cache(self) -> bool: + ans = await self._flush_cache() + return ans + + +def connect_router_rpc(port: int) -> RouterRpcClient: + router_rpc_conn = rpyc.connect("localhost", port, config={"allow_pickle": True}) + router_rpc_conn._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return RouterRpcClient(router_rpc_conn) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 95f0c99515..db708f3cf5 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -288,6 +288,11 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return + def flush_radix_cache(self): + if self.radix_cache is not None: + self.radix_cache.flush_cache() + return + def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): """ 这个函数会把next token id和logprobs保存到pinned memory中 diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 1bb625db09..b7797a7628 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -181,6 +181,15 @@ def init_model(self, kvargs): def get_max_total_token_num(self): return self.backend.get_max_total_token_num() + def flush_radix_cache(self): + try: + if self.backend is not None: + self.backend.flush_radix_cache() + return True + except BaseException as e: + logger.exception(f"flush radix cache failed: {str(e)}") + return False + class ModelRpcClient: def __init__(self, rpc_event, rpc_finished_event): @@ -211,6 +220,16 @@ async def get_max_total_token_num(self): assert func_name == "get_max_total_token_num" return ret + def flush_radix_cache(self) -> bool: + self.rpc_shm_params.write_func_params("flush_radix_cache", ()) + self.rpc_event.set() + + self.rpc_finished_event.wait() + self.rpc_finished_event.clear() + func_name, ret = self.rpc_shm_results.read_func_result() + assert func_name == "flush_radix_cache" + return ret + def _init_env( args, diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py index 605433e9d8..dfeda0b6f7 100644 --- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py +++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py @@ -230,5 +230,32 @@ def test_case9(): assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64)) +def test_case10(): + """ + 测试场景:测试 flush_cache 函数 + """ + print("\nTest Case 10: Testing flush_cache function\n") + tree = RadixCache("unique_name", 100, 0) + tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is not None + assert size == 3 + tree.flush_cache() + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is None + assert size == 0 + assert tree.get_tree_total_tokens_num() == 0 + assert tree.get_refed_tokens_num() == 0 + assert len(tree.root_node.children) == 0 + assert tree.root_node.token_id_key.numel() == 0 + assert tree.root_node.token_mem_index_value.numel() == 0 + assert tree.root_node.ref_counter == 1 + + if __name__ == "__main__": pytest.main() From 60c379ed7e9a25a8db867e0190ef7656d8d6dd7a Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Tue, 18 Nov 2025 14:51:43 +0800 Subject: [PATCH 002/180] Aborted reqs (#1113) --- lightllm/server/api_http.py | 6 +- lightllm/server/core/objs/io_objs/__init__.py | 2 +- .../server/core/objs/io_objs/group_req.py | 25 +--- lightllm/server/core/objs/req.py | 12 +- lightllm/server/detokenization/decode_req.py | 8 +- lightllm/server/detokenization/manager.py | 7 +- lightllm/server/httpserver/manager.py | 118 ++++++++++++------ lightllm/server/io_struct.py | 63 +++++++++- .../server/multi_level_kv_cache/manager.py | 8 +- lightllm/server/multimodal_params.py | 15 ++- lightllm/server/router/manager.py | 63 ++++++---- .../server/router/model_infer/infer_batch.py | 2 + .../server/router/req_queue/base_queue.py | 11 ++ .../req_queue/chunked_prefill/beam_impl.py | 5 +- .../router/req_queue/chunked_prefill/impl.py | 5 +- .../chunked_prefill/impl_for_nixl_pd.py | 3 +- .../chunked_prefill/impl_for_pd_decode.py | 3 +- .../server/router/req_queue/dp_base_queue.py | 6 + lightllm/server/visualserver/manager.py | 11 +- 19 files changed, 254 insertions(+), 119 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 2ef01ea903..2c85488736 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -292,11 +292,11 @@ async def metrics() -> Response: return response -@app.post("/abort_req") -async def abort_req(request: AbortReq, raw_request: Request): +@app.post("/abort_request") +async def abort_request(request: AbortReq, raw_request: Request): """Abort a request.""" try: - await g_objs.httpserver_manager.abort_req(request) + await g_objs.httpserver_manager.abort_request(request) return Response(status_code=200) except Exception as e: return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") diff --git a/lightllm/server/core/objs/io_objs/__init__.py b/lightllm/server/core/objs/io_objs/__init__.py index c9b806c47d..10386b70e6 100644 --- a/lightllm/server/core/objs/io_objs/__init__.py +++ b/lightllm/server/core/objs/io_objs/__init__.py @@ -1 +1 @@ -from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd +from .group_req import AbortedReqCmd, StopStrMatchedReqCmd diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd2562..d644c0c316 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -1,33 +1,10 @@ from dataclasses import dataclass from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.core.objs.sampling_params import SamplingParams from typing import List from ..req import Req -@dataclass -class GroupReqIndexes: - group_req_id: int - multimodal_params: MultimodalParams - shm_req_indexes: List[int] - time_mark: float - - -@dataclass -class GroupReqObjs: - group_req_id: int - multimodal_params: MultimodalParams - shm_req_objs: List[Req] - time_mark: float - - def to_group_req_index(self): - return GroupReqIndexes( - group_req_id=self.group_req_id, - multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], - time_mark=self.time_mark, - ) - - @dataclass class AbortedReqCmd: req_id: int diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 947f24644d..e6f878b25c 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -24,19 +24,20 @@ class FinishStatus(ctypes.Structure): NO_FINISH = 0 FINISHED_STOP = 1 FINISHED_LENGTH = 2 + FINISHED_ABORTED = 3 def __init__(self, init_state=NO_FINISH): self.status = init_state def set_status(self, new_status): - assert 0 <= new_status <= 2 + assert 0 <= new_status <= 3 self.status = new_status def get_status(self): return self.status def is_finished(self): - return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH + return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED def is_stopped(self): return self.status == self.FINISHED_STOP @@ -49,6 +50,8 @@ def get_finish_reason(self): return "stop" elif self.status == self.FINISHED_LENGTH: return "length" + elif self.status == self.FINISHED_ABORTED: + return "abort" return None @@ -247,9 +250,8 @@ def can_release(self): ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: - return True - + # if self.is_aborted and can_released_mark and ref_count_ok: + # return True ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty(): diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9aa3a8effc..c77379986c 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool: return False def need_detoken(self): - if ( - (not self.req.is_aborted) - and (not self.req.stop_str_matched) - and len(self.output_ids) < self.req.candetoken_out_len - ): + if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len: return True return False @@ -83,8 +79,6 @@ def get_decode_tokens(self): return prefix_tokens, read_tokens def can_set_release_mark(self): - if self.req.is_aborted: - return True if self.req.stop_str_matched: return True if ( diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..ab5f706b97 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -6,7 +6,6 @@ import zmq import inspect from lightllm.server.core.objs import ShmReqManager, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.utils.graceful_utils import graceful_registry from typing import Union, Dict, List from .decode import decode_token @@ -17,6 +16,7 @@ import time from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import BaseReq logger = init_logger(__name__) @@ -46,7 +46,7 @@ def _init_get_token_id_to_token_str(self): self.token_id_to_token = {token_id: token for token, token_id in self.tokenizer.get_vocab().items()} return - def _add_new_group_req_index(self, recv_obj: GroupReqIndexes): + def _add_new_group_req_index(self, recv_obj: BaseReq): for req_index in recv_obj.shm_req_indexes: req = self.shm_req_manager.get_req_obj_by_index(req_index) req.link_prompt_ids_shm_array() @@ -74,8 +74,7 @@ def handle_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 7158b89235..5254e2097d 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -25,12 +25,18 @@ from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE -from lightllm.server.core.objs.io_objs import GroupReqObjs from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient +from lightllm.server.io_struct import ( + AbortReq, + BaseReq, + GenerateReq, + GenerateReqMeta, + GenerateReqIndex, +) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name @@ -74,7 +80,7 @@ def __init__( self.multinode_req_manager = context.socket(zmq.PULL) self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}") logger.info( - f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" + f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}" ) self.enable_multimodal = args.enable_multimodal @@ -218,18 +224,32 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def loop_for_request(self): assert self.args.node_rank > 0 while True: - ( - prompt, - sampling_params, - multimodal_params, - ) = await self.multinode_req_manager.recv_pyobj() - results_generator = self.generate(prompt, sampling_params, multimodal_params, None) + req_obj = await self.multinode_req_manager.recv_pyobj() + if req_obj is None: + continue + if isinstance(req_obj, GenerateReqMeta): + self.process_generate_request(req_obj) + elif isinstance(req_obj, AbortReq): + self.process_abort_request(req_obj) + else: + assert False, f"Unknown request type: {type(req_obj)}" + return + + def process_generate_request(self, req_meta: GenerateReqMeta): + prompt = req_meta.prompt + sampling_params = req_meta.sampling_params + multimodal_params = req_meta.multimodal_params + results_generator = self.generate(prompt, sampling_params, multimodal_params, None) + + async def generate_wrapper(results_generator): + async for _, _, _, _ in results_generator: + pass - async def generate_wrapper(results_generator): - async for _, _, _, _ in results_generator: - pass + asyncio.create_task(generate_wrapper(results_generator)) + return - asyncio.create_task(generate_wrapper(results_generator)) + def process_abort_request(self, request: AbortReq): + asyncio.create_task(self.abort_request(request)) return def alloc_req_id(self, sampling_params, is_health_req: bool = False): @@ -279,10 +299,6 @@ async def generate( group_request_id = self.alloc_req_id(sampling_params, is_health_req) try: - original_multimodal_params = None - if self.is_multinode_tp_master: - original_multimodal_params = copy.deepcopy(multimodal_params) - if self.pd_mode.is_P_or_NORMAL(): await multimodal_params.verify_and_preload(request) @@ -346,12 +362,17 @@ async def generate( ) req_objs.append(req_obj) - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) + req_status = ReqStatus( + group_request_id=group_request_id, + prompt=prompt, + sampling_params=sampling_params, + multimodal_params=multimodal_params, + req_objs=req_objs, + start_time=start_time, + ) self.req_id_to_out_inf[group_request_id] = req_status - await self.transfer_to_next_module_or_node( - prompt, sampling_params, original_multimodal_params, req_status.group_req_objs - ) + await self.transfer_to_next_module_or_node(req_status.group_req_objs) results_generator = self._wait_to_token_package( start_time, @@ -482,44 +503,49 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: async def transfer_to_next_module_or_node( self, - prompt: str, - sampling_params: SamplingParams, - original_multimodal_params: MultimodalParams, - group_req_objs: Optional[GroupReqObjs] = None, + req_obj: Optional["BaseReq"] = None, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. + req_to_next_node = req_obj.get_req_to_next_node() + self.transfer_to_next_node(req_to_next_node) + req_to_next_module = req_obj.get_req_to_next_module() + await self.transfer_to_next_module(req_to_next_module) + return + + def transfer_to_next_node( + self, + req_to_next_node: Optional["BaseReq"] = None, + ): if self.is_multinode_tp_master: for sender in self.multinode_req_manager: sender.send_pyobj( - (prompt, sampling_params, original_multimodal_params), + req_to_next_node, protocol=pickle.HIGHEST_PROTOCOL, ) - - await self.transfer_to_next_module(group_req_objs) return async def transfer_to_next_module( self, - group_req_objs: Optional[GroupReqObjs] = None, + req_to_next_module: Optional["GenerateReqIndex"] = None, ): if self.pd_mode.is_P_or_NORMAL(): if self.enable_multimodal: self.send_to_visual.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -527,7 +553,7 @@ async def transfer_to_next_module( if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -643,12 +669,24 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_id} not exist") return False - group_req_objs: GroupReqObjs = req_status.group_req_objs + group_req_objs: GenerateReq = req_status.group_req_objs for req in group_req_objs.shm_req_objs: req.is_aborted = True logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def abort_request(self, request: AbortReq): + request_id = request.request_id + abort_all = request.abort_all + if self.is_multinode_tp_master: + self.transfer_to_next_node(req_to_next_node=request) + if request_id is not None and not abort_all: + await self.abort(request_id) + if abort_all: + for group_req_id in list(self.req_id_to_out_inf.keys()): + await self.abort(group_req_id) + pass + async def recycle_resource_loop(self): pre_time_mark = time.time() @@ -776,11 +814,21 @@ async def flush_cache(self): class ReqStatus: - def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: + def __init__( + self, + group_request_id: int, + prompt: str, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + req_objs: List[Req], + start_time, + ) -> None: self.lock = asyncio.Lock() self.event = asyncio.Event() - self.group_req_objs = GroupReqObjs( + self.group_req_objs = GenerateReq( group_req_id=group_request_id, + prompt=prompt, + sampling_params=sampling_params, multimodal_params=multimodal_params, shm_req_objs=req_objs, time_mark=start_time, diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 68d32816f1..b5adff954e 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -1,7 +1,66 @@ +from abc import ABC from dataclasses import dataclass +from lightllm.server.core.objs.req import Req +from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from typing import List @dataclass -class AbortReq: - request_id: int = -1 +class BaseReq(ABC): + def get_req_to_next_node(self): + return self + + def get_req_to_next_module(self): + return self + + +# for next node +@dataclass +class GenerateReqMeta(BaseReq): + prompt: str + sampling_params: SamplingParams + multimodal_params: MultimodalParams + + +# for next module +@dataclass +class GenerateReqIndex(BaseReq): + group_req_id: int + multimodal_params: MultimodalParams + shm_req_indexes: List[int] + time_mark: float + + +@dataclass +class GenerateReq(BaseReq): + group_req_id: int + prompt: str + sampling_params: SamplingParams + multimodal_params: MultimodalParams + shm_req_objs: List[Req] + time_mark: float + + def get_req_to_next_module(self): + # 已经完成跨节点转发,可以释放图片原始资源 + self.multimodal_params.free() + return GenerateReqIndex( + group_req_id=self.group_req_id, + multimodal_params=self.multimodal_params, + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + time_mark=self.time_mark, + ) + + def get_req_to_next_node(self): + return GenerateReqMeta( + prompt=self.prompt, + sampling_params=self.sampling_params, + multimodal_params=self.multimodal_params, + ) + + +@dataclass +class AbortReq(BaseReq): + # 外部调用传入,等同内部的 group_req_id + request_id: int = None abort_all: bool = False diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py index 8853e352ed..e3bbe268b5 100644 --- a/lightllm/server/multi_level_kv_cache/manager.py +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -10,7 +10,7 @@ import concurrent.futures from queue import Queue from lightllm.server.core.objs import ShmReqManager, Req, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.server.io_struct import GenerateReqIndex from lightllm.utils.graceful_utils import graceful_registry from .cpu_cache_client import CpuKvCacheClient from lightllm.utils.log_utils import init_logger @@ -51,7 +51,7 @@ def cpu_cache_hanle_loop(self): logger.exception(str(e)) return - def _handle_group_req_cpu_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float): + def _handle_group_req_cpu_cache_match(self, group_req_indexes: GenerateReqIndex, start_time: float): """ match cpu cache pages """ @@ -110,8 +110,8 @@ def recv_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + recv_obj: GenerateReqIndex = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + assert isinstance(recv_obj, GenerateReqIndex) recv_objs.append(recv_obj) start_time = recv_obj.time_mark diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 066fe5cc2a..9a1529a06c 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -54,9 +54,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -112,9 +114,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -162,3 +166,10 @@ def to_origin_dict(self): ret = {} ret["images"] = [i.to_origin_dict() for i in self.images] return ret + + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index b6132b12f8..ee3b7a957e 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -18,7 +18,6 @@ from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue from lightllm.server.core.objs.io_objs import ( - GroupReqIndexes, AbortedReqCmd, StopStrMatchedReqCmd, ) @@ -31,6 +30,7 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name @@ -385,8 +385,13 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]: ans = [] if self.running_batch is None: return ans - for req in self.running_batch.reqs: - if req.is_aborted and req._router_aborted is False: + aborted_req_mask = torch.tensor( + [req.is_aborted for req in self.running_batch.reqs], dtype=torch.bool, device="cpu" + ) + if self.is_multinode_tp: + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + for req, is_aborted in zip(self.running_batch.reqs, aborted_req_mask.numpy()): + if is_aborted and req._router_aborted is False: req._router_aborted = True ans.append(req) return ans @@ -435,7 +440,7 @@ def get_used_tokens(self, dp_index): else: return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) - def _add_req(self, group_req_indexes: GroupReqIndexes): + def _add_req(self, group_req_indexes: BaseReq): req_group = [] for req_index in group_req_indexes.shm_req_indexes: req = self.shm_req_manager.get_req_obj_by_index(req_index) @@ -481,9 +486,22 @@ def _multinode_tp_generate_new_batch(self): dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) req_id_select_mark = [1 for _ in range(len(req_ids))] req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") + # TODO: 这里可以合成一个 allreudce,req_id_select_mark + aborted_req_mask dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + aborted_req_mask = torch.tensor( + [req.is_aborted for req in new_batch.reqs], dtype=torch.bool, device="cpu" + ) + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) back_req_list = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + # 释放多节点abort 请求,如果select == 0, is_aborted 一定为False + if is_aborted and select == 1: + req = new_batch.pop_req(req_id) + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue if select == 0: req = new_batch.pop_req(req_id) back_req_list.append(req) @@ -499,23 +517,28 @@ def _multinode_tp_generate_new_batch(self): else: req_ids = [None for _ in range(req_num)] dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) - all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + # all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list} req_id_select_mark = [] + aborted_req_mask = [] for req_id in req_ids: - req_id_select_mark.append(1 if req_id in all_req_id_set else 0) + req_id_select_mark.append(1 if req_id in id_to_req_obj else 0) + aborted_req_mask.append(id_to_req_obj[req_id].is_aborted if req_id in id_to_req_obj else False) req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) - select_req_ids = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): - if select == 1: - select_req_ids.append(req_id) - + aborted_req_mask = torch.tensor(aborted_req_mask, dtype=torch.bool, device="cpu") + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) select_reqs = [] - for req_id in select_req_ids: - for req in self.req_queue.waiting_req_list: - if req.request_id == req_id: - select_reqs.append(req) - + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + if select == 1: + req = id_to_req_obj[req_id] + if is_aborted: + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue + select_reqs.append(req) for req in select_reqs: self.req_queue.waiting_req_list.remove(req) if select_reqs: @@ -538,11 +561,9 @@ async def _recv_new_reqs_and_schedule(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) - else: - assert False, f"Error Req Inf {recv_req}" # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 3fe3f5136d..7bb01538d4 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -477,6 +477,8 @@ def update_finish_status(self, eos_ids, output_len: int): self.finish_status.set_status(FinishStatus.FINISHED_STOP) elif output_len >= self.sampling_param.shm_param.max_new_tokens: self.finish_status.set_status(FinishStatus.FINISHED_LENGTH) + elif self.infer_aborted: + self.finish_status.set_status(FinishStatus.FINISHED_ABORTED) return def _stop_sequences_matched(self, output_len: int): diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 36aefae6e7..d7ef06828b 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -34,6 +34,17 @@ def free_aborted_req_cpu_cache_pages(self, req: Req): req.cpu_cache_match_page_indexes.clear() self.router.cpu_cache_client.lock.release() + def free_aborted_req(self, req: Req): + # 为了让http server 能正常返回请求,还没有开始推理的请求,直接设置结束,返回空字符串 + input_len = req.input_len + req.link_prompt_ids_shm_array() + req.link_logprobs_shm_array() + req.candetoken_out_len = 1 + req.finish_token_index = input_len + req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0] + req.shm_logprobs.arr[input_len] = 0 + req.finish_status.set_status(FinishStatus.FINISHED_ABORTED) + def extend(self, req_group: List[Req]): for req in req_group: req.sample_params.suggested_dp_index = self.dp_index diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index ae7c90b335..ed2a5dbb12 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -89,7 +89,7 @@ def generate_new_batch(self, current_batch: Batch): aborted_count = 0 cur_group_reqs = [] for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: aborted_count += 1 abort_req_list.append(req) continue @@ -111,7 +111,7 @@ def generate_new_batch(self, current_batch: Batch): ok_insert, new_batch_first_router_need_tokens = self._can_add_new_group_reqs( cur_group_reqs, is_busy, new_batch_first_router_need_tokens ) - if ok_insert: + if ok_insert and False: can_run_list.extend(cur_group_reqs) new_batch = None @@ -120,6 +120,7 @@ def generate_new_batch(self, current_batch: Batch): for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 0d870b55d8..9449798e9c 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -79,8 +79,8 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. + if req.is_aborted and not self.router.is_multinode_tp: + # 由于管理的复杂性,只有没有被调度运行过的单节点请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 abort_req_list.append(req) @@ -97,6 +97,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py index f2658159b4..842b93648b 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py @@ -70,7 +70,7 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 @@ -88,6 +88,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py index e0da134875..3dea3cf955 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py @@ -38,7 +38,7 @@ def generate_new_batch(self, current_batch: Batch): abort_req_list = [] aborted_count = 0 for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 aborted_count += 1 @@ -53,6 +53,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index a73823b8b7..e5f731df5f 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -27,6 +27,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: self.reqs_waiting_for_dp_index: List[List[Req]] = [] return + def free_aborted_req(self, req: Req): + dp_index = req.sample_params.suggested_dp_index + assert dp_index >= 0 and dp_index < self.dp_size_in_node + self.inner_queues[dp_index].free_aborted_req(req) + return + def get_dp_queue(self, dp_index: int): assert dp_index < self.dp_size_in_node, "dp index out of range" return self.inner_queues[dp_index] diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index b7e1ac10cd..fa3ac0c991 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -8,7 +8,6 @@ import inspect import setproctitle from typing import List -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -18,6 +17,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from rpyc.utils.classic import obtain @@ -48,7 +48,7 @@ def __init__( self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port - self.waiting_reqs: List[GroupReqIndexes] = [] + self.waiting_reqs: List[BaseReq] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp self.vit_dp = args.visual_dp @@ -171,11 +171,12 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + # 目前只有 GenerateReqIndex 会进入这个队列,判断是否需要推理图片 + if isinstance(recv_req, GenerateReqIndex): self.waiting_reqs.append(recv_req) else: - assert False, f"Error Req Inf {recv_req}" + self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL) self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 From 4095831bd165b8a605ddb0031c9ec2df754a7769 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:11:51 +0800 Subject: [PATCH 003/180] flush cache mulit node (#1116) --- lightllm/server/api_http.py | 1 - lightllm/server/detokenization/manager.py | 19 ++- lightllm/server/httpserver/manager.py | 148 ++++++++++++---------- lightllm/server/io_struct.py | 15 +++ lightllm/server/router/manager.py | 77 ++++++----- lightllm/server/router/mananger_rpc.py | 43 ------- 6 files changed, 161 insertions(+), 142 deletions(-) delete mode 100644 lightllm/server/router/mananger_rpc.py diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 2c85488736..0a8841f946 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -382,7 +382,6 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) - g_objs.httpserver_manager.connect_router_rpc() loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index ab5f706b97..7548342cd5 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -16,7 +16,11 @@ import time from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.server.io_struct import BaseReq +from lightllm.server.io_struct import ( + BaseReq, + GenerateResp, + FlushCacheResp, +) logger = init_logger(__name__) @@ -31,9 +35,9 @@ def __init__( self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}") - self.pub_to_httpserver = context.socket(zmq.PUB) - self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") - logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") + self.send_to_httpserver = context.socket(zmq.PUSH) + self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") + logger.info(f"send_to_httpserver sendhwm {self.send_to_httpserver.getsockopt(zmq.SNDHWM)}") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out: Dict[int, DecodeReq] = {} @@ -75,6 +79,11 @@ def handle_loop(self): # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_obj, FlushCacheResp): + print("Detokenization receive flush cache request", flush=True) + self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) + print("Detokenization send flush cache request to httpserver", flush=True) + continue self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 @@ -145,7 +154,7 @@ def gen_token_out(self): # 通知 httpserver 进程 if exist_decode: - self.pub_to_httpserver.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_httpserver.send_pyobj(GenerateResp(), protocol=pickle.HIGHEST_PROTOCOL) self.remove_finished_reqs() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 5254e2097d..0dab8fc8cc 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -33,7 +33,10 @@ from lightllm.server.io_struct import ( AbortReq, BaseReq, + FlushCacheReq, + FlushCacheResp, GenerateReq, + GenerateResp, GenerateReqMeta, GenerateReqIndex, ) @@ -96,9 +99,8 @@ def __init__( self.shm_req_manager = ShmReqManager() # recv from detokenization - self.zmq_recv_socket = context.socket(zmq.SUB) + self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") - self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -120,6 +122,9 @@ def __init__( # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + # 交互式请求 event + self.flush_cache_event: Optional[asyncio.Event] = None return async def _alloc_resource(self, items, md5sums, token_nums, datas): @@ -275,13 +280,6 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert False, "dead code path" return group_request_id - def connect_router_rpc(self): - from lightllm.server.router.mananger_rpc import connect_router_rpc - - self.router_rpc_client = connect_router_rpc(self.args.router_rpc_port) - logger.info("HttpServerManager connected to Router RPC service successfully") - return - async def generate( self, prompt: Union[str, List[int]], @@ -743,64 +741,16 @@ async def handle_loop(self): while True: try: - await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) + recv_obj = await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) except asyncio.TimeoutError: - pass + recv_obj = None try: - for group_req_id_ in list(self.req_id_to_out_inf.keys()): - req_status = self.req_id_to_out_inf.get(group_req_id_, None) - if req_status is None: - continue + if recv_obj is None or isinstance(recv_obj, GenerateResp): + await self._handle_recv_generate_request(recv_obj) + elif isinstance(recv_obj, FlushCacheResp): + await self._handle_recv_flush_cache_request(recv_obj) - token_list = [] - for req in req_status.group_req_objs.shm_req_objs: - req_id = req.request_id - read_token_count = 1 - if req.out_tokens_queue.is_full(): - read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE - - for _ in range(read_token_count): - if not req.out_tokens_queue.is_empty(): - - text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() - req.cumlogprob += float(req.shm_logprobs.arr[src_index]) - metadata = { - "id": int(req.shm_prompt_ids.arr[src_index]), - "logprob": float(req.shm_logprobs.arr[src_index]), - "cumlogprob": float(req.cumlogprob) / count_output_tokens, - "special": special, - "count_output_tokens": count_output_tokens, - "prompt_cache_len": req.prompt_cache_len, - "cpu_prompt_cache_len": req.cpu_prompt_cache_len, - "mtp_accepted_token_num": req.mtp_accepted_token_num, - } - if self.args.return_all_prompt_logprobs: - metadata.update(req.get_all_prompt_metadata()) - if self.args.use_reward_model: - metadata["score"] = float(req.reward_score) - - req.out_tokens_queue.pop_no_ret() - - finished_token_index = ( - req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index - ) - - if finished_token_index != src_index: - token_list.append((req_id, text, metadata, FinishStatus())) - else: - if req.stop_str_matched: - finish_status = FinishStatus(FinishStatus.FINISHED_STOP) - else: - finish_status = FinishStatus(req.finish_status.status) - - token_list.append((req_id, text, metadata, finish_status)) - else: - break - - async with req_status.lock: - req_status.out_token_info_list.extend(token_list) - req_status.event.set() except BaseException as e: logger.exception(str(e)) raise e @@ -808,8 +758,78 @@ async def handle_loop(self): self.recycle_event.set() return + async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta): + for group_req_id_ in list(self.req_id_to_out_inf.keys()): + req_status = self.req_id_to_out_inf.get(group_req_id_, None) + if req_status is None: + continue + + token_list = [] + for req in req_status.group_req_objs.shm_req_objs: + req_id = req.request_id + read_token_count = 1 + if req.out_tokens_queue.is_full(): + read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE + + for _ in range(read_token_count): + if not req.out_tokens_queue.is_empty(): + + text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() + req.cumlogprob += float(req.shm_logprobs.arr[src_index]) + metadata = { + "id": int(req.shm_prompt_ids.arr[src_index]), + "logprob": float(req.shm_logprobs.arr[src_index]), + "cumlogprob": float(req.cumlogprob) / count_output_tokens, + "special": special, + "count_output_tokens": count_output_tokens, + "prompt_cache_len": req.prompt_cache_len, + "cpu_prompt_cache_len": req.cpu_prompt_cache_len, + "mtp_accepted_token_num": req.mtp_accepted_token_num, + } + if self.args.return_all_prompt_logprobs: + metadata.update(req.get_all_prompt_metadata()) + if self.args.use_reward_model: + metadata["score"] = float(req.reward_score) + + req.out_tokens_queue.pop_no_ret() + + finished_token_index = ( + req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index + ) + + if finished_token_index != src_index: + token_list.append((req_id, text, metadata, FinishStatus())) + else: + if req.stop_str_matched: + finish_status = FinishStatus(FinishStatus.FINISHED_STOP) + else: + finish_status = FinishStatus(req.finish_status.status) + + token_list.append((req_id, text, metadata, finish_status)) + else: + break + + async with req_status.lock: + req_status.out_token_info_list.extend(token_list) + req_status.event.set() + + async def _handle_recv_flush_cache_request(self, recv_obj: FlushCacheResp): + assert self.flush_cache_event is not None + self.flush_cache_event.success = recv_obj.success + self.flush_cache_event.set() + return + async def flush_cache(self): - ret = await self.router_rpc_client.flush_cache() + if self.flush_cache_event is None: + self.flush_cache_event = asyncio.Event() + await self.transfer_to_next_module(FlushCacheReq()) + try: + await asyncio.wait_for(self.flush_cache_event.wait(), timeout=30) + ret = self.flush_cache_event.success + except asyncio.TimeoutError: + # 超时直接返回失败 + ret = False + self.flush_cache_event.clear() return ret diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index b5adff954e..2b4b3cef49 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -59,6 +59,21 @@ def get_req_to_next_node(self): ) +@dataclass +class GenerateResp(BaseReq): + pass + + +@dataclass +class FlushCacheReq(BaseReq): + pass + + +@dataclass +class FlushCacheResp(BaseReq): + success: bool + + @dataclass class AbortReq(BaseReq): # 外部调用传入,等同内部的 group_req_id diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ee3b7a957e..64b36a1239 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -30,7 +30,12 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager -from lightllm.server.io_struct import BaseReq, GenerateReqIndex +from lightllm.server.io_struct import ( + BaseReq, + GenerateReqIndex, + FlushCacheReq, + FlushCacheResp, +) from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name @@ -152,9 +157,6 @@ async def wait_to_model_ready(self): rpc_finished_event=self.rpc_finished_event, ) - # 启动 rpyc 服务,供 HTTP Server 远程调用 - self._start_router_rpc_service() - kvargs = { "args": self.args, "rank_id": None, # 由后续处理填充真实数据 @@ -235,25 +237,6 @@ async def wait_to_model_ready(self): return - def _start_router_rpc_service(self): - """launch a rpyc service for httpserver to call RouterManager""" - import threading - from rpyc.utils.server import ThreadedServer - import lightllm.utils.rpyc_fix_utils as _ - from .mananger_rpc import RouterRpcService - - service = RouterRpcService(self) - port = self.args.router_rpc_port - - def start_server(): - t = ThreadedServer(service, port=port, protocol_config={"allow_pickle": True}) - t.start() - - rpc_thread = threading.Thread(target=start_server, daemon=True) - rpc_thread.start() - logger.info(f"Router RPC service started successfully on port {port}") - return - def _get_schedule_time_interval(self): # dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求 return self.schedule_time_interval @@ -559,11 +542,15 @@ async def _recv_new_reqs_and_schedule(self): self.recv_max_count = 64 try: + # 多机tp需要广播给其他node的请求 + special_reqs = [] # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) + elif isinstance(recv_req, FlushCacheReq): + special_reqs.append(recv_req) # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) @@ -572,6 +559,8 @@ async def _recv_new_reqs_and_schedule(self): # 当队列已经开始清空的时候,将一次接受的数量下调 self.recv_max_count = 64 + self._process_special_reqs(special_reqs) + if self.is_multinode_tp: self._multinode_tp_generate_new_batch() else: @@ -579,16 +568,46 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return + def _process_special_reqs(self, special_reqs: List[BaseReq]): + if self.is_multinode_tp: + special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs) + for req in special_reqs: + if isinstance(req, FlushCacheReq): + self.flush_cache() + + def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): + req_num = len(reqs) + if self.node_rank == 0: + req_nums = [len(reqs)] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + else: + req_nums = [None] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + reqs = [None for _ in range(req_num)] + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + return reqs + def flush_cache(self) -> bool: - if self.running_batch is not None: - return False - if self.req_queue.get_wait_req_num() > 0: - return False # if radix cache client is not initialized, just return True if self.radix_cache_client is None: - return True + success = True # only flush cache when no running batch and no waiting requests - return self.model_rpc_client.flush_radix_cache() + elif self.running_batch is not None or self.req_queue.get_wait_req_num() > 0: + success = False + else: + success = self.model_rpc_client.flush_radix_cache() + + if self.is_multinode_tp: + # 等待其他节点的flush 结果 + dist.barrier(group=self.mulitnode_group) + if self.is_multinode_tp_master: + self.send_to_detokenization.send_pyobj(FlushCacheResp(success=success), protocol=pickle.HIGHEST_PROTOCOL) + return success def clean_up(self): return diff --git a/lightllm/server/router/mananger_rpc.py b/lightllm/server/router/mananger_rpc.py deleted file mode 100644 index 60f9e0458b..0000000000 --- a/lightllm/server/router/mananger_rpc.py +++ /dev/null @@ -1,43 +0,0 @@ -import rpyc -import asyncio -import socket -from .manager import RouterManager - - -class RouterRpcService(rpyc.Service): - def __init__(self, router_manager: "RouterManager"): - super().__init__() - self.router_manager = router_manager - return - - def exposed_flush_cache(self) -> bool: - return self.router_manager.flush_cache() - - -class RouterRpcClient: - def __init__(self, router_rpc_conn): - self.router_rpc_conn = router_rpc_conn - - def async_wrap(f): - f = rpyc.async_(f) - - async def _func(*args, **kwargs): - ans = f(*args, **kwargs) - await asyncio.to_thread(ans.wait) - # raise if exception - return ans.value - - return _func - - self._flush_cache = async_wrap(self.router_rpc_conn.root.flush_cache) - return - - async def flush_cache(self) -> bool: - ans = await self._flush_cache() - return ans - - -def connect_router_rpc(port: int) -> RouterRpcClient: - router_rpc_conn = rpyc.connect("localhost", port, config={"allow_pickle": True}) - router_rpc_conn._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - return RouterRpcClient(router_rpc_conn) From ca9325fd0cd590adec3032c4d2024e5337327e18 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:56:32 +0800 Subject: [PATCH 004/180] [bugfix]: flush cache in single node (#1118) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightllm/server/router/manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 64b36a1239..35fd861d8a 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -592,7 +592,7 @@ def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) return reqs - def flush_cache(self) -> bool: + def flush_cache(self) -> None: # if radix cache client is not initialized, just return True if self.radix_cache_client is None: success = True @@ -605,9 +605,9 @@ def flush_cache(self) -> bool: if self.is_multinode_tp: # 等待其他节点的flush 结果 dist.barrier(group=self.mulitnode_group) - if self.is_multinode_tp_master: + if self.node_rank == 0: self.send_to_detokenization.send_pyobj(FlushCacheResp(success=success), protocol=pickle.HIGHEST_PROTOCOL) - return success + return def clean_up(self): return From 99489258234bb2256b5908285aa5230c70af00bf Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Wed, 19 Nov 2025 22:03:35 +0800 Subject: [PATCH 005/180] add pause and continue (#1120) --- lightllm/server/api_http.py | 12 ++++++++++++ lightllm/server/httpserver/manager.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 0a8841f946..07fcc41396 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -316,6 +316,18 @@ async def flush_cache(): ) +@app.post("/pause_generation") +async def pause_generation(): + await g_objs.httpserver_manager.pause_generation() + return Response(content="Generation paused successfully.", status_code=200) + + +@app.post("/continue_generation") +async def continue_generation(): + await g_objs.httpserver_manager.continue_generation() + return Response(content="Generation continued successfully.", status_code=200) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0dab8fc8cc..765b44eea0 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -123,6 +123,9 @@ def __init__( self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + self.is_pause = False + self.is_pause_cond = asyncio.Condition() + # 交互式请求 event self.flush_cache_event: Optional[asyncio.Event] = None return @@ -302,6 +305,10 @@ async def generate( # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) + + async with self.is_pause_cond: + await self.is_pause_cond.wait_for(lambda: not self.is_pause) + # encode prompt_ids = await self._encode(prompt, multimodal_params, sampling_params) @@ -832,6 +839,23 @@ async def flush_cache(self): self.flush_cache_event.clear() return ret + async def pause_generation(self): + # 因为请求是从master node转发到slave node的 + # 所以只要master暂停了,slave自然暂停。 + async with self.is_pause_cond: + self.is_pause = True + while True: + await self.abort_request(AbortReq(request_id=None, abort_all=True)) + running_req_num = len(list(self.req_id_to_out_inf.keys())) + if running_req_num == 0: + break + await asyncio.sleep(1.0) + + async def continue_generation(self): + async with self.is_pause_cond: + self.is_pause = False + self.is_pause_cond.notify_all() + class ReqStatus: def __init__( From 4b32287d662c8501d9c7c2a1f193c265d01fc6be Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Fri, 21 Nov 2025 17:39:17 +0800 Subject: [PATCH 006/180] add launch_server and StartArgs (#1119) --- lightllm/server/api_http.py | 15 ++++- lightllm/server/api_server.py | 25 ++++++-- lightllm/server/api_start.py | 33 ++++++++--- lightllm/server/core/objs/start_args_type.py | 62 ++++++++++++++------ lightllm/utils/device_utils.py | 2 +- 5 files changed, 106 insertions(+), 31 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 07fcc41396..b96cf9306c 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -33,7 +33,7 @@ import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Union from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -131,6 +131,19 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} +@app.get("/get_server_info") +@app.post("/get_server_info") +def get_server_info(): + # 将 StartArgs 转换为字典格式 + from dataclasses import asdict + server_info: dict[str, Any] = asdict(g_objs.args) + return {**server_info} + +@app.get("/get_weight_version") +@app.post("/get_weight_version") +def get_weight_version(): + return {"weight_version": g_objs.args.weight_version} + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808a..dd531f58d4 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,11 +1,21 @@ import torch from .api_cli import make_argument_parser +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.log_utils import init_logger -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess - parser = make_argument_parser() - args = parser.parse_args() +logger = init_logger(__name__) + +def launch_server(args: StartArgs): from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + + try: + # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn") + except RuntimeError as e: + logger.warning(f"Failed to set start method: {e}") + except Exception as e: + logger.error(f"Failed to set start method: {e}") + raise e if args.run_mode == "pd_master": pd_master_start(args) @@ -13,3 +23,10 @@ config_server_start(args) else: normal_or_p_d_start(args) + + +if __name__ == "__main__": + parser = make_argument_parser() + args = parser.parse_args() + + launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 138b0a599b..6a02dda17b 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -16,6 +16,7 @@ from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.shm_size_check import check_recommended_shm_size +from lightllm.server.core.objs.start_args_type import StartArgs logger = init_logger(__name__) @@ -51,20 +52,38 @@ def signal_handler(sig, frame): process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) + elif sig == signal.SIGHUP: + logger.info("Received SIGHUP (terminal closed), shutting down gracefully...") + if http_server_process and http_server_process.poll() is None: + http_server_process.send_signal(signal.SIGTERM) + + start_time = time.time() + while (time.time() - start_time) < 60: + if not is_process_active(http_server_process.pid): + logger.info("httpserver exit") + break + time.sleep(1) + + if time.time() - start_time < 60: + logger.info("HTTP server has exited gracefully") + else: + logger.warning("HTTP server did not exit in time, killing it...") + kill_recursive(http_server_process) + + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully due to terminal closure.") + sys.exit(0) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) logger.info(f"start process pid {os.getpid()}") logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): - from lightllm.server.core.objs.start_args_type import StartArgs - - args: StartArgs = args - +def normal_or_p_d_start(args: StartArgs): set_unique_server_name(args) if not args.disable_shm_warning: @@ -370,7 +389,7 @@ def normal_or_p_d_start(args): return -def pd_master_start(args): +def pd_master_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -433,7 +452,7 @@ def pd_master_start(args): http_server_process.wait() -def config_server_start(args): +def config_server_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "config_server": return diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 659aab1dc7..40f68a7439 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -1,48 +1,52 @@ from dataclasses import dataclass, field from typing import List, Optional, Tuple -# 只是为了更好的编程提示 +# 服务启动参数 @dataclass class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + metadata={"choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"]}, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) + httpserver_workers: int = field(default=1) zmq_mode: str = field( default="ipc:///tmp/", metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"}, ) - pd_master_ip: str = field(default="127.0.0.1") + pd_master_ip: str = field(default="0.0.0.0") pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) pd_decode_rpyc_port: int = field(default=42000) - select_p_d_node_strategy: str = field(default=None) + select_p_d_node_strategy: str = field( + default="round_robin", + metadata={"choices": ["random", "round_robin", "adaptive_load"]} + ) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) - tokenizer_mode: str = field(default="slow") + tokenizer_mode: str = field(default="fast") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) batch_max_tokens: Optional[int] = field(default=None) - eos_id: List[int] = field(default_factory=list) + eos_id: Optional[List[int]] = field(default=None) tool_call_parser: Optional[str] = field( - default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]} + default=None, metadata={"choices": ["qwen25", "llama3", "mistral", "deepseekv3", "qwen"]} ) running_max_req_size: int = field(default=1000) tp: int = field(default=1) dp: int = field(default=1) nnodes: int = field(default=1) node_rank: int = field(default=0) - max_req_total_len: int = field(default=2048 + 1024) + max_req_total_len: int = field(default=16384) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) use_config_server_to_init_nccl: bool = field(default=False) - mode: List[str] = field(default_factory=list) + mode: List[str] = field(default_factory=lambda: []) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) @@ -51,11 +55,14 @@ class StartArgs: router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) - chunked_prefill_size: int = field(default=8192) + chunked_prefill_size: int = field(default=4096) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) + output_constraint_mode: str = field( + default="none", + metadata={"choices": ["outlines", "xgrammar", "none"]} + ) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) @@ -74,10 +81,10 @@ class StartArgs: health_monitor: bool = field(default=False) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") - grouping_key: List[str] = field(default_factory=list) + grouping_key: List[str] = field(default_factory=lambda: []) push_interval: int = field(default=10) visual_infer_batch_size: int = field(default=1) - visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) + visual_gpu_ids: Optional[List[int]] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) @@ -86,10 +93,10 @@ class StartArgs: graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) - graph_max_len_in_batch: int = field(default=8192) - quant_type: Optional[str] = field(default=None) + graph_max_len_in_batch: int = field(default=0) + quant_type: Optional[str] = field(default="none") quant_cfg: Optional[str] = field(default=None) - vit_quant_type: Optional[str] = field(default=None) + vit_quant_type: Optional[str] = field(default="none") vit_quant_cfg: Optional[str] = field(default=None) enable_flashinfer_prefill: bool = field(default=False) enable_flashinfer_decode: bool = field(default=False) @@ -99,7 +106,10 @@ class StartArgs: ) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) - mtp_mode: Optional[str] = field(default=None) + mtp_mode: Optional[str] = field( + default=None, + metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]} + ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) @@ -108,7 +118,7 @@ class StartArgs: pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) - cpu_cache_token_page_size: int = field(default=64) + cpu_cache_token_page_size: int = field(default=256) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) # zmp ports @@ -128,3 +138,19 @@ class StartArgs: # kernel setting enable_fa3: bool = field(default=False) + + httpserver_workers: int = field(default=1) + disable_shm_warning: bool = field(default=False) + dp_balancer: str = field( + default="bs_balancer", + metadata={"choices": ["round_robin", "bs_balancer"]} + ) + enable_custom_allgather: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) + enable_mps: bool = field(default=False) + multinode_router_gloo_port: int = field(default=20001) + schedule_time_interval: float = field(default=0.03) + use_dynamic_prompt_cache: bool = field(default=False) + disable_custom_allreduce: bool = field(default=False) + + weight_version: str = "default" diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index b4d1ba6298..d2b6d06a8a 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -85,7 +85,7 @@ def get_current_device_name(): gpu_name = torch.cuda.get_device_name(device).replace(" ", "_") return gpu_name else: - return None + raise RuntimeError("No GPU available") @lru_cache(maxsize=None) From 27abcf536894eb5e6f91871a5120debd07adabde Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 1 Dec 2025 16:39:35 +0800 Subject: [PATCH 007/180] Update weight (#1127) Co-authored-by: Weichao Luo Co-authored-by: shihaobai <1798930569@qq.com> --- lightllm/common/basemodel/basemodel.py | 7 + lightllm/server/api_http.py | 42 +++- lightllm/server/detokenization/manager.py | 5 + lightllm/server/httpserver/manager.py | 73 +++++++ lightllm/server/io_struct.py | 85 +++++++- lightllm/server/router/manager.py | 21 +- .../model_infer/mode_backend/base_backend.py | 181 +++++++++++++++++- .../server/router/model_infer/model_rpc.py | 20 ++ lightllm/utils/dist_utils.py | 74 ++++++- lightllm/utils/patch_torch.py | 65 +++++++ lightllm/utils/serializer.py | 132 +++++++++++++ lightllm/utils/tensor_bucket.py | 108 +++++++++++ 12 files changed, 805 insertions(+), 8 deletions(-) create mode 100644 lightllm/utils/patch_torch.py create mode 100644 lightllm/utils/serializer.py create mode 100644 lightllm/utils/tensor_bucket.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 77ca299b2a..1221f1939b 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -178,6 +178,13 @@ def _init_weights(self): [weight.verify_load() for weight in self.trans_layers_weight] return + def load_weights(self, weight_dict: dict): + load_hf_weights(self.data_type, + self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=weight_dict) + def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 self.mem_manager = MemoryManager( diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index b96cf9306c..28ae93dfb2 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -58,7 +58,14 @@ CompletionRequest, CompletionResponse, ) -from .io_struct import AbortReq +from .io_struct import ( + AbortReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, + GeneralModelToHttpRpcRsp +) from .build_prompt import build_prompt, init_tokenizer logger = init_logger(__name__) @@ -315,6 +322,39 @@ async def abort_request(request: AbortReq, raw_request: Request): return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") +async def handle_request_common(request_obj, handler): + try: + ret: GeneralModelToHttpRpcRsp = await handler(request_obj) + if ret.success: + return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200) + else: + return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) + except Exception as e: + return create_error_response( + HTTPStatus.EXPECTATION_FAILED, + f"error: {str(e)}" + ) + +@app.post("/init_weights_update_group") +async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request): + """Init weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group) + +@app.post("/destroy_weights_update_group") +async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request): + """Destroy weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group) + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed) + +@app.post("/update_weights_from_tensor") +async def update_weights_from_distributed(request: UpdateWeightsFromTensorReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor) + @app.post("/flush_cache") @app.get("/flush_cache") async def flush_cache(): diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 7548342cd5..b7ba960258 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -20,6 +20,7 @@ BaseReq, GenerateResp, FlushCacheResp, + GeneralModelToHttpRpcRsp ) logger = init_logger(__name__) @@ -84,6 +85,10 @@ def handle_loop(self): self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) print("Detokenization send flush cache request to httpserver", flush=True) continue + elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): + self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) + print(f"Detokenization send {recv_obj.func_name} request to httpserver") + continue self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 765b44eea0..44721cfe3b 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import inspect from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -39,6 +40,17 @@ GenerateResp, GenerateReqMeta, GenerateReqIndex, + InitWeightsUpdateGroupReq, + InitWeightsUpdateGroupRsp, + DestroyWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupRsp, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromDistributedRsp, + UpdateWeightsFromTensorReq, + UpdateWeightsFromTensorRsp, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp + ) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size @@ -128,6 +140,7 @@ def __init__( # 交互式请求 event self.flush_cache_event: Optional[asyncio.Event] = None + self.async_events_per_func: Dict[str, asyncio.Event] = {} return async def _alloc_resource(self, items, md5sums, token_nums, datas): @@ -757,6 +770,8 @@ async def handle_loop(self): await self._handle_recv_generate_request(recv_obj) elif isinstance(recv_obj, FlushCacheResp): await self._handle_recv_flush_cache_request(recv_obj) + elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): + await self._handle_recv_general_model_to_http_request(recv_obj) except BaseException as e: logger.exception(str(e)) @@ -826,6 +841,13 @@ async def _handle_recv_flush_cache_request(self, recv_obj: FlushCacheResp): self.flush_cache_event.set() return + async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralModelToHttpRpcRsp): + assert recv_obj.func_name is not None + event = await self.get_event_for_func(recv_obj.func_name) + event.result = recv_obj + event.set() + return + async def flush_cache(self): if self.flush_cache_event is None: self.flush_cache_event = asyncio.Event() @@ -856,6 +878,57 @@ async def continue_generation(self): self.is_pause = False self.is_pause_cond.notify_all() + async def get_event_for_func(self, func_name: str) -> asyncio.Event: + if func_name not in self.async_events_per_func: + self.async_events_per_func[func_name] = asyncio.Event() + return self.async_events_per_func[func_name] + + async def http_to_model_special_request(self, request: GeneralHttpToModelRpcReq, timeout: int=300) -> GeneralModelToHttpRpcRsp: + event = await self.get_event_for_func(request.func_name) + await self.transfer_to_next_module(request) + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + ret = event.result + + except asyncio.TimeoutError: + ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) + except Exception as e: + ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name) + return ret + + + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self.http_to_model_special_request(GeneralHttpToModelRpcReq( + func_name="init_weights_update_group", func_args=request)) + + + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self.http_to_model_special_request(GeneralHttpToModelRpcReq( + func_name="destroy_weights_update_group", func_args=request)) + + + async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + + if request.flush_cache: + await self.flush_cache() + + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request)) + + async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> Tuple[bool, str]: + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + + if request.flush_cache: + await self.flush_cache() + + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="update_weights_from_tensor", func_args=request) + ) + class ReqStatus: def __init__( diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 2b4b3cef49..7947c7a581 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -3,7 +3,7 @@ from lightllm.server.core.objs.req import Req from lightllm.server.core.objs.sampling_params import SamplingParams from lightllm.server.multimodal_params import MultimodalParams -from typing import List +from typing import List, Optional, Any, Union @dataclass @@ -14,6 +14,10 @@ def get_req_to_next_node(self): def get_req_to_next_module(self): return self +@dataclass +class BaseRsp(ABC): + success: bool + msg: Optional[str] # for next node @dataclass @@ -79,3 +83,82 @@ class AbortReq(BaseReq): # 外部调用传入,等同内部的 group_req_id request_id: int = None abort_all: bool = False + + +@dataclass +class GeneralHttpToModelRpcReq(BaseReq): + func_name: str + func_args: Optional[Any] = None + +@dataclass +class GeneralModelToHttpRpcRsp(BaseRsp): + func_name: str + func_rsp: Optional[Any] = None + +@dataclass +class InitWeightsUpdateGroupReq(BaseReq): + # The master address + master_address: str + # The master port + master_port: int + # The rank offset + rank_offset: int + # The world size + world_size: int + # The group name + group_name: str = "weight_update_group" + # The backend + backend: str = "nccl" + +@dataclass +class InitWeightsUpdateGroupRsp(BaseRsp): + pass + +@dataclass +class DestroyWeightsUpdateGroupReq(BaseReq): + group_name: str = "weight_update_group" + +@dataclass +class DestroyWeightsUpdateGroupRsp(BaseRsp): + pass + +@dataclass +class UpdateWeightsFromDistributedReq(BaseReq): + names: List[str] + dtypes: List[str] + shapes: List[List[int]] + # The group name + group_name: str = "weight_update_group" + # Whether to flush the cache after updating weights + flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + +@dataclass +class UpdateWeightsFromDistributedRsp(BaseRsp): + pass + + +@dataclass +class UpdateWeightsFromTensorReq(BaseReq): + """Update model weights from tensor input. + + - Tensors are serialized for transmission + - Data is structured in JSON for easy transmission over HTTP + """ + + serialized_named_tensors: List[Union[str, bytes]] + # Optional format specification for loading + load_format: Optional[str] = None + # Whether to flush the cache after updating weights + flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + +@dataclass +class UpdateWeightsFromTensorRsp(BaseRsp): + pass \ No newline at end of file diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 35fd861d8a..ea1f17b90f 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -35,12 +35,13 @@ GenerateReqIndex, FlushCacheReq, FlushCacheResp, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp ) from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name - logger = init_logger(__name__) @@ -549,7 +550,7 @@ async def _recv_new_reqs_and_schedule(self): recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) - elif isinstance(recv_req, FlushCacheReq): + elif isinstance(recv_req, (FlushCacheReq, GeneralHttpToModelRpcReq)): special_reqs.append(recv_req) # 当队列中存在较多的请求时,将一次接受的数量上调 @@ -574,6 +575,8 @@ def _process_special_reqs(self, special_reqs: List[BaseReq]): for req in special_reqs: if isinstance(req, FlushCacheReq): self.flush_cache() + elif isinstance(req, (GeneralHttpToModelRpcReq)): + self.forward_to_model(req) def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): req_num = len(reqs) @@ -609,6 +612,20 @@ def flush_cache(self) -> None: self.send_to_detokenization.send_pyobj(FlushCacheResp(success=success), protocol=pickle.HIGHEST_PROTOCOL) return + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: + ret = self.model_rpc_client.forward_to_model(req) + if self.is_multinode_tp: + output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None + dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) + for res in output_list: + res : GeneralModelToHttpRpcRsp + if not res.success: + ret = res + break + + if self.node_rank == 0: + self.send_to_detokenization.send_pyobj(ret, protocol=pickle.HIGHEST_PROTOCOL) + def clean_up(self): return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index db708f3cf5..60431c8ff1 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,7 +4,7 @@ import time import threading import torch.distributed as dist -from typing import List, Tuple, Callable, Optional +from typing import List, Tuple, Callable, Optional, Union from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger @@ -16,7 +16,7 @@ from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify -from lightllm.utils.dist_utils import init_distributed_env +from lightllm.utils.dist_utils import init_distributed_env, init_custom_process_group from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd @@ -31,6 +31,9 @@ enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) +from lightllm.utils.serializer import LocalSerializedTensor, MultiprocessingSerializer +from lightllm.utils.patch_torch import monkey_patch_torch_reductions +from lightllm.utils.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata from lightllm.distributed import dist_group_manager from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack @@ -39,6 +42,12 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.server.io_struct import ( + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq +) class ModeBackend: @@ -112,6 +121,8 @@ def init_model(self, kvargs): ) dist_group_manager.create_groups(group_size=group_size) # set the default group + self._model_update_group = {} + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 @@ -293,6 +304,172 @@ def flush_radix_cache(self): self.radix_cache.flush_cache() return + def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + assert ( + torch.distributed.is_initialized() + ), "Default torch process group must be initialized" + + assert request.group_name != "", "Group name cannot be empty" + rank = request.rank_offset + self.rank_in_dp + self.logger.info( + f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, " + f"rank_offset={request.rank_offset}, rank={rank}, world_size={request.world_size}, group_name={request.group_name}, " + f" backend={request.backend}" + ) + + try: + if request.group_name in self._model_update_group: + raise ValueError( + f"Process group with name {request.group_name} already exists." + ) + + self._model_update_group[request.group_name] = init_custom_process_group( + backend=request.backend, + init_method=f"tcp://{request.master_address}:{request.master_port}", + world_size=request.world_size, + rank=rank, + group_name=request.group_name, + ) + return True, "Succeeded to initialize custom process group." + + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + self.logger.error(message) + return False, message + + def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + try: + if request.group_name in self._model_update_group: + pg = self._model_update_group.pop(request.group_name) + torch.distributed.destroy_process_group(pg) + return True, "Succeeded to destroy custom process group." + else: + return False, "The group to be destroyed does not exist." + except Exception as e: + message = f"Failed to destroy custom process group: {e}." + self.logger.error(message) + return False, message + + def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + """ + Update specific parameter in the model weights online + through `_model_update_group` process group. + + Args: + name: the name of the parameter to be updated. + dtype: the data type of the parameter to be updated. + shape: the shape of the parameter to be updated. + """ + + assert request.group_name in self._model_update_group, ( + f"Group {request.group_name} not in {list(self._model_update_group.keys())}. " + "Please call `init_weights_update_group` first." + ) + + try: + weights = [] + handles = [] + for name, dtype, shape in zip(request.names, request.dtypes, request.shapes): + target_dtype = ( + dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + ) + weight = torch.empty(shape, dtype=target_dtype, device='cuda') + handles.append( + torch.distributed.broadcast( + weight, + src=0, + group=self._model_update_group[request.group_name], + async_op=True, + ) + ) + weights.append((name, weight)) + for handle in handles: + handle.wait() + + self.model.load_weights(weights) + return True, "Succeeded to update parameter online from distributed." + + except Exception as e: + error_msg = ( + f"Failed to update parameter online: {e}. " + f"The full weights of the ModelRunner are partially updated. " + f"Please discard the whole weights." + ) + self.logger.error(error_msg) + return False, error_msg + + def _update_weights_from_flattened_bucket( + self, + flattened_tensor_bucket_dict, + ): + """Handle flattened bucket format for weight updates""" + flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"] + metadata = flattened_tensor_bucket_dict["metadata"] + + # Convert metadata dict to our format + converted_metadata = [] + for meta in metadata: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=meta.shape, + dtype=meta.dtype, + start_idx=meta.start_idx, + end_idx=meta.end_idx, + numel=meta.numel, + ) + converted_metadata.append(converted_meta) + + # Create bucket and reconstruct tensors + bucket = FlattenedTensorBucket( + flattened_tensor=flattened_tensor, metadata=converted_metadata + ) + reconstructed_tensors = bucket.reconstruct_tensors() + + # Load the reconstructed tensors using the standard method + self.model.load_weights(reconstructed_tensors) + + return True, "Succeeded to update parameter online from flattened bucket tensor." + + def update_weights_from_tensor( + self, + request: UpdateWeightsFromTensorReq + ): + try: + monkey_patch_torch_reductions() + if request.load_format == "flattened_bucket": + # Handle flattened bucket format + return self._update_weights_from_flattened_bucket( + flattened_tensor_bucket_dict=request.named_tensors + ) + + # We need to get device after patch otherwise the device would be wrong + self.device_module = torch.get_device_module("cuda") + infered_device = self.device_module.current_device() + + named_tensors=MultiprocessingSerializer.deserialize( + request.serialized_named_tensors[self.rank_in_dp] + ) + + def _unwrap_tensor(tensor, tp_rank, device): + if isinstance(tensor, LocalSerializedTensor): + tensor = tensor.get(tp_rank) + return tensor.to(device) + + named_tensors = { + name : _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) + for name, tensor in named_tensors + } + + self.model.load_weights(named_tensors) + + return True, "Succeeded to update parameter online from tensor." + + except Exception as e: + message = f"Failed to update parameter online from tensor. Reason: {e}." + self.logger.error(message) + + return False, message + def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): """ 这个函数会把next token id和logprobs保存到pinned memory中 diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index b7797a7628..04e7495adf 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -33,6 +33,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp logger = init_logger(__name__) @@ -190,6 +191,16 @@ def flush_radix_cache(self): logger.exception(f"flush radix cache failed: {str(e)}") return False + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + try: + if self.backend is None or not hasattr(self.backend, req.func_name): + raise ValueError(f"Backend does not support function {req.func_name}") + success, ret = getattr(self.backend, req.func_name)(req.func_args) + return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret) + except BaseException as e: + logger.exception(f"forward to model backend failed: {str(e)}") + return GeneralModelToHttpRpcRsp(success=False, msg=f'forward to model backend failed: {str(e)}', func_name=req.func_name) + class ModelRpcClient: def __init__(self, rpc_event, rpc_finished_event): @@ -230,6 +241,15 @@ def flush_radix_cache(self) -> bool: assert func_name == "flush_radix_cache" return ret + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + self.rpc_shm_params.write_func_params("forward_to_model", (req,)) + self.rpc_event.set() + + self.rpc_finished_event.wait() + self.rpc_finished_event.clear() + func_name, ret = self.rpc_shm_results.read_func_result() + assert func_name == "forward_to_model" + return ret def _init_env( args, diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4c..28667c6d00 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -65,12 +65,15 @@ def init_vision_distributed_env(kvargs): device_id = visual_gpu_ids[kvargs["vit_rank_id"]] set_current_device_id(device_id) torch.cuda.set_device(device_id) + # 不要在init_process_group时,显示的传入device_id + # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank + # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行 + # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。 dist.init_process_group( "nccl", init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', rank=kvargs["tp_rank_id"], world_size=tp_world_size, - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -104,7 +107,6 @@ def init_distributed_env(kvargs): init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', rank=kvargs["rank_id"], world_size=kvargs["world_size"], - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -270,3 +272,71 @@ def _init_nccl_env(): assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}" return + + +# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675 +def init_custom_process_group( + backend=None, + init_method=None, + timeout=None, + world_size=-1, + rank=-1, + store=None, + group_name=None, + pg_options=None, + device_id=None, +): + from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, + ) + + assert (store is None) or (init_method is None), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + device_id=device_id, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py new file mode 100644 index 0000000000..c504e4bbc9 --- /dev/null +++ b/lightllm/utils/patch_torch.py @@ -0,0 +1,65 @@ +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py +from typing import Callable, Union + +import torch +from packaging import version +from torch.multiprocessing import reductions + + +def monkey_patch_torch_reductions(): + """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" + + # Currently, NPU does not support UUID. This has been temporarily commented out, with support expected in the fourth quarter. + # if _is_npu: + # return + + if hasattr(reductions, "_reduce_tensor_original"): + return + + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + + reductions.init_reductions() + + +# The signature has not been changed for years, and we will not need this when the next version is released, +# so it looks safe to use a constant. +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple( + output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid + ) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise Exception("Invalid device_uuid=" + device_maybe_uuid) + + raise Exception(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return *t[:index], modifier(t[index]), *t[index + 1 :] + diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py new file mode 100644 index 0000000000..e0b5233032 --- /dev/null +++ b/lightllm/utils/serializer.py @@ -0,0 +1,132 @@ + +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py + +import base64 +import pickle +import io +from dataclasses import dataclass +from multiprocessing.reduction import ForkingPickler +from typing import List + + +class MultiprocessingSerializer: + @staticmethod + def serialize(obj, output_str: bool = False): + """ + Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = base64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """ + Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = base64.b64decode(data, validate=True) + + return SafeUnpickler(io.BytesIO(data)).load() + + +class SafeUnpickler(pickle.Unpickler): + ALLOWED_MODULE_PREFIXES = { + # --- Python types --- + "builtins.", + "collections.", + "copyreg.", + "functools.", + "itertools.", + "operator.", + "types.", + "weakref.", + # --- PyTorch types --- + "torch.", + "torch._tensor.", + "torch.storage.", + "torch.nn.parameter.", + "torch.autograd.function.", + # --- torch distributed --- + "torch.distributed.", + "torch.distributed._shard.", + "torch.distributed._composable.", + "torch._C._distributed_c10d.", + "torch._C._distributed_fsdp.", + "torch.distributed.optim.", + # --- multiprocessing --- + "multiprocessing.resource_sharer.", + "multiprocessing.reduction.", + "pickletools.", + # --- PEFT / LoRA --- + "peft.", + "transformers.", + "huggingface_hub.", + # --- SGLang & Unitest --- + "sglang.srt.weight_sync.tensor_bucket.", + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", + } + + DENY_CLASSES = { + ("builtins", "eval"), + ("builtins", "exec"), + ("builtins", "compile"), + ("os", "system"), + ("subprocess", "Popen"), + ("subprocess", "run"), + ("codecs", "decode"), + ("types", "CodeType"), + ("types", "FunctionType"), + } + + def find_class(self, module, name): + # Block deterministic attacks + if (module, name) in self.DENY_CLASSES: + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " + f"to prevent exploitation of CVE-2025-10164" + ) + # Allowlist of safe-to-load modules. + if any( + (module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES + ): + return super().find_class(module, name) + + # Block everything else. (Potential attack surface) + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " + f"to prevent exploitation of CVE-2025-10164" + ) + +@dataclass +class LocalSerializedTensor: + """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data). + The i-th element in the list corresponds to i-th rank's GPU.""" + + values: List[bytes] + + def get(self, rank: int): + return MultiprocessingSerializer.deserialize(self.values[rank]) \ No newline at end of file diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py new file mode 100644 index 0000000000..762bd0dd06 --- /dev/null +++ b/lightllm/utils/tensor_bucket.py @@ -0,0 +1,108 @@ +# copy from https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/srt/weight_sync/tensor_bucket.py +from dataclasses import dataclass +from typing import List, Tuple + +import torch + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + # This field is solely for users of to check whether the class supports this feature + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten().view(torch.uint8) + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError( + "Must provide either named_tensors or both flattened_tensor and metadata" + ) + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = ( + self.flattened_tensor[meta.start_idx : meta.end_idx] + .view(meta.dtype) + .reshape(meta.shape) + ) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed \ No newline at end of file From c210c82fa08f9933ab286e5b67d638da35c9a7aa Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Mon, 1 Dec 2025 20:06:06 +0800 Subject: [PATCH 008/180] release and resume (#1122) --- lightllm/common/basemodel/basemodel.py | 96 ++++++++++++++++--- lightllm/common/basemodel/cuda_graph.py | 9 +- .../basemodel/layer_weights/hf_load_utils.py | 70 +++++++++++++- lightllm/server/api_cli.py | 6 ++ lightllm/server/api_http.py | 40 +++++--- lightllm/server/core/objs/start_args_type.py | 22 ++--- lightllm/server/detokenization/manager.py | 12 +-- lightllm/server/httpserver/manager.py | 78 ++++++++------- lightllm/server/io_struct.py | 33 ++++++- lightllm/server/router/manager.py | 36 +++---- .../model_infer/mode_backend/base_backend.py | 75 ++++++++------- .../server/router/model_infer/model_rpc.py | 37 +++---- lightllm/utils/torch_memory_saver_utils.py | 92 ++++++++++++++++++ requirements.txt | 3 +- 14 files changed, 451 insertions(+), 158 deletions(-) create mode 100644 lightllm/utils/torch_memory_saver_utils.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 1221f1939b..cc50d0a085 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -6,7 +6,7 @@ import json import torch import torch.nn.functional as F -from typing import final +from typing import final, List, Optional from tqdm import tqdm from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights @@ -30,6 +30,10 @@ from lightllm.utils.envs_utils import set_model_init_status from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) logger = init_logger(__name__) @@ -88,6 +92,7 @@ def __init__(self, kvargs): self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"] + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) self._init_datatype() self._init_config() @@ -97,20 +102,29 @@ def __init__(self, kvargs): # 更连续的显存分配可以有更好的性能 if self.max_total_token_num is None: - self._init_weights() - self._init_mem_manager() + with self.torch_memory_saver.region( + tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup + ): + self._init_weights() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_mem_manager() else: - self._init_mem_manager() - self._init_weights() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_mem_manager() + with self.torch_memory_saver.region( + tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup + ): + self._init_weights() self._init_kv_move_buffer() self._check_mem_size() - self._init_req_manager() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_req_manager() self._init_infer_layer() self._init_some_value() self._init_custom() self._init_inferstate_cls() - self._autotune_warmup() + # self._autotune_warmup() self._init_padded_req() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() @@ -179,11 +193,13 @@ def _init_weights(self): return def load_weights(self, weight_dict: dict): - load_hf_weights(self.data_type, - self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=weight_dict) + load_hf_weights( + self.data_type, + self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=weight_dict, + ) def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 @@ -766,6 +782,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + torch.cuda.empty_cache() return def autotune_layers(self): @@ -896,6 +913,9 @@ def _init_padded_req(self): del b_seq_len del b_ready_cache_len del model_output + del b_mtp_index + del b_prefill_start_loc + del b_q_seq_len torch.cuda.empty_cache() return @@ -911,3 +931,55 @@ def _gen_special_model_input(self, token_num: int): special_model_input["deepseekv3_mtp_draft_input_hiddens"] = None return special_model_input + + def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): + if tags is None: + self.release_all() + return + if MemoryTag.WEIGHT in tags: + self.release_weight() + if MemoryTag.KV_CACHE in tags: + self.release_kv_cache() + if MemoryTag.GRAPH in tags: + self.release_graph() + return + + def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): + if tags is None: + self.resume_all() + return + if MemoryTag.WEIGHT in tags: + self.resume_weight() + if MemoryTag.KV_CACHE in tags: + self.resume_kv_cache() + if MemoryTag.GRAPH in tags: + self.resume_graph() + return + + def release_weight(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + + def release_kv_cache(self): + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + + def release_graph(self): + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + + def release_all(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + + def resume_weight(self): + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + + def resume_kv_cache(self): + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + + def resume_graph(self): + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) + + def resume_all(self): + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index c754fabce0..220ae10cff 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -7,6 +7,10 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .infer_struct import InferStateInfo @@ -24,6 +28,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] @@ -82,7 +87,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf torch.cuda.synchronize() with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output = decode_func(input_ids, infer_state) self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) graph_obj.replay() @@ -111,7 +116,7 @@ def _capture_decode_overlap( torch.cuda.synchronize() with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) self.graph[batch_size] = ( graph_obj, diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8cf66a5ad6..2a9006efd6 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -5,6 +5,8 @@ from tqdm import tqdm import lightllm.utils.petrel_helper as utils from lightllm.utils.dist_utils import get_current_device_id +from queue import Queue +from threading import Thread def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): @@ -28,7 +30,7 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay gc.collect() -def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): +def load_hf_weights_old(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): if isinstance(data_type, str): data_type = torch.float16 if data_type == "fp16" else torch.float32 if pre_post_layer is not None: @@ -70,3 +72,69 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye pass return + + +def _read_file(file_, use_safetensors, weight_dir): + if use_safetensors: + weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + weights = {k: weights.get_tensor(k) for k in weights.keys()} + else: + weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") + + return weights + + +def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): + if isinstance(data_type, str): + data_type = torch.float16 if data_type == "fp16" else torch.float32 + if pre_post_layer is not None: + assert pre_post_layer.data_type_ == data_type, "type is not right" + if transformer_layer_list is not None: + assert transformer_layer_list[0].data_type_ == data_type, "type is not right" + if weight_dict: + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weight_dict) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weight_dict) + del weight_dict + return + use_safetensors = True + files = utils.PetrelHelper.list(weight_dir, extension="all") + candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) + if len(candidate_files) == 0: + use_safetensors = False + candidate_files = list(filter(lambda x: x.endswith(".bin"), files)) + assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." + + weight_queue = Queue(maxsize=5) # 控制内存使用 + + def producer(chunk): + for file_ in chunk: + weights = _read_file(file_, use_safetensors, weight_dir) + weight_queue.put(weights) + + LOADWORKER = int(os.environ.get("LOADWORKER", 1)) + + num_producers = min(LOADWORKER, len(candidate_files)) # 生产者数量 + chunk_size = (len(candidate_files) + num_producers - 1) // num_producers + file_chunks = [candidate_files[i : i + chunk_size] for i in range(0, len(candidate_files), chunk_size)] + + producer_threads = [] + for i, chunk in enumerate(file_chunks): + thread = Thread(target=producer, args=(chunk,), name=f"Producer-{i}") + thread.start() + producer_threads.append(thread) + + for _ in tqdm(range(len(candidate_files)), desc="Loading weights"): + weights = weight_queue.get() + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weights) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weights) + del weights + gc.collect() + + for thread in producer_threads: + thread.join() diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ec11f8f1da..ee3f184e41 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -537,4 +537,10 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" ) + parser.add_argument( + "--enable_torch_memory_saver", + action="store_true", + help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""", + ) + parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""") return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 28ae93dfb2..ff9acafc94 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -49,6 +49,7 @@ from lightllm.utils.error_utils import ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq from dataclasses import dataclass from .api_openai import chat_completions_impl, completions_impl @@ -60,11 +61,12 @@ ) from .io_struct import ( AbortReq, + FlushCacheReq, InitWeightsUpdateGroupReq, DestroyWeightsUpdateGroupReq, UpdateWeightsFromDistributedReq, UpdateWeightsFromTensorReq, - GeneralModelToHttpRpcRsp + GeneralModelToHttpRpcRsp, ) from .build_prompt import build_prompt, init_tokenizer @@ -143,14 +145,17 @@ def get_model_name(): def get_server_info(): # 将 StartArgs 转换为字典格式 from dataclasses import asdict + server_info: dict[str, Any] = asdict(g_objs.args) return {**server_info} + @app.get("/get_weight_version") @app.post("/get_weight_version") def get_weight_version(): return {"weight_version": g_objs.args.weight_version} + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") @@ -330,43 +335,38 @@ async def handle_request_common(request_obj, handler): else: return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) except Exception as e: - return create_error_response( - HTTPStatus.EXPECTATION_FAILED, - f"error: {str(e)}" - ) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + @app.post("/init_weights_update_group") async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request): """Init weights update group.""" return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group) + @app.post("/destroy_weights_update_group") async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request): """Destroy weights update group.""" return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group) + @app.post("/update_weights_from_distributed") async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request): """Update model parameter from distributed online.""" return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed) + @app.post("/update_weights_from_tensor") -async def update_weights_from_distributed(request: UpdateWeightsFromTensorReq, raw_request: Request): +async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request): """Update model parameter from distributed online.""" return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor) + @app.post("/flush_cache") @app.get("/flush_cache") async def flush_cache(): """Flush the radix cache.""" - ret = await g_objs.httpserver_manager.flush_cache() - return Response( - content="Cache flushed successfully." - if ret - else "Cache flush failed. " - + "When there are running or waiting requests, the operation will not be performed.", - status_code=200 if ret else 500, - ) + return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache) @app.post("/pause_generation") @@ -381,6 +381,18 @@ async def continue_generation(): return Response(content="Generation continued successfully.", status_code=200) +@app.get("/release_memory_occupation") +@app.post("/release_memory_occupation") +async def release_memory_occupation(request: ReleaseMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation) + + +@app.get("/resume_memory_occupation") +@app.post("/resume_memory_occupation") +async def resume_memory_occupation(request: ResumeMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 40f68a7439..eff4dfab55 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -8,7 +8,9 @@ class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"]}, + metadata={ + "choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"] + }, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) @@ -23,8 +25,7 @@ class StartArgs: config_server_port: int = field(default=None) pd_decode_rpyc_port: int = field(default=42000) select_p_d_node_strategy: str = field( - default="round_robin", - metadata={"choices": ["random", "round_robin", "adaptive_load"]} + default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]} ) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) @@ -59,10 +60,7 @@ class StartArgs: disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field( - default="none", - metadata={"choices": ["outlines", "xgrammar", "none"]} - ) + output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) @@ -107,8 +105,7 @@ class StartArgs: ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) mtp_mode: Optional[str] = field( - default=None, - metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]} + default=None, metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]} ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) @@ -141,10 +138,7 @@ class StartArgs: httpserver_workers: int = field(default=1) disable_shm_warning: bool = field(default=False) - dp_balancer: str = field( - default="bs_balancer", - metadata={"choices": ["round_robin", "bs_balancer"]} - ) + dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]}) enable_custom_allgather: bool = field(default=False) enable_fused_shared_experts: bool = field(default=False) enable_mps: bool = field(default=False) @@ -152,5 +146,7 @@ class StartArgs: schedule_time_interval: float = field(default=0.03) use_dynamic_prompt_cache: bool = field(default=False) disable_custom_allreduce: bool = field(default=False) + enable_torch_memory_saver: bool = field(default=False) + enable_weight_cpu_backup: bool = field(default=False) weight_version: str = "default" diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index b7ba960258..17a47dfde6 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -20,7 +20,9 @@ BaseReq, GenerateResp, FlushCacheResp, - GeneralModelToHttpRpcRsp + ReleaseMemoryResp, + ResumeMemoryResp, + GeneralModelToHttpRpcRsp, ) logger = init_logger(__name__) @@ -80,14 +82,8 @@ def handle_loop(self): # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_obj, FlushCacheResp): - print("Detokenization receive flush cache request", flush=True) + if isinstance(recv_obj, GeneralModelToHttpRpcRsp): self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) - print("Detokenization send flush cache request to httpserver", flush=True) - continue - elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): - self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) - print(f"Detokenization send {recv_obj.func_name} request to httpserver") continue self._add_new_group_req_index(recv_obj=recv_obj) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 44721cfe3b..083e939bac 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -14,7 +14,7 @@ from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Optional +from typing import Union, List, Tuple, Dict, Optional, Literal from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer @@ -40,6 +40,10 @@ GenerateResp, GenerateReqMeta, GenerateReqIndex, + ReleaseMemoryReq, + ReleaseMemoryResp, + ResumeMemoryReq, + ResumeMemoryResp, InitWeightsUpdateGroupReq, InitWeightsUpdateGroupRsp, DestroyWeightsUpdateGroupReq, @@ -49,13 +53,13 @@ UpdateWeightsFromTensorReq, UpdateWeightsFromTensorRsp, GeneralHttpToModelRpcReq, - GeneralModelToHttpRpcRsp - + GeneralModelToHttpRpcRsp, ) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.torch_memory_saver_utils import MemoryTag from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -140,6 +144,8 @@ def __init__( # 交互式请求 event self.flush_cache_event: Optional[asyncio.Event] = None + self.release_memory_event: Optional[asyncio.Event] = None + self.resume_memory_event: Optional[asyncio.Event] = None self.async_events_per_func: Dict[str, asyncio.Event] = {} return @@ -768,8 +774,6 @@ async def handle_loop(self): try: if recv_obj is None or isinstance(recv_obj, GenerateResp): await self._handle_recv_generate_request(recv_obj) - elif isinstance(recv_obj, FlushCacheResp): - await self._handle_recv_flush_cache_request(recv_obj) elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): await self._handle_recv_general_model_to_http_request(recv_obj) @@ -835,12 +839,6 @@ async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta): req_status.out_token_info_list.extend(token_list) req_status.event.set() - async def _handle_recv_flush_cache_request(self, recv_obj: FlushCacheResp): - assert self.flush_cache_event is not None - self.flush_cache_event.success = recv_obj.success - self.flush_cache_event.set() - return - async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralModelToHttpRpcRsp): assert recv_obj.func_name is not None event = await self.get_event_for_func(recv_obj.func_name) @@ -848,22 +846,11 @@ async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralMode event.set() return - async def flush_cache(self): - if self.flush_cache_event is None: - self.flush_cache_event = asyncio.Event() - await self.transfer_to_next_module(FlushCacheReq()) - try: - await asyncio.wait_for(self.flush_cache_event.wait(), timeout=30) - ret = self.flush_cache_event.success - except asyncio.TimeoutError: - # 超时直接返回失败 - ret = False - self.flush_cache_event.clear() - return ret - async def pause_generation(self): # 因为请求是从master node转发到slave node的 # 所以只要master暂停了,slave自然暂停。 + if self.is_pause: + return async with self.is_pause_cond: self.is_pause = True while True: @@ -883,7 +870,9 @@ async def get_event_for_func(self, func_name: str) -> asyncio.Event: self.async_events_per_func[func_name] = asyncio.Event() return self.async_events_per_func[func_name] - async def http_to_model_special_request(self, request: GeneralHttpToModelRpcReq, timeout: int=300) -> GeneralModelToHttpRpcRsp: + async def http_to_model_special_request( + self, request: GeneralHttpToModelRpcReq, timeout: int = 300 + ) -> GeneralModelToHttpRpcRsp: event = await self.get_event_for_func(request.func_name) await self.transfer_to_next_module(request) try: @@ -893,19 +882,41 @@ async def http_to_model_special_request(self, request: GeneralHttpToModelRpcReq, except asyncio.TimeoutError: ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) except Exception as e: - ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name) + ret = GeneralModelToHttpRpcRsp( + success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name + ) return ret + async def flush_cache(self, request: FlushCacheReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="flush_cache", func_args=request) + ) - async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): - return await self.http_to_model_special_request(GeneralHttpToModelRpcReq( - func_name="init_weights_update_group", func_args=request)) + async def release_memory_occupation(self, request: ReleaseMemoryReq): + assert len(self.req_id_to_out_inf) == 0, "there are still requests running, cannot release memory occupation" + # 暂停接受请求,除非resume + await self.pause_generation() + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="release_memory_occupation", func_args=request.tags) + ) + async def resume_memory_occupation(self, request: ResumeMemoryReq): + ret = await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="resume_memory_occupation", func_args=request.tags) + ) + if ret.success: + await self.continue_generation() + return ret - async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): - return await self.http_to_model_special_request(GeneralHttpToModelRpcReq( - func_name="destroy_weights_update_group", func_args=request)) + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="init_weights_update_group", func_args=request) + ) + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="destroy_weights_update_group", func_args=request) + ) async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): @@ -916,7 +927,8 @@ async def update_weights_from_distributed(self, request: UpdateWeightsFromDistri await self.flush_cache() return await self.http_to_model_special_request( - GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request)) + GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request) + ) async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> Tuple[bool, str]: if request.abort_all_requests: diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 7947c7a581..e04e8871ce 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -4,6 +4,7 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from lightllm.server.multimodal_params import MultimodalParams from typing import List, Optional, Any, Union +from lightllm.utils.torch_memory_saver_utils import MemoryTag @dataclass @@ -14,11 +15,13 @@ def get_req_to_next_node(self): def get_req_to_next_module(self): return self + @dataclass class BaseRsp(ABC): success: bool msg: Optional[str] + # for next node @dataclass class GenerateReqMeta(BaseReq): @@ -85,16 +88,38 @@ class AbortReq(BaseReq): abort_all: bool = False +@dataclass +class ReleaseMemoryReq(BaseReq): + tags: Optional[List[MemoryTag]] = None + + +@dataclass +class ReleaseMemoryResp(BaseReq): + success: bool + + +@dataclass +class ResumeMemoryReq(BaseReq): + tags: Optional[List[MemoryTag]] = None + + +@dataclass +class ResumeMemoryResp(BaseReq): + success: bool + + @dataclass class GeneralHttpToModelRpcReq(BaseReq): func_name: str func_args: Optional[Any] = None + @dataclass class GeneralModelToHttpRpcRsp(BaseRsp): func_name: str func_rsp: Optional[Any] = None + @dataclass class InitWeightsUpdateGroupReq(BaseReq): # The master address @@ -110,18 +135,22 @@ class InitWeightsUpdateGroupReq(BaseReq): # The backend backend: str = "nccl" + @dataclass class InitWeightsUpdateGroupRsp(BaseRsp): pass + @dataclass class DestroyWeightsUpdateGroupReq(BaseReq): group_name: str = "weight_update_group" + @dataclass class DestroyWeightsUpdateGroupRsp(BaseRsp): pass + @dataclass class UpdateWeightsFromDistributedReq(BaseReq): names: List[str] @@ -136,6 +165,7 @@ class UpdateWeightsFromDistributedReq(BaseReq): # Optional: Update weight version along with weights weight_version: Optional[str] = None + @dataclass class UpdateWeightsFromDistributedRsp(BaseRsp): pass @@ -159,6 +189,7 @@ class UpdateWeightsFromTensorReq(BaseReq): # Optional: Update weight version along with weights weight_version: Optional[str] = None + @dataclass class UpdateWeightsFromTensorRsp(BaseRsp): - pass \ No newline at end of file + pass diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ea1f17b90f..20952e2283 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -35,12 +35,17 @@ GenerateReqIndex, FlushCacheReq, FlushCacheResp, + ReleaseMemoryReq, + ReleaseMemoryResp, + ResumeMemoryReq, + ResumeMemoryResp, GeneralHttpToModelRpcReq, - GeneralModelToHttpRpcRsp + GeneralModelToHttpRpcRsp, ) from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.torch_memory_saver_utils import MemoryTag logger = init_logger(__name__) @@ -550,8 +555,10 @@ async def _recv_new_reqs_and_schedule(self): recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) - elif isinstance(recv_req, (FlushCacheReq, GeneralHttpToModelRpcReq)): + elif isinstance(recv_req, GeneralHttpToModelRpcReq): special_reqs.append(recv_req) + else: + raise ValueError(f"Unknown request type: {type(recv_req)}") # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) @@ -573,10 +580,8 @@ def _process_special_reqs(self, special_reqs: List[BaseReq]): if self.is_multinode_tp: special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs) for req in special_reqs: - if isinstance(req, FlushCacheReq): - self.flush_cache() - elif isinstance(req, (GeneralHttpToModelRpcReq)): - self.forward_to_model(req) + assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq" + self.forward_to_model(req) def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): req_num = len(reqs) @@ -595,30 +600,13 @@ def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) return reqs - def flush_cache(self) -> None: - # if radix cache client is not initialized, just return True - if self.radix_cache_client is None: - success = True - # only flush cache when no running batch and no waiting requests - elif self.running_batch is not None or self.req_queue.get_wait_req_num() > 0: - success = False - else: - success = self.model_rpc_client.flush_radix_cache() - - if self.is_multinode_tp: - # 等待其他节点的flush 结果 - dist.barrier(group=self.mulitnode_group) - if self.node_rank == 0: - self.send_to_detokenization.send_pyobj(FlushCacheResp(success=success), protocol=pickle.HIGHEST_PROTOCOL) - return - def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: ret = self.model_rpc_client.forward_to_model(req) if self.is_multinode_tp: output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) for res in output_list: - res : GeneralModelToHttpRpcRsp + res: GeneralModelToHttpRpcRsp if not res.success: ret = res break diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 60431c8ff1..4be944584d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -41,12 +41,14 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.utils.torch_memory_saver_utils import MemoryTag from .multi_level_kv_cache import MultiLevelKvCacheModule from lightllm.server.io_struct import ( + FlushCacheReq, InitWeightsUpdateGroupReq, DestroyWeightsUpdateGroupReq, UpdateWeightsFromDistributedReq, - UpdateWeightsFromTensorReq + UpdateWeightsFromTensorReq, ) @@ -299,36 +301,54 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return - def flush_radix_cache(self): + def flush_cache(self, request: FlushCacheReq): if self.radix_cache is not None: self.radix_cache.flush_cache() - return + return True, "Succeeded to flush cache." + + def release_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.release_memory_occupation(tags) + self.flush_cache(request=None) + self.model.req_manager.free_all() + self.model.mem_manager.free_all() + return True, "Succeeded to release memory occupation." + except Exception as e: + self.logger.error(f"release memory occupation failed: {str(e)}") + return False, f"release memory occupation failed: {str(e)}" + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.resume_memory_occupation(tags) + return True, "Succeeded to resume memory occupation." + except Exception as e: + self.logger.error(f"resume memory occupation failed: {str(e)}") + return False, f"resume memory occupation failed: {str(e)}" def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): - assert ( - torch.distributed.is_initialized() - ), "Default torch process group must be initialized" + assert torch.distributed.is_initialized(), "Default torch process group must be initialized" assert request.group_name != "", "Group name cannot be empty" - rank = request.rank_offset + self.rank_in_dp + rank_offset = request.rank_offset + rank = rank_offset + self.rank_in_dp + world_size = request.world_size + group_name = request.group_name self.logger.info( f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, " - f"rank_offset={request.rank_offset}, rank={rank}, world_size={request.world_size}, group_name={request.group_name}, " + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, " f" backend={request.backend}" ) try: - if request.group_name in self._model_update_group: - raise ValueError( - f"Process group with name {request.group_name} already exists." - ) + if group_name in self._model_update_group: + raise ValueError(f"Process group with name {group_name} already exists.") - self._model_update_group[request.group_name] = init_custom_process_group( + self._model_update_group[group_name] = init_custom_process_group( backend=request.backend, init_method=f"tcp://{request.master_address}:{request.master_port}", - world_size=request.world_size, + world_size=world_size, rank=rank, - group_name=request.group_name, + group_name=group_name, ) return True, "Succeeded to initialize custom process group." @@ -370,10 +390,8 @@ def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedR weights = [] handles = [] for name, dtype, shape in zip(request.names, request.dtypes, request.shapes): - target_dtype = ( - dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) - ) - weight = torch.empty(shape, dtype=target_dtype, device='cuda') + target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + weight = torch.empty(shape, dtype=target_dtype, device="cuda") handles.append( torch.distributed.broadcast( weight, @@ -420,9 +438,7 @@ def _update_weights_from_flattened_bucket( converted_metadata.append(converted_meta) # Create bucket and reconstruct tensors - bucket = FlattenedTensorBucket( - flattened_tensor=flattened_tensor, metadata=converted_metadata - ) + bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata) reconstructed_tensors = bucket.reconstruct_tensors() # Load the reconstructed tensors using the standard method @@ -430,25 +446,18 @@ def _update_weights_from_flattened_bucket( return True, "Succeeded to update parameter online from flattened bucket tensor." - def update_weights_from_tensor( - self, - request: UpdateWeightsFromTensorReq - ): + def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): try: monkey_patch_torch_reductions() if request.load_format == "flattened_bucket": # Handle flattened bucket format - return self._update_weights_from_flattened_bucket( - flattened_tensor_bucket_dict=request.named_tensors - ) + return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=request.named_tensors) # We need to get device after patch otherwise the device would be wrong self.device_module = torch.get_device_module("cuda") infered_device = self.device_module.current_device() - named_tensors=MultiprocessingSerializer.deserialize( - request.serialized_named_tensors[self.rank_in_dp] - ) + named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp]) def _unwrap_tensor(tensor, tp_rank, device): if isinstance(tensor, LocalSerializedTensor): @@ -456,7 +465,7 @@ def _unwrap_tensor(tensor, tp_rank, device): return tensor.to(device) named_tensors = { - name : _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) + name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) for name, tensor in named_tensors } diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 04e7495adf..399e9d2404 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -33,6 +33,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.torch_memory_saver_utils import MemoryTag from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp logger = init_logger(__name__) @@ -182,13 +183,20 @@ def init_model(self, kvargs): def get_max_total_token_num(self): return self.backend.get_max_total_token_num() - def flush_radix_cache(self): + def release_memory_occupation(self, tags: List[MemoryTag]): try: - if self.backend is not None: - self.backend.flush_radix_cache() + self.backend.release_memory_occupation(tags) return True except BaseException as e: - logger.exception(f"flush radix cache failed: {str(e)}") + logger.exception(f"release memory occupation failed: {str(e)}") + return False + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.backend.resume_memory_occupation(tags) + return True + except BaseException as e: + logger.exception(f"resume memory occupation failed: {str(e)}") return False def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: @@ -199,7 +207,9 @@ def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpR return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret) except BaseException as e: logger.exception(f"forward to model backend failed: {str(e)}") - return GeneralModelToHttpRpcRsp(success=False, msg=f'forward to model backend failed: {str(e)}', func_name=req.func_name) + return GeneralModelToHttpRpcRsp( + success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name + ) class ModelRpcClient: @@ -231,16 +241,6 @@ async def get_max_total_token_num(self): assert func_name == "get_max_total_token_num" return ret - def flush_radix_cache(self) -> bool: - self.rpc_shm_params.write_func_params("flush_radix_cache", ()) - self.rpc_event.set() - - self.rpc_finished_event.wait() - self.rpc_finished_event.clear() - func_name, ret = self.rpc_shm_results.read_func_result() - assert func_name == "flush_radix_cache" - return ret - def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: self.rpc_shm_params.write_func_params("forward_to_model", (req,)) self.rpc_event.set() @@ -251,6 +251,7 @@ def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpR assert func_name == "forward_to_model" return ret + def _init_env( args, rank, @@ -313,7 +314,11 @@ async def start_model_process( success_event, ), ) - proc.start() + from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper + + torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver) + with torch_memory_saver.configure_subprocess(): + proc.start() # Use asyncio.to_thread to make the blocking wait non-blocking await asyncio.to_thread(success_event.wait, timeout=40) diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py new file mode 100644 index 0000000000..edf15fa837 --- /dev/null +++ b/lightllm/utils/torch_memory_saver_utils.py @@ -0,0 +1,92 @@ +import torch +from contextlib import contextmanager +from enum import Enum +from lightllm.utils.log_utils import init_logger + +try: + from torch_memory_saver import ( + torch_memory_saver, + configure_subprocess, + ) + + HAS_TORCH_MEMORY_SAVER = True + +except ImportError: + HAS_TORCH_MEMORY_SAVER = False + pass + +logger = init_logger(__name__) + + +class MemoryTag(Enum): + KV_CACHE = "kv_cache" + WEIGHT = "weight" + GRAPH = "graph" + + def is_kv_cache(self): + return self == MemoryTag.KV_CACHE + + def is_weight(self): + return self == MemoryTag.WEIGHT + + def is_graph(self): + return self == MemoryTag.GRAPH + + def __str__(self): + return self.value + + +class TorchMemorySaverWrapper: + def __new__(cls, enable_torch_memory_saver: bool = False): + if enable_torch_memory_saver: + assert ( + HAS_TORCH_MEMORY_SAVER + ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`." + return _TorchMemorySaver() + else: + return _TorchMemorySaverFake() + + +class _TorchMemorySaver: + def configure_subprocess(self): + return configure_subprocess() + + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup) + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value) + + def disable(self): + return torch_memory_saver.disable() + + def pause(self, tag: MemoryTag): + return torch_memory_saver.pause(tag=tag.value) + + def resume(self, tag: MemoryTag): + return torch_memory_saver.resume(tag=tag.value) + + +class _TorchMemorySaverFake: + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + yield + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch.cuda.graph(graph_obj, **kwargs) + + @contextmanager + def disable(self): + yield + + def pause(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, pause is not supported.") + return + + def resume(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, resume is not supported.") + return diff --git a/requirements.txt b/requirements.txt index 40d0b49566..20f27dc05a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -87,4 +87,5 @@ librosa==0.11.0 cuda_bindings==12.9.0 orjson==3.11.2 setproctitle==1.3.6 -xxhash==3.6.0 \ No newline at end of file +xxhash==3.6.0 +torch_memory_saver==0.0.9 From 094df8ca341f3ae3d579933321677f7c6ad6424e Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Mon, 8 Dec 2025 19:38:00 +0800 Subject: [PATCH 009/180] use portpicker (#1142) --- lightllm/utils/net_utils.py | 52 ++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 20b9888753..486414e88e 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -2,44 +2,72 @@ import subprocess import ipaddress import random +import portpicker from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000): + if used_nccl_ports is None: + used_nccl_ports = [] + port_list = [] - for port in range(from_port_num, 65536): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) - if result != 0 and port not in used_nccl_ports: + max_attempts = num * 50 # Allow more attempts to find ports in range + + for _ in range(max_attempts): + if len(port_list) >= num: + break + + try: + port = portpicker.pick_unused_port() + + if port >= from_port_num and port not in used_nccl_ports: port_list.append(port) - if len(port_list) > num * 30: - break + logger.debug(f"Allocated port: {port}") + else: + logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping") + + except Exception as e: + logger.warning(f"Failed to allocate port: {e}") + continue if len(port_list) < num: + logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}") return None - random.shuffle(port_list) - return port_list[0:num] + logger.info(f"Successfully allocated {len(port_list)} ports: {port_list}") + return port_list def alloc_can_use_port(min_port, max_port): port_list = [] for port in range(min_port, max_port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: port_list.append(port) + except Exception: + continue return port_list def find_available_port(start_port, end_port): for port in range(start_port, end_port + 1): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - result = sock.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: return port + except Exception: + continue return None From 560be020a0414ae0cb63197a2a1409e6cfd66895 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Mon, 8 Dec 2025 20:52:03 +0800 Subject: [PATCH 010/180] Rl weight (#1143) Co-authored-by: sufubao --- lightllm/common/basemodel/basemodel.py | 10 +- .../layer_weights/meta_weights/__init__.py | 5 +- .../{ => fused_moe}/fused_moe_weight_ep.py | 178 ++++- .../fused_moe_weight_ep_redundancy.py | 9 +- .../fused_moe/fused_moe_weight_tp.py | 325 +++++++++ .../gpt_oss_fused_moe_weight_tp.py | 2 +- .../meta_weights/fused_moe_weight_tp.py | 665 ------------------ .../meta_weights/mm_weight/__init__.py | 9 +- .../meta_weights/mm_weight/colmm_weight.py | 82 +-- .../meta_weights/mm_weight/mm_factory.py | 90 --- .../meta_weights/mm_weight/mm_slicer.py | 18 + .../meta_weights/mm_weight/mm_weight.py | 348 ++------- .../meta_weights/mm_weight/rowmm_weight.py | 88 +-- .../layer_weights/meta_weights/norm_weight.py | 50 +- .../layer_weights/transformer_layer_weight.py | 6 +- lightllm/common/quantization/__init__.py | 5 +- lightllm/common/quantization/awq_quant.py | 139 ++-- .../common/quantization/deepgemm_quant.py | 54 +- lightllm/common/quantization/no_quant.py | 52 ++ .../common/quantization/quantize_method.py | 66 +- lightllm/common/quantization/registry.py | 5 +- lightllm/common/quantization/torchao_quant.py | 9 +- .../fp8/fp8w8a8_block_quant_kernel.py | 2 +- .../fp8/fp8w8a8_scaled_mm_per_token_kernel.py | 471 +++++++++++++ .../quantization/triton_quant/triton_quant.py | 43 +- lightllm/common/quantization/w8a8_quant.py | 100 ++- .../pre_and_post_layer_weight.py | 25 +- .../pre_and_post_layer_weight.py | 17 +- .../pre_and_post_layer_weight.py | 6 +- .../layer_weights/transformer_layer_weight.py | 6 +- .../layer_weights/transformer_layer_weight.py | 8 +- .../pre_and_post_layer_weight.py | 20 +- .../layer_weights/transformer_layer_weight.py | 8 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 6 +- .../pre_and_post_layer_weight.py | 20 +- .../pre_and_post_layer_weight.py | 6 +- .../pre_and_post_layer_weight.py | 18 +- .../layer_weights/transformer_layer_weight.py | 53 +- .../pre_and_post_layer_weight.py | 6 +- .../layer_weights/transformer_layer_weight.py | 1 + .../pre_and_post_layer_weight.py | 18 +- .../pre_and_post_layer_weight.py | 21 +- .../layer_weights/transformer_layer_weight.py | 9 - .../pre_and_post_layer_weight.py | 32 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 5 +- .../pre_and_post_layer_weight.py | 23 +- .../pre_and_post_layer_weight.py | 31 +- .../pre_and_post_layer_weight.py | 25 +- .../pre_and_post_layer_weight.py | 48 +- .../layer_weights/transformer_layer_weight.py | 9 +- .../mode_backend/redundancy_expert_manager.py | 4 +- 54 files changed, 1727 insertions(+), 1541 deletions(-) rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/fused_moe_weight_ep.py (74%) rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/fused_moe_weight_ep_redundancy.py (96%) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/gpt_oss_fused_moe_weight_tp.py (99%) delete mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py delete mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py create mode 100644 lightllm/common/quantization/no_quant.py create mode 100644 lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index cc50d0a085..3eb5d7dbe4 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -120,6 +120,7 @@ def __init__(self, kvargs): self._check_mem_size() with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): self._init_req_manager() + self.load_weights(self.weight_dict) self._init_infer_layer() self._init_some_value() self._init_custom() @@ -181,15 +182,6 @@ def _init_weights(self): ) for i in range(self.config["n_layer"]) ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return def load_weights(self, weight_dict: dict): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index b3dab0614c..e4f2beebcc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -1,11 +1,10 @@ from .base_weight import BaseWeight from .mm_weight import ( - MMWeightPack, MMWeightTpl, ROWMMWeight, COLMMWeight, ROWBMMWeight, ) from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight -from .fused_moe_weight_tp import create_tp_moe_wegiht_obj -from .fused_moe_weight_ep import FusedMoeWeightEP +from .fused_moe.fused_moe_weight_tp import create_tp_moe_wegiht_obj +from .fused_moe.fused_moe_weight_ep import FusedMoeWeightEP diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py similarity index 74% rename from lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py index 87a7b361e3..0923d5dea0 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py @@ -3,7 +3,7 @@ import threading from typing import Optional, Tuple, List, Dict, Any from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id -from .base_weight import BaseWeight +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeight from lightllm.common.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, masked_group_gemm, @@ -23,6 +23,7 @@ from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair from lightllm.utils.log_utils import init_logger from lightllm.common.triton_utils.autotuner import Autotuner +from lightllm.common.quantization.quantize_method import WeightPack logger = init_logger(__name__) @@ -41,6 +42,7 @@ def __init__( network_config: Dict[str, Any], layer_num: int, quant_cfg=None, + hidden_size: Optional[int] = None, ) -> None: super().__init__() @@ -62,6 +64,7 @@ def __init__( self.e_score_correction_bias_name = e_score_correction_bias_name self.n_routed_experts = n_routed_experts self.data_type_ = data_type + self.hidden_size = hidden_size global_world_size = get_global_world_size() self.global_rank_ = get_global_rank() @@ -78,6 +81,7 @@ def __init__( assert self.n_routed_experts % global_world_size == 0 self.ep_n_routed_experts = self.n_routed_experts // global_world_size ep_load_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num + self.ep_load_expert_num = ep_load_expert_num self.experts_up_projs = [None] * ep_load_expert_num self.experts_gate_projs = [None] * ep_load_expert_num self.experts_up_proj_scales = [None] * ep_load_expert_num @@ -105,6 +109,51 @@ def __init__( # auto update redundancy expert vars self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert + # Pre-allocate memory if hidden_size is provided + if self.hidden_size is not None: + self._create_weight() + + def _create_weight(self): + """Pre-allocate GPU memory for fused MoE weights""" + if self.hidden_size is None: + return + + total_expert_num = self.ep_load_expert_num + # We need to determine intermediate size from network config or use a default + # This will be updated when first weight is loaded if needed + intermediate_size = getattr(self, "intermediate_size", None) + if intermediate_size is None: + # Default fallback - this will be corrected during load + intermediate_size = self.hidden_size * 4 + + device_id = get_current_device_id() + + if not self.quantized_weight and self.quant_method is not None: + # Quantized weights + w1_pack = self.quant_method.create_weight( + total_expert_num * intermediate_size * 2, self.hidden_size, dtype=self.data_type_, device_id=device_id + ) + self.w1[0] = w1_pack.weight.view(total_expert_num, intermediate_size * 2, self.hidden_size) + self.w1[1] = w1_pack.weight_scale.view(total_expert_num, intermediate_size * 2, self.hidden_size) + + w2_pack = self.quant_method.create_weight( + total_expert_num * self.hidden_size, intermediate_size, dtype=self.data_type_, device_id=device_id + ) + self.w2[0] = w2_pack.weight.view(total_expert_num, self.hidden_size, intermediate_size) + self.w2[1] = w2_pack.weight_scale.view(total_expert_num, self.hidden_size, intermediate_size) + else: + # Regular weights + self.w1[0] = torch.empty( + (total_expert_num, intermediate_size * 2, self.hidden_size), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + self.w2[0] = torch.empty( + (total_expert_num, self.hidden_size, intermediate_size), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + def experts( self, input_tensor, @@ -422,8 +471,12 @@ def _fuse(self): inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self.quantized_weight and self.quant_method is not None: - self.w1 = self.quant_method.quantize(w1) - self.w2 = self.quant_method.quantize(w2) + qw1_pack = self.quant_method.quantize(w1) + qw2_pack = self.quant_method.quantize(w2) + self.w1[0] = qw1_pack.weight + self.w1[1] = qw1_pack.weight_scale + self.w2[0] = qw2_pack.weight + self.w2[1] = qw2_pack.weight_scale else: self.w1[0] = self._cuda(w1) self.w2[0] = self._cuda(w2) @@ -465,38 +518,74 @@ def _fuse_weight_scale(self): def load_hf_weights(self, weights): n_expert_ep = self.ep_n_routed_experts - # tp to ep here + + # Load bias if self.e_score_correction_bias_name in weights: self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) + # Get weight shapes from first expert to determine intermediate size + first_expert_idx = 0 + n_expert_ep * self.global_rank_ + w1_weight_name = f"{self.weight_prefix}.{first_expert_idx}.{self.w1_weight_name}.weight" + if w1_weight_name in weights: + intermediate_size = weights[w1_weight_name].shape[0] + self.intermediate_size = intermediate_size + + # Re-create weights with correct size if needed + if self.w1[0].shape[1] != intermediate_size * 2: + self._create_weight() + + # Load regular experts for i_experts_ep in range(n_expert_ep): i_experts = i_experts_ep + n_expert_ep * self.global_rank_ - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - if w1_weight in weights: - self.experts_gate_projs[i_experts_ep] = weights[w1_weight] - if w3_weight in weights: - self.experts_up_projs[i_experts_ep] = weights[w3_weight] - if w2_weight in weights: - self.w2_list[i_experts_ep] = weights[w2_weight] - - # Load weight parameters for redundant experts + self._copy_expert_weights(i_experts_ep, i_experts, weights) + + # Load redundant experts for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): - i_experts = redundant_expert_id - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - if w1_weight in weights: - self.experts_gate_projs[n_expert_ep + i] = weights[w1_weight] - if w3_weight in weights: - self.experts_up_projs[n_expert_ep + i] = weights[w3_weight] - if w2_weight in weights: - self.w2_list[n_expert_ep + i] = weights[w2_weight] + self._copy_expert_weights(n_expert_ep + i, redundant_expert_id, weights) if self.quantized_weight: - self._load_weight_scale(weights) - self._fuse() + self._load_weight_scale_direct(weights) + + def _copy_expert_weights(self, target_idx, expert_id, weights): + """Copy a single expert's weights to pre-allocated GPU memory""" + w1_weight = f"{self.weight_prefix}.{expert_id}.{self.w1_weight_name}.weight" + w2_weight = f"{self.weight_prefix}.{expert_id}.{self.w2_weight_name}.weight" + w3_weight = f"{self.weight_prefix}.{expert_id}.{self.w3_weight_name}.weight" + + intermediate_size = self.intermediate_size + + if w1_weight in weights and w3_weight in weights: + # Combine gate and up projections into w1 + gate_weight = weights[w1_weight] # [intermediate_size, hidden_size] + up_weight = weights[w3_weight] # [intermediate_size, hidden_size] + + # Copy to pre-allocated memory + if not self.quantized_weight and self.quant_method is not None: + # Quantized path + combined_cpu = torch.empty((intermediate_size * 2, self.hidden_size), dtype=gate_weight.dtype) + combined_cpu[:intermediate_size, :] = gate_weight + combined_cpu[intermediate_size:, :] = up_weight + quantized_pack = self.quant_method.quantize(combined_cpu) + self.w1[0][target_idx].copy_(quantized_pack.weight.view(intermediate_size * 2, self.hidden_size)) + if quantized_pack.weight_scale is not None: + self.w1[1][target_idx].copy_( + quantized_pack.weight_scale.view(intermediate_size * 2, self.hidden_size) + ) + else: + # Regular path + self.w1[0][target_idx, :intermediate_size, :].copy_(gate_weight) + self.w1[0][target_idx, intermediate_size:, :].copy_(up_weight) + + if w2_weight in weights: + # Copy w2 (down projection) + w2_weight_tensor = weights[w2_weight] # [hidden_size, intermediate_size] - already the correct shape + if not self.quantized_weight and self.quant_method is not None: + quantized_pack = self.quant_method.quantize(w2_weight_tensor) + self.w2[0][target_idx].copy_(quantized_pack.weight) + if quantized_pack.weight_scale is not None: + self.w2[1][target_idx].copy_(quantized_pack.weight_scale) + else: + self.w2[0][target_idx].copy_(w2_weight_tensor) def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: n_expert_ep = self.ep_n_routed_experts @@ -526,6 +615,41 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: if w2_scale in weights: self.w2_scale_list[n_expert_ep + i] = weights[w2_scale] + def _load_weight_scale_direct(self, weights: Dict[str, torch.Tensor]) -> None: + """Load weight scales directly to pre-allocated GPU memory""" + n_expert_ep = self.ep_n_routed_experts + + # Load regular expert scales + for i_experts_ep in range(n_expert_ep): + i_experts = i_experts_ep + n_expert_ep * self.global_rank_ + self._copy_expert_scales(i_experts_ep, i_experts, weights) + + # Load redundant expert scales + for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): + self._copy_expert_scales(n_expert_ep + i, redundant_expert_id, weights) + + def _copy_expert_scales(self, target_idx, expert_id, weights): + """Copy a single expert's weight scales to pre-allocated GPU memory""" + w1_scale = f"{self.weight_prefix}.{expert_id}.{self.w1_weight_name}.{self.weight_scale_suffix}" + w2_scale = f"{self.weight_prefix}.{expert_id}.{self.w2_weight_name}.{self.weight_scale_suffix}" + w3_scale = f"{self.weight_prefix}.{expert_id}.{self.w3_weight_name}.{self.weight_scale_suffix}" + + intermediate_size = self.intermediate_size + + if w1_scale in weights and w3_scale in weights: + # Combine gate and up projection scales into w1 scale + gate_scale = weights[w1_scale] # [intermediate_size, hidden_size] + up_scale = weights[w3_scale] # [intermediate_size, hidden_size] + + # Copy to pre-allocated memory + self.w1[1][target_idx, :intermediate_size, :].copy_(gate_scale) + self.w1[1][target_idx, intermediate_size:, :].copy_(up_scale) + + if w2_scale in weights: + # Copy w2 scale (down projection) + w2_scale_tensor = weights[w2_scale] # [hidden_size, intermediate_size] + self.w2[1][target_idx].copy_(w2_scale_tensor) + def _cuda(self, cpu_tensor): device_id = get_current_device_id() if self.quantized_weight: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py similarity index 96% rename from lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py index 5558070a2b..933a94f78c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py @@ -102,12 +102,15 @@ def _fuse(self): inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self._ep_w.quantized_weight and self._ep_w.quant_method is not None: - self.w1 = self._ep_w.quant_method.quantize(w1) - self.w2 = self._ep_w.quant_method.quantize(w2) + qw1_pack = self._ep_w.quant_method.quantize(w1) + qw2_pack = self._ep_w.quant_method.quantize(w2) + self.w1[0] = qw1_pack.weight + self.w1[1] = qw1_pack.weight_scale + self.w2[0] = qw2_pack.weight + self.w2[1] = qw2_pack.weight_scale else: self.w1[0] = w1 self.w2[0] = w2 - delattr(self, "w2_list") delattr(self, "experts_up_projs") delattr(self, "experts_gate_projs") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py new file mode 100644 index 0000000000..bf7b218b71 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -0,0 +1,325 @@ +import os +import torch +import threading +from typing import Tuple, List, Dict, Any, Union, Callable +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeight +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id, get_dp_world_size +from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.quantize_method import WeightPack +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( + get_row_slice_mixin, + get_col_slice_mixin, +) + + +def create_tp_moe_wegiht_obj( + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, +) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]: + quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + if quant_method is not None and quant_method.method_name == "awq_marlin": + return FusedAWQMARLINMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) + else: + return FusedMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) + + +class FusedMoeWeightTP(BaseWeight): + def __init__( + self, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, + ) -> None: + super().__init__() + self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + self.quantized_weight = quant_cfg.quantized_weight + if self.quant_method.method_name != "none": + self.weight_scale_suffix = self.quant_method.weight_scale_suffix + + self.w1_weight_name = gate_proj_name + self.w2_weight_name = down_proj_name + self.w3_weight_name = up_proj_name + + self.e_score_correction_bias_name = e_score_correction_bias_name + self.weight_prefix = weight_prefix + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.n_routed_experts = n_routed_experts + num_fused_shared_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) + self.split_inter_size = split_inter_size + self.data_type_ = data_type + self.hidden_size = network_config.get("hidden_size") + self.tp_rank_ = get_current_rank_in_dp() + self.e_score_correction_bias = None + self.scoring_func = network_config.get("scoring_func", "softmax") + self.row_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() + ) + self.col_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() + ) + self._create_weight() + + def _create_weight(self): + total_expert_num = self.n_routed_experts + intermediate_size = self.split_inter_size + device_id = get_current_device_id() + + # Create e_score_correction_bias + if self.e_score_correction_bias is not None: + self.e_score_correction_bias = torch.empty( + (total_expert_num,), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + + self.w13: WeightPack = self.quant_method.create_weight( + out_dim=intermediate_size * 2, + in_dim=self.hidden_size, + dtype=self.data_type_, + device_id=device_id, + num_experts=total_expert_num, + ) + self.w2: WeightPack = self.quant_method.create_weight( + out_dim=self.hidden_size, + in_dim=intermediate_size, + dtype=self.data_type_, + device_id=device_id, + num_experts=total_expert_num, + ) + + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=self.e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=self.scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts - self.num_fused_shared_experts, + end=self.n_routed_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + + w13, w13_scale = self.w13.weight, self.w13.weight_scale + w2, w2_scale = self.w2.weight, self.w2.weight_scale + use_fp8_w8a8 = self.quant_method.method_name != "none" + + from lightllm.common.fused_moe.grouped_fused_moe import fused_experts + + fused_experts( + hidden_states=input_tensor, + w1=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + return + + def _cuda(self, cpu_tensor): + device_id = get_current_device_id() + if self.quantized_weight: + return cpu_tensor.cuda(device_id) + return cpu_tensor.cuda(device_id) + + def verify_load(self): + return True + + def load_hf_weights(self, weights): + # Load bias + if self.e_score_correction_bias_name in weights: + self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + + # Load each expert with TP slicing + for i_experts in range(self.n_routed_experts): + self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) + if self.w13.weight_scale is not None: + self._load_expert(i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix) + if self.w13.weight_zero_point is not None: + self._load_expert( + i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix + ) + + def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): + if self.quant_method.weight_need_quanted(weight): + self.quant_method.quantize(weight, weight_pack, start_idx) + else: + self.quant_method.load_weight(weight, weight_pack, start_idx) + + def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): + w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{suffix}" + w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{suffix}" + w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{suffix}" + intermediate_size = self.split_inter_size + load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) + if w1_weight in weights: + load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) + if w3_weight in weights: + load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) + + load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) + if w2_weight in weights: + load_func(slice_func(weights[w2_weight]), self.w2.get_expert(expert_idx), start_idx=0) + + def _get_load_and_slice_func(self, type: str, is_row: bool = True): + if is_row: + slicer = self.row_slicer + else: + slicer = self.col_slicer + if type == "weight": + return self._load_weight_func, slicer._slice_weight + elif type == "weight_scale": + return getattr(self.quant_method, "load_weight_scale"), slicer._slice_weight_scale + elif type == "weight_zero_point": + return getattr(self.quant_method, "load_weight_zero_point"), slicer._slice_weight_zero_point + + +class FusedAWQMARLINMoeWeightTP(FusedMoeWeightTP): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + + assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + ) + + self.workspace = marlin_make_workspace_new(self.w13.weight.device, 4) + + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=self.e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=self.scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts - self.num_fused_shared_experts, + end=self.n_routed_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + + w1, w1_scale, w1_zero_point = self.w13.weight, self.w13.weight_scale, self.w13.weight_zero_point + w2, w2_scale, w2_zero_point = self.w2.weight, self.w2.weight_scale, self.w2.weight_zero_point + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + + fused_marlin_moe( + input_tensor, + w1, + w2, + None, + None, + w1_scale, + w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_method.vllm_quant_type.id, + apply_router_weight_on_input=False, + global_num_experts=-1, + expert_map=None, + w1_zeros=w1_zero_point, + w2_zeros=w2_zero_point, + workspace=self.workspace, + inplace=True, + ) + + return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py similarity index 99% rename from lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index df72cc6208..9d79ff7c25 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -3,7 +3,7 @@ import threading from typing import Optional, Tuple, List, Dict, Any -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_tp import FusedMoeWeightTP from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg from lightllm.utils.log_utils import init_logger diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py deleted file mode 100644 index 0449db3448..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ /dev/null @@ -1,665 +0,0 @@ -import os -import torch -import threading -from typing import Optional, Tuple, List, Dict, Any, Union -from .base_weight import BaseWeight -from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id -from lightllm.common.quantization import Quantcfg - - -def create_tp_moe_wegiht_obj( - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, -) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]: - quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - if quant_method is not None and quant_method.method_name == "awq_marlin": - return FusedAWQMARLINMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - else: - return FusedMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - - -class FusedMoeWeightTP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.quant_method.is_moe = True - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.tp_rank_ = get_current_rank_in_dp() - self.experts_up_projs = [None] * self.n_routed_experts - self.experts_gate_projs = [None] * self.n_routed_experts - self.experts_up_proj_scales = [None] * self.n_routed_experts - self.experts_gate_proj_scales = [None] * self.n_routed_experts - self.e_score_correction_bias = None - self.w2_list = [None] * self.n_routed_experts - self.w2_scale_list = [None] * self.n_routed_experts - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None] # weight, weight_scale - self.w2 = [None, None] # weight, weight_scale - self.lock = threading.Lock() - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - use_fp8_w8a8 = self.quant_method is not None - - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts - - fused_experts( - hidden_states=input_tensor, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) - return - - def _fuse(self): - if self.quantized_weight: - self._fuse_weight_scale() - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape - up_out_dim, up_in_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_projs[0].dtype - total_expert_num = self.n_routed_experts - - w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") - - for i_experts in range(self.n_routed_experts): - w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] - w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - if not self.quantized_weight and self.quant_method is not None: - self.w1 = self.quant_method.quantize(w1) - self.w2 = self.quant_method.quantize(w2) - else: - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape - up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_proj_scales[0].dtype - total_expert_num = self.n_routed_experts - - w1_scale = torch.empty( - (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" - ) - - for i_experts in range(self.n_routed_experts): - w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale) - self.w2[1] = self._cuda(w2_scale) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def load_hf_weights(self, weights): - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - if self.quant_method is not None: - self._load_weight_scale(weights) - self._fuse() - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - block_size = 1 - if hasattr(self.quant_method, "block_size"): - block_size = self.quant_method.block_size - for i_experts in range(self.n_routed_experts): - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - :, - ] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts] = weights[w3_scale][ - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - :, - ] - - if w2_scale in weights: - self.w2_scale_list[i_experts] = weights[w2_scale][ - :, - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - ] - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.contiguous().cuda(device_id) - return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None - - -class FusedAWQMARLINMoeWeightTP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.weight_zero_point_suffix = self.quant_method.weight_zero_point_suffix - self.quant_method.is_moe = True - hf_quantization_config = network_config.get("quantization_config", None) - self.num_bits = hf_quantization_config.get("bits", 4) - self.group_size = hf_quantization_config.get("group_size", 128) - self.pack_factor = 32 // self.num_bits - self.has_processed_weight = False - assert self.quant_method.method_name == "awq_marlin" - - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.tp_rank_ = get_current_rank_in_dp() - self.experts_up_projs = [None] * self.n_routed_experts - self.experts_gate_projs = [None] * self.n_routed_experts - self.experts_up_proj_scales = [None] * self.n_routed_experts - self.experts_up_proj_zero_points = [None] * self.n_routed_experts - self.experts_gate_proj_scales = [None] * self.n_routed_experts - self.experts_gate_proj_zero_points = [None] * self.n_routed_experts - self.e_score_correction_bias = None - self.w2_list = [None] * self.n_routed_experts - self.w2_scale_list = [None] * self.n_routed_experts - self.w2_zero_point_list = [None] * self.n_routed_experts - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None, None] # weight, weight_scale, zero_point - self.w2 = [None, None, None] # weight, weight_scale, zero_point - self.lock = threading.Lock() - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - - w1, w1_scale, w1_zero_point = self.w1 - w2, w2_scale, w2_zero_point = self.w2 - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe - - fused_marlin_moe( - input_tensor, - w1, - w2, - None, - None, - w1_scale, - w2_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=self.quant_method.vllm_quant_type.id, - apply_router_weight_on_input=False, - global_num_experts=-1, - expert_map=None, - w1_zeros=w1_zero_point, - w2_zeros=w2_zero_point, - workspace=self.workspace, - inplace=True, - ) - - return - - def _fuse(self): - self._fuse_weight() - self._fuse_weight_scale() - self._fuse_weight_zero_point() - - def _fuse_weight(self): - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_projs[0].shape - up_in_dim, up_out_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - - w1 = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" - ) - - for i_experts in range(self.n_routed_experts): - w1[i_experts, :, 0:gate_out_dim] = self.experts_gate_projs[i_experts] - w1[i_experts, :, gate_out_dim:] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_proj_scales[0].shape - up_in_dim, up_out_dim = self.experts_up_proj_scales[0].shape - dtype = self.experts_gate_proj_scales[0].dtype - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - w1_scale = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=dtype, device="cpu" - ) - for i_experts in range(self.n_routed_experts): - w1_scale[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, :, gate_out_dim:] = self.experts_up_proj_scales[i_experts] - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale).to(self.data_type_) - self.w2[1] = self._cuda(w2_scale).to(self.data_type_) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def _fuse_weight_zero_point(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_zero_points") - and None not in self.experts_up_proj_zero_points - and None not in self.experts_gate_proj_zero_points - and None not in self.w2_zero_point_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_proj_zero_points[0].shape - up_in_dim, up_out_dim = self.experts_up_proj_zero_points[0].shape - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - w1_zero_point = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" - ) - for i_experts in range(self.n_routed_experts): - w1_zero_point[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_zero_points[i_experts] - w1_zero_point[i_experts, :, gate_out_dim:] = self.experts_up_proj_zero_points[i_experts] - inter_shape, hidden_size = self.w2_zero_point_list[0].shape[0], self.w2_zero_point_list[0].shape[1] - w2_zero_point = torch._utils._flatten_dense_tensors(self.w2_zero_point_list).view( - len(self.w2_zero_point_list), inter_shape, hidden_size - ) - self.w1[2] = self._cuda(w1_zero_point) - self.w2[2] = self._cuda(w2_zero_point) - delattr(self, "w2_zero_point_list") - delattr(self, "experts_up_proj_zero_points") - delattr(self, "experts_gate_proj_zero_points") - - def load_hf_weights(self, weights): - self._load_weight(weights) - self._load_weight_scale(weights) - self._load_weight_zero_point(weights) - self._fuse() - self._process_weight_after_loading() - - def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: - # awq quantization weight shape: in x out - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.qweight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.qweight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.qweight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - for i_experts in range(self.n_routed_experts): - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - split_inter_size = self.split_inter_size * self.pack_factor - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ - :, - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - ] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts] = weights[w3_scale][ - :, - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - ] - - if w2_scale in weights: - self.w2_scale_list[i_experts] = weights[w2_scale][ - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - :, - ] - - def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: - for i_experts in range(self.n_routed_experts): - w1_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_zero_point_suffix}" - w2_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_zero_point_suffix}" - w3_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_zero_point_suffix}" - if w1_zero_point in weights: - self.experts_gate_proj_zero_points[i_experts] = weights[w1_zero_point][ - :, - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - ] - if w3_zero_point in weights: - self.experts_up_proj_zero_points[i_experts] = weights[w3_zero_point][ - :, - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - ] - if w2_zero_point in weights: - self.w2_zero_point_list[i_experts] = weights[w2_zero_point][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - :, - ] - - def _process_weight_after_loading(self): - with self.lock: - if None in self.w1 or None in self.w2 or self.has_processed_weight: - return - self.has_processed_weight = True - from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - - assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" - - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_moe_permute_scales, - moe_awq_to_marlin_zero_points, - marlin_make_workspace_new, - ) - - num_experts = self.n_routed_experts - device = self.w1[0].device - - self.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - self.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - self.w1[0] = vllm_ops.awq_marlin_moe_repack( - self.w1[0], - self.w13_g_idx_sort_indices, - size_k=self.w1[0].shape[1], - size_n=self.w1[0].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.w2[0] = vllm_ops.awq_marlin_moe_repack( - self.w2[0], - self.w2_g_idx_sort_indices, - size_k=self.w2[0].shape[1], - size_n=self.w2[0].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - # Why does this take the intermediate size for size_k? - self.w1[1] = marlin_moe_permute_scales( - s=self.w1[1], - size_k=self.split_inter_size * self.pack_factor, - size_n=self.w1[1].shape[2], - group_size=self.group_size, - ) - - self.w2[1] = marlin_moe_permute_scales( - s=self.w2[1], - size_k=self.split_inter_size * self.pack_factor, - size_n=self.w2[1].shape[2], - group_size=self.group_size, - ) - - self.w1[2] = moe_awq_to_marlin_zero_points( - self.w1[2], - size_k=self.w1[2].shape[1], - size_n=self.w1[2].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.w2[2] = moe_awq_to_marlin_zero_points( - self.w2[2], - size_k=self.w2[2].shape[1], - size_n=self.w2[2].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.workspace = marlin_make_workspace_new(device, 4) - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.cuda(device_id) - return cpu_tensor.cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index 63605b1774..34d989b01f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,10 +1,5 @@ from .mm_weight import ( - MMWeightPack, MMWeightTpl, ) -from .mm_factory import ( - MMWeight, - ROWMMWeight, - ROWBMMWeight, - COLMMWeight, -) +from .rowmm_weight import ROWMMWeight, ROWBMMWeight +from .colmm_weight import COLMMWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index 281f30f022..bf73b9ad89 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -1,19 +1,19 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( MMWeightTpl, - DeepGemmFP8W8A8B128MMWeight, - AWQMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import ColSliceMixin, QuantizedColSliceMixin, AwqQuantizedColSliceMixin +from .mm_slicer import get_col_slice_mixin -class StandardCOLMMWeight(MMWeightTpl): +class COLMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -22,6 +22,8 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, data_type=data_type, bias_names=bias_names, @@ -29,74 +31,6 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = ColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQCOLMMWeight(AWQMMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.param_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=tp_rank, tp_world_size=tp_world_size ) - # 注意这里不是错误,因为awq的weight是按inxout存的 - self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQMARLINCOLMMWeight(AWQCOLMMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - -COLMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight, - "awq": AWQCOLMMWeight, - "awq_marlin": AWQMARLINCOLMMWeight, -} diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py deleted file mode 100644 index 464de84413..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py +++ /dev/null @@ -1,90 +0,0 @@ -from lightllm.common.quantization import Quantcfg -from lightllm.common.quantization.quantize_method import QuantizationMethod -from typing import Type, Union, Dict -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeightTpl, - BMMWeightTpl, -) -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ( - StandardROWMMWeight, - UnquantizedROWBMMWeight, - ROWMM_WEIGHT_CLS_MAP, -) -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import ( - StandardCOLMMWeight, - COLMM_WEIGHT_CLS_MAP, -) - - -class MMWeight: - def __new__(cls, **kwargs): - """ - weight_names, - data_type, - bias_names, - quant_cfg, - layer_num, - name, - tp_rank, - tp_world_size, - ... - 该类主要是通过重载 __new__ 为对应的mm权重绑定量化方法,其他参数都是透传。 - """ - - quant_cfg = kwargs.pop("quant_cfg", None) - layer_num_ = kwargs.pop("layer_num", None) - name = kwargs.pop("name", None) - quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name) - # quantized_weight 本身是用来标识权重本身在文件中是否是以量化后的形式存储, - # 现在不再使用该参数,是否量化由后续的加载过程自动识别。 - kwargs["quant_method"] = quant_method - mmcls = cls._get_mmcls(quant_method) - return mmcls(**kwargs) - - @classmethod - def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod: - if quant_cfg is None: - return None, False - quant_method: QuantizationMethod = quant_cfg.get_quant_method(layer_num_, name) - if quant_method is None: - return None, False - quant_method.hf_quantization_config = quant_cfg.hf_quantization_config - quantized_weight = quant_cfg.quantized_weight - return quant_method, quantized_weight - - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod) -> Type[Union[MMWeightTpl, BMMWeightTpl]]: - raise NotImplementedError("Subclasses must implement _get_mmcls method") - - -class ROWMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return StandardROWMMWeight - - return ROWMM_WEIGHT_CLS_MAP.get( - quant_method.method_name, - StandardROWMMWeight, - ) - - -class ROWBMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return UnquantizedROWBMMWeight - else: - # TODO: Implement more quantization weight - raise NotImplementedError("ROWBMMWeight is not implemented") - - -class COLMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return StandardCOLMMWeight - return COLMM_WEIGHT_CLS_MAP.get( - quant_method.method_name, - StandardCOLMMWeight, - ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index e3ef5b0ea3..e2830ab611 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -132,3 +132,21 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None): def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ + + +def get_row_slice_mixin(quant_method_name: str, tp_rank: int = None, tp_world_size: int = None) -> SliceMixinTpl: + if quant_method_name.startswith("awq"): + return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size) + elif quant_method_name == "none": + return RowSliceMixin(tp_rank, tp_world_size) + else: + return QuantizedRowSliceMixin(tp_rank, tp_world_size) + + +def get_col_slice_mixin(quant_method_name: str, tp_rank: int = None, tp_world_size: int = None) -> SliceMixinTpl: + if quant_method_name.startswith("awq"): + return AwqQuantizedColSliceMixin(tp_rank, tp_world_size) + elif quant_method_name == "none": + return ColSliceMixin(tp_rank, tp_world_size) + else: + return QuantizedColSliceMixin(tp_rank, tp_world_size) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index e603032ecd..014cf2ec28 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -5,9 +5,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, List, Dict, Union, Type from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.no_quant import NoQuantization from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.log_utils import init_logger from .mm_slicer import SliceMixinTpl @@ -15,53 +16,11 @@ logger = init_logger(__name__) -@dataclass -class MMWeightPack: - weight: Optional[torch.Tensor] = None - bias: Optional[torch.Tensor] = None - weight_scale: Optional[torch.Tensor] = None - weight_zero_point: Optional[torch.Tensor] = None - - has_bias: bool = False - has_weight_scale: bool = False - has_weight_zero_point: bool = False - - def is_ready(self) -> bool: - return ( - self.weight is not None - and (not self.has_bias or (self.has_bias and self.bias is not None)) - and (not self.has_weight_scale or (self.has_weight_scale and self.weight_scale is not None)) - and (not self.has_weight_zero_point or (self.has_weight_zero_point and self.weight_zero_point is not None)) - ) - - def ready_for_fused_merge(self) -> bool: - """ - 判断权重是否满足可以和其他权重进行融合cat的条件,因为可能权重是量化和非量化后的权重,所以复杂一些。 - """ - weight_ready = self.weight is not None and self.weight.dtype in [ - torch.bfloat16, - torch.float16, - torch.float32, - torch.float64, - ] - bias_ready = (self.has_bias and self.bias is not None) or (not self.has_bias) - if weight_ready and bias_ready: - return True - else: - return self.is_ready() - - def is_load_finished(self): - return ( - (self.is_ready() and self.weight.is_cuda) - and ((self.has_bias and self.bias.is_cuda) or (not self.has_bias)) - and ((self.has_weight_scale and self.weight_scale.is_cuda) or (not self.has_weight_scale)) - and ((self.has_weight_zero_point and self.weight_zero_point.is_cuda) or (not self.has_weight_zero_point)) - ) - - class MMWeightTpl(BaseWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], bias_names: Optional[Union[str, List[str]]], data_type: torch.dtype, @@ -72,6 +31,14 @@ def __init__( super().__init__(tp_rank, tp_world_size, data_type) self.lock = threading.Lock() + self.in_dim = in_dim + if isinstance(out_dims, int): + out_dims = [out_dims] + self.out_dims = out_dims + self.cusum_out_dims = [0] + for out_dim in out_dims[:-1]: + self.cusum_out_dims.append(self.cusum_out_dims[-1] + out_dim) + if isinstance(weight_names, str): weight_names = [weight_names] if isinstance(bias_names, str): @@ -82,60 +49,29 @@ def __init__( if bias_names[0] is None: bias_names = None - if quant_method is not None: - has_weight_scale = quant_method.has_weight_scale - has_weight_zero_point = quant_method.has_weight_zero_point - else: - has_weight_scale = False - has_weight_zero_point = False - # 同时存在 weight_names 和 quanted_weight_names 是为了兼容在线和离线两种加载方案 self.weight_names = weight_names - self.bias_names = bias_names - has_bias = self.bias_names is not None - - self.gen_weight_quant_param_names(quant_method=quant_method) - self.quant_method = quant_method - self.sub_child_mm_params: List[MMWeightPack] = [ - MMWeightPack( - has_bias=has_bias, - has_weight_scale=has_weight_scale, - has_weight_zero_point=has_weight_zero_point, - ) - for _ in range(len(weight_names)) - ] - self.mm_param: MMWeightPack = MMWeightPack( - has_bias=has_bias, - has_weight_scale=has_weight_scale, - has_weight_zero_point=has_weight_zero_point, - ) + self.quant_method: QuantizationMethod = NoQuantization() if quant_method is None else quant_method self.param_slicer: SliceMixinTpl = None + self._create_weight() + self.gen_weight_quant_param_names(quant_method=quant_method) - self.weight_fused_dim = 0 - self.bias_fused_dim = 0 - self.weight_scale_and_zero_point_fused_dim = 0 - - self.load_finished: bool = False + def _create_weight(self): + self.bias = None + if self.bias_names is not None: + self.bias = torch.empty(self.cusum_out_dims[-1], dtype=self.data_type_).cuda(get_current_device_id()) + self.mm_param: WeightPack = self.quant_method.create_weight( + in_dim=self.in_dim, out_dim=sum(self.out_dims), dtype=self.data_type_, device_id=get_current_device_id() + ) + return def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: - if self.quant_method is not None: - return self.quant_method.apply( - input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger - ) - if out is None: - shape = (input_tensor.shape[0], self.mm_param.weight.shape[1]) - dtype = input_tensor.dtype - device = input_tensor.device - if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) - else: - out = torch.empty(shape, dtype=dtype, device=device) - if self.mm_param.bias is None: - return torch.mm(input_tensor, self.mm_param.weight, out=out) - return torch.addmm(self.mm_param.bias, input_tensor, self.mm_param.weight, out=out) + return self.quant_method.apply( + input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger, bias=self.bias + ) def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod]): if quant_method is None: @@ -176,8 +112,6 @@ def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod return def load_hf_weights(self, weights): - if self.mm_param.is_load_finished(): - return for sub_child_index, param_name in enumerate(self.weight_names): self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) @@ -196,51 +130,8 @@ def load_hf_weights(self, weights): for sub_child_index, param_name in enumerate(self.weight_zero_point_names): self._load_weight_zero_point(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - with self.lock: - # 如果需要fused的请求,全部ok了以后进行merge操作。, all([]) 竟然返回是True, 需要len(self.sub_child_mm_params) > 0 的额外判断。 - if len(self.sub_child_mm_params) > 0 and all(e.ready_for_fused_merge() for e in self.sub_child_mm_params): - self._fuse_weights() - self.sub_child_mm_params.clear() - - # 在线量化操作 - if ( - self.quant_method is not None - and self.mm_param.weight is not None - and self.quant_method.weight_need_quanted(self.mm_param.weight) - and self.load_finished is False - ): - logger.info(f"online quant weight names: {self.weight_names}") - quantized_weight, weight_scale, weight_zero_point = self.quant_method.quantize( - self.mm_param.weight.cuda(get_current_device_id()) - ) - self.mm_param.weight = quantized_weight - self.mm_param.weight_scale = weight_scale - self.mm_param.weight_zero_point = weight_zero_point - - # repack 操作 - if ( - self.quant_method is not None - and self.mm_param.is_ready() - and self.quant_method.params_need_repack() - and self.load_finished is False - ): - ( - self.mm_param.weight, - self.mm_param.weight_scale, - self.mm_param.weight_zero_point, - ) = self.quant_method.params_repack( - weight=self.mm_param.weight, - weight_scale=self.mm_param.weight_scale, - weight_zero_point=self.mm_param.weight_zero_point, - dtype_type=self.data_type_, - ) - - if self.mm_param.is_ready() and self.load_finished is False: - self._to_gpu_device() - self.load_finished = True - def verify_load(self) -> bool: - return self.mm_param.is_ready() + return True # 执行顺序 def _load_weight( @@ -248,7 +139,11 @@ def _load_weight( ) -> None: if param_name in weights: weight = self.param_slicer._slice_weight(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight = weight + start_idx = self.cusum_out_dims[sub_child_index] + if self.quant_method.weight_need_quanted(weight): + self.quant_method.quantize(weight, self.mm_param, offset=start_idx) + else: + self.quant_method.load_weight(weight, self.mm_param, start_idx) return def _load_bias( @@ -256,7 +151,9 @@ def _load_bias( ) -> None: if param_name in weights: bias = self.param_slicer._slice_bias(weights[param_name]) - self.sub_child_mm_params[sub_child_index].bias = bias + start_idx = self.cusum_out_dims[sub_child_index] + end_idx = start_idx + bias.shape[0] + self.mm_param.bias[start_idx:end_idx].copy_(bias) return def _load_weight_scale( @@ -264,7 +161,8 @@ def _load_weight_scale( ) -> None: if param_name in weights: weight_scale = self.param_slicer._slice_weight_scale(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight_scale = weight_scale + start_idx = self.cusum_out_dims[sub_child_index] + self.quant_method.load_weight_scale(weight_scale, self.mm_param, start_idx) return def _load_weight_zero_point( @@ -272,88 +170,8 @@ def _load_weight_zero_point( ) -> None: if param_name in weights: weight_zero_point = self.param_slicer._slice_weight_zero_point(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight_zero_point = weight_zero_point - return - - # weight merge - def _fuse_weights(self) -> None: - need_merge = len(self.sub_child_mm_params) > 1 - if self.mm_param.weight is None and all(p.weight is not None for p in self.sub_child_mm_params): - if need_merge: - weight = torch.cat([p.weight for p in self.sub_child_mm_params], dim=self.weight_fused_dim) - else: - weight = self.sub_child_mm_params[0].weight - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight = None - - self.mm_param.weight = weight - - if ( - self.mm_param.has_bias - and self.mm_param.bias is None - and all(p.bias is not None for p in self.sub_child_mm_params) - ): - if need_merge: - bias = torch.cat([p.bias for p in self.sub_child_mm_params], dim=self.bias_fused_dim) - else: - bias = self.sub_child_mm_params[0].bias - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.bias = None - - self.mm_param.bias = bias - - if self.mm_param.weight_scale is None and all(p.weight_scale is not None for p in self.sub_child_mm_params): - if need_merge: - weight_scale = torch.cat( - [p.weight_scale for p in self.sub_child_mm_params], dim=self.weight_scale_and_zero_point_fused_dim - ) - else: - weight_scale = self.sub_child_mm_params[0].weight_scale - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight_scale = None - - self.mm_param.weight_scale = weight_scale - - if self.mm_param.weight_zero_point is None and all( - p.weight_zero_point is not None for p in self.sub_child_mm_params - ): - if need_merge: - weight_zero_point = torch.cat( - [p.weight_zero_point for p in self.sub_child_mm_params], - dim=self.weight_scale_and_zero_point_fused_dim, - ) - else: - weight_zero_point = self.sub_child_mm_params[0].weight_zero_point - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight_zero_point = None - - self.mm_param.weight_zero_point = weight_zero_point - return - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - if self.quant_method is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - else: - # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.mm_param.weight = ( - self.mm_param.weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1) - ) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) + start_idx = self.cusum_out_dims[sub_child_index] + self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param, start_idx) return @@ -376,90 +194,6 @@ def bmm( out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) else: out = torch.empty(shape, dtype=dtype, device=device) - if self.mm_param.bias is None: + if self.bias is None: return torch.bmm(input_tensor, fpweight, out=out) - return torch.addbmm(self.mm_param.bias, input_tensor, fpweight, out=out) - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - if self.quant_method is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - else: - # bmm 不需要 transpose 操作 - self.mm_param.weight = self.mm_param.weight.to(self.data_type_).cuda(get_current_device_id()) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return - - -class DeepGemmFP8W8A8B128MMWeight(MMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - bias_names=bias_names, - data_type=data_type, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()).transpose(0, 1) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()).transpose(0, 1) - - assert self.mm_param.has_weight_zero_point is False - - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return - - -class AWQMMWeightTpl(MMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - bias_names: Optional[Union[str, List[str]]] = None, - data_type: torch.dtype = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - bias_names=bias_names, - data_type=data_type, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - self.weight_fused_dim = 1 - self.bias_fused_dim = 0 - self.weight_scale_and_zero_point_fused_dim = 1 - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.to(self.data_type_).cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return + return torch.addbmm(self.bias, input_tensor, fpweight, out=out) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 0eebdc74d2..e53d643cec 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -1,20 +1,20 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( MMWeightTpl, - DeepGemmFP8W8A8B128MMWeight, - AWQMMWeightTpl, BMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, AwqQuantizedRowSliceMixin +from .mm_slicer import get_row_slice_mixin -class StandardROWMMWeight(MMWeightTpl): +class ROWMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -23,6 +23,8 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, bias_names=bias_names, data_type=data_type, @@ -30,32 +32,12 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class DeepGemmFP8W8A8B128ROWMMWeight(DeepGemmFP8W8A8B128MMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.param_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=tp_rank, tp_world_size=tp_world_size ) - self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - return -class UnquantizedROWBMMWeight(BMMWeightTpl): +class ROWBMMWeight(BMMWeightTpl): def __init__( self, weight_names: Union[str, List[str]], @@ -73,53 +55,5 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQROWMMWeight(AWQMMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQMARLINROWMMWeight(AWQROWMMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - -ROWMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128ROWMMWeight, - "awq": AWQROWMMWeight, - "awq_marlin": AWQMARLINROWMMWeight, -} + # bmm 不支持量化运算操作 + self.param_slicer = get_row_slice_mixin(quant_method_name="none", tp_rank=tp_rank, tp_world_size=tp_world_size) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 7ec672ab88..b92ec24cb9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -4,49 +4,59 @@ class NormWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type, bias_name=None): + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): super().__init__() + self.norm_dim = norm_dim self.weight_name = weight_name self.bias_name = bias_name self.data_type_ = data_type self.weight = None self.bias = None + self.is_weight_ready = False + self.is_bias_ready = False + self._create_weight() + + def _create_weight(self): + device = f"cuda:{get_current_device_id()}" + self.weight = torch.empty(self.norm_dim, dtype=self.data_type_, device=device) + self.bias = ( + torch.empty(self.norm_dim, dtype=self.data_type_, device=device) if self.bias_name is not None else None + ) def load_hf_weights(self, weights): if self.weight_name in weights: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) + self.weight.copy_(weights[self.weight_name]) + self.is_weight_ready = True if self.bias_name in weights: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) + self.bias.copy_(weights[self.bias_name]) + self.is_bias_ready = True def verify_load(self): - load_ok = True - # Verify weight. The weight must be not None. - load_ok = load_ok and self.weight is not None - # Verify bias. If bias_name is set, it must be not None. - if self.bias_name is not None: - load_ok = load_ok and self.bias is not None - return load_ok + return self.is_weight_ready and (self.bias_name is None or self.is_bias_ready) class GEMMANormWeight(NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name, data_type, bias_name) + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): + super().__init__(norm_dim, weight_name, data_type, bias_name) def load_hf_weights(self, weights): + # TODO: 这里直接 +1 会不会导致精度问题? 计算时要求 (1.0 + weight.float()) ? if self.weight_name in weights: - self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) + self.weight.copy_((weights[self.weight_name] + 1).to(self.data_type_)) + self.is_weight_ready = True class TpNormWeight(NormWeight): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): - super().__init__(weight_name, data_type, bias_name) - self.split_n_embed = split_n_embed + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): + super().__init__(norm_dim, weight_name, data_type, bias_name) def load_hf_weights(self, weights): - start = self.split_n_embed * self.tp_rank_ - end = self.split_n_embed * (self.tp_rank_ + 1) + start = self.norm_dim * self.tp_rank_ + end = self.norm_dim * (self.tp_rank_ + 1) if self.weight_name in weights: - self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + self.weight.copy_(weights[self.weight_name][start:end].to(self.data_type_)) + self.is_weight_ready = True if self.bias_name in weights: - self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + self.bias.copy_(weights[self.bias_name][start:end].to(self.data_type_)) + self.is_bias_ready = True diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 97bc762370..1889ceb391 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -4,6 +4,7 @@ from .base_layer_weight import BaseLayerWeight from .meta_weights import BaseWeight, MMWeightTpl from lightllm.utils.log_utils import init_logger +from lightllm.common.quantization import Quantcfg logger = init_logger(__name__) @@ -15,7 +16,7 @@ def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): self.data_type_ = data_type self.network_config_ = network_config self.mode = mode - self.quant_cfg = quant_cfg + self.quant_cfg: Quantcfg = quant_cfg self._parse_config() self._init_weight_names() self._init_weight() @@ -41,3 +42,6 @@ def load_hf_weights(self, weights): attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + + def get_quant_method(self, name): + return self.quant_cfg.get_quant_method(self.layer_num_, name) diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 26f59258cd..ecf2e6d42f 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -6,6 +6,7 @@ from .triton_quant.triton_quant import * from .deepgemm_quant import * from .awq_quant import * +from .no_quant import * from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -78,4 +79,6 @@ def get_quant_type(self, layer_num, name): def get_quant_method(self, layer_num, name): quant_type = self.get_quant_type(layer_num, name) - return QUANTMETHODS.get(quant_type) + quant_method = QUANTMETHODS.get(quant_type) + quant_method.hf_quantization_config = self.hf_quantization_config + return quant_method diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py index 8c04cdcea9..d523cce757 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/awq_quant.py @@ -9,8 +9,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from lightllm.utils.dist_utils import get_current_device_id -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack if HAS_VLLM: awq_dequantize = vllm_ops.awq_dequantize @@ -39,16 +38,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("AWQ online quantization is not supported yet.") @@ -72,21 +72,21 @@ def __init__(self): def method_name(self): return "awq" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point - bias = weight_pack.bias NEED_DEQUANT_WEIGHT = input_tensor.shape[:-1].numel() >= 256 if NEED_DEQUANT_WEIGHT: @@ -99,6 +99,33 @@ def apply( out.add_(bias) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + weight_pack.weight[:, start_idx : start_idx + weight.shape[1]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[:, start_idx : start_idx + weight_scale.shape[1]].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + end_idx = start_idx + weight_zero_point.shape[1] + weight_pack.weight_zero_point[:, start_idx:end_idx].copy_(weight_zero_point) + return + @QUANTMETHODS.register("awq_marlin") class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): @@ -115,20 +142,15 @@ def __init__(self): self.vllm_quant_type = TYPE_MAP[self.nbits] self.has_weight_scale = True self.has_weight_zero_point = True + self.tile_size = 16 @property def method_name(self): return "awq_marlin" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: raise NotImplementedError("AWQ online quantization is not supported yet.") - def params_need_repack(self) -> bool: - """ - 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 - """ - return True - def params_repack( self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -144,47 +166,18 @@ def params_repack( ) return weight, weight_scale, weight_zero_point - def _process_weight_after_loading(self, weight: torch.Tensor) -> torch.Tensor: - assert self.hf_quantization_config is not None, "hf_quantization_config is not set" - self.k = weight.shape[0] - self.n = weight.shape[1] * self.pack_factor - return vllm_ops.awq_marlin_repack( - weight, - size_k=weight.shape[0], - size_n=weight.shape[1] * self.pack_factor, - num_bits=self.hf_quantization_config["bits"], - ) - - def _process_weight_scale_after_loading(self, weight_scale: torch.Tensor) -> torch.Tensor: - assert self.hf_quantization_config is not None, "hf_quantization_config is not set" - group_size = self.hf_quantization_config["group_size"] - return marlin_permute_scales( - weight_scale, - size_k=weight_scale.shape[0] * group_size, - size_n=weight_scale.shape[1], - group_size=self.hf_quantization_config["group_size"], - ) - - def _process_weight_zero_point_after_loading(self, weight_zero_point: torch.Tensor) -> torch.Tensor: - return awq_to_marlin_zero_points( - weight_zero_point, - size_k=weight_zero_point.shape[0], - size_n=weight_zero_point.shape[1] * self.pack_factor, - num_bits=self.hf_quantization_config["bits"], - ) - def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point - bias = weight_pack.bias reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) use_atomic_add = should_use_atomic_add_reduce( @@ -219,6 +212,62 @@ def apply( out.add_(bias) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + self.n = out_dim + self.k = in_dim + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty( + expert_prefix + (in_dim // self.tile_size, out_dim * self.tile_size // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + device_id = get_current_device_id() + repack_weight = vllm_ops.awq_marlin_repack( + weight.cuda(device_id), + size_k=weight.shape[0], + size_n=weight.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + start_idx = start_idx // self.pack_factor * self.tile_size + weight_pack.weight[:, start_idx : start_idx + repack_weight.shape[1]].copy_(repack_weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + group_size = self.hf_quantization_config["group_size"] + device_id = get_current_device_id() + repack_weight_scale = marlin_permute_scales( + weight_scale.cuda(device_id), + size_k=weight_scale.shape[0] * group_size, + size_n=weight_scale.shape[1], + group_size=self.hf_quantization_config["group_size"], + ) + weight_pack.weight_scale[:, start_idx : start_idx + repack_weight_scale.shape[1]].copy_(repack_weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + device_id = get_current_device_id() + repack_weight_zero_point = awq_to_marlin_zero_points( + weight_zero_point.cuda(device_id), + size_k=weight_zero_point.shape[0], + size_n=weight_zero_point.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + start_idx = start_idx // self.pack_factor + weight_pack.weight_zero_point[:, start_idx : start_idx + repack_weight_zero_point.shape[1]].copy_( + repack_weight_zero_point + ) + return + # adapted from # https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212 diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index f566307808..86dd9b5729 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -1,5 +1,6 @@ import os import torch +from torch.types import Device from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS import torch.nn.functional as F @@ -9,8 +10,8 @@ ) from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack + try: HAS_DEEPGEMM = True import deep_gemm @@ -26,17 +27,17 @@ def __init__(self): self.cache_manager = g_cache_manager assert HAS_DEEPGEMM, "deepgemm is not installed, you can't use quant api of it" - def quantize(self, weight: torch.Tensor): - """ """ - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -60,25 +61,30 @@ def __init__(self): def method_name(self): return "deepgemm-fp8w8a8-b128" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - return weight_quant(weight, self.block_size) + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: "WeightPack", out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale input_scale = None alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty m, k = input_tensor.shape - n = qweight.shape[1] + n = qweight.shape[0] if input_scale is None: qinput_tensor, input_scale = per_token_group_quant_fp8( input_tensor, @@ -91,9 +97,35 @@ def apply( if out is None: out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight.t(), weight_scale.t()), out) + _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight, weight_scale), out) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[ + start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size + ].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_zero_point[ + start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size + ].copy_(weight_zero_point) + return + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py new file mode 100644 index 0000000000..f342607c10 --- /dev/null +++ b/lightllm/common/quantization/no_quant.py @@ -0,0 +1,52 @@ +from .quantize_method import QuantizationMethod, WeightPack +from .registry import QUANTMETHODS +import torch +from typing import Optional + + +@QUANTMETHODS.register("none") +class NoQuantization(QuantizationMethod): + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + weight = weight_pack.weight.t() + if out is None: + shape = (input_tensor.shape[0], weight.shape[1]) + dtype = input_tensor.dtype + device = input_tensor.device + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + else: + out = torch.empty(shape, dtype=dtype, device=device) + if bias is None: + return torch.mm(input_tensor, weight, out=out) + return torch.addmm(bias, input_tensor, weight, out=out) + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=dtype).cuda(device_id) + return WeightPack(weight=weight, weight_scale=None, weight_zero_point=None) + + def weight_need_quanted(self, weight: torch.Tensor) -> bool: + return False + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + return + + @property + def method_name(self): + return "none" + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0], :].copy_(weight) + return diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 9b629bcaf1..77e59465ee 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,38 +1,58 @@ import torch from abc import ABC, abstractmethod +from dataclasses import dataclass from lightllm.utils.dist_utils import get_current_device_id -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +@dataclass +class WeightPack: + weight: Optional[torch.Tensor] = None + weight_scale: Optional[torch.Tensor] = None + weight_zero_point: Optional[torch.Tensor] = None + + def get_expert(self, expert_idx: int): + assert self.weight.ndim == 3, f"weight must be a 3D tensor, but got {self.weight.ndim}" + weight = self.weight[expert_idx] + weight_scale = self.weight_scale[expert_idx] if self.weight_scale is not None else None + weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) class QuantizationMethod(ABC): def __init__(self): super().__init__() self.device_id_ = get_current_device_id() - self.weight_suffix = None + self.weight_suffix = "weight" self.weight_scale_suffix = None self.weight_zero_point_suffix = None self.act_scale_suffix = None self.has_weight_scale: bool = None self.has_weight_zero_point: bool = None + self.group_size: int = -1 # -1表示不分组即per-channel量化,其他表示分组大小 + self.pack_factor: int = 1 + # 一些量化模式需要用到的额外量化参数,如awq量化 self.hf_quantization_config = None @abstractmethod - def quantize(self, weights: torch.Tensor): + def quantize( + self, + weight: torch.Tensor, + output: WeightPack, + offset: int = 0, + ) -> None: pass @abstractmethod def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - bias: Optional[torch.Tensor] = None, + weight_pack: "WeightPack", out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass @@ -41,20 +61,26 @@ def apply( def method_name(self): pass + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + pass + def weight_need_quanted(self, weight: torch.Tensor) -> bool: # 判断一个 weight 是否需要进行量化操作。 return weight.dtype in [torch.bfloat16, torch.float16, torch.float32, torch.float64] - def params_need_repack(self) -> bool: - """ - 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 - """ - return False - - def params_repack( - self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - 一些量化方法在将参数完成量化后,为了加速性能,还需要将参数进行重拍,使算子性能达到最优,如awq方法。 - """ - return weight, weight_scale, weight_zero_point + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight" + ) + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight scale" + ) + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight zero point" + ) diff --git a/lightllm/common/quantization/registry.py b/lightllm/common/quantization/registry.py index 674a22b60f..e9b4073987 100644 --- a/lightllm/common/quantization/registry.py +++ b/lightllm/common/quantization/registry.py @@ -1,5 +1,4 @@ from .quantize_method import QuantizationMethod -from typing import Type class QuantMethodFactory: @@ -17,9 +16,7 @@ def decorator(cls): return decorator - def get(self, key, *args, **kwargs) -> Type[QuantizationMethod]: - if key == "none": - return None + def get(self, key, *args, **kwargs) -> "QuantizationMethod": quant_method_class = self._quant_methods.get(key) if not quant_method_class: raise ValueError(f"QuantMethod '{key}' not supported.") diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index ba4115b1d9..d1db65b35a 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -5,8 +5,7 @@ import torch.nn.functional as F from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack try: HAS_TORCH_AO = True @@ -34,17 +33,17 @@ def __init__(self): assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" self.quant_func = None - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: """ """ dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) quantize_(dummy_linear, self.quant_func) - return dummy_linear.weight, None, None + return WeightPack(weight=dummy_linear.weight, weight_scale=None, weight_zero_point=None) def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py index 11c1897d76..3881cfe4b8 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py @@ -55,4 +55,4 @@ def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, return y_quant, s_scales else: y_quant, s_scales = mm_weight_quant(x, block_size) - return y_quant.t(), s_scales.t() + return y_quant, s_scales diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py new file mode 100644 index 0000000000..7c76e82c9e --- /dev/null +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py @@ -0,0 +1,471 @@ +import torch +import triton +import triton.language as tl + +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Any, Dict, List, Optional, Tuple +from triton import Config +from lightllm.common.triton_utils.autotuner import autotune +from lightllm.utils.device_utils import triton_support_tensor_descriptor, is_5090_gpu + + +class Fp8ScaledMMKernelConfig(KernelConfigs): + kernel_name: str = "fp8_scaled_mm_per_token" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + M: int, + N: int, + K: int, + out_dtype: str, + ) -> dict: + key_params = { + "N": N, + "K": K, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + # find by M + config: dict = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] + return config + else: + config = { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 64, + "GROUP_M": 8, + "num_warps": 4, + "num_stages": 3, + } + return config + + @classmethod + def save_config(cls, N: int, K: int, out_dtype: str, config_json: Dict[int, Dict[int, Dict]]): + + key_params = { + "N": N, + "K": K, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) + + +@triton.jit +def grouped_launch(pid, m_block_num, n_block_num, group_m: tl.constexpr): + + num_pid_in_group = group_m * n_block_num + group_id = pid // num_pid_in_group + first_pid_m = group_id * group_m + group_size_m = tl.minimum(m_block_num - first_pid_m, group_m) + in_group_index = pid % num_pid_in_group + + # Swizzle pattern: zigzag traversal + back_mark = (in_group_index // group_size_m) % 2 + back_mark1 = -1 * (2 * back_mark - 1) + pid_m = first_pid_m + back_mark * (group_size_m - 1) + back_mark1 * (in_group_index % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n + + +@triton.jit +def _scaled_mm_per_token( + A, + A_desc: "tl.core.tensor_descriptor", + B, + B_desc: "tl.core.tensor_descriptor", + out, + out_desc: "tl.core.tensor_descriptor", + Ascale, + Bscale, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + USE_TMA: tl.constexpr, + B_IS_TRANS: tl.constexpr, + NEED_N_MASK: tl.constexpr, + NEED_K_MASK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + m_block_num = tl.cdiv(M, BLOCK_M) + n_block_num = tl.cdiv(N, BLOCK_N) + pid_m, pid_n = grouped_launch(pid, m_block_num, n_block_num, GROUP_M) + + start_m = pid_m * BLOCK_M + start_n = pid_n * BLOCK_N + + offs_am = start_m + tl.arange(0, BLOCK_M) + offs_bn = start_n + tl.arange(0, BLOCK_N) + + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N) + + offs_k = tl.arange(0, BLOCK_K) + + if not USE_TMA: + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + Ascale_ptrs = Ascale + offs_am + Bscale_ptrs = Bscale + offs_bn + a_s = tl.load(Ascale_ptrs) + b_s = tl.load(Bscale_ptrs) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K)): + if USE_TMA: + a = A_desc.load([start_m, k * BLOCK_K]) + if not B_IS_TRANS: + b = B_desc.load([k * BLOCK_K, start_n]) + else: + b = B_desc.load([start_n, k * BLOCK_K]).T + elif NEED_K_MASK: + k_remaining = K - k * BLOCK_K + a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + else: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + acc = tl.dot(a, b, acc) + if not USE_TMA: + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + acc = acc * a_s[:, None] * b_s[None, :] + + acc = acc.to(out.dtype.element_ty) + + if not USE_TMA: + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = out + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + if NEED_N_MASK: + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + else: + mask = offs_cm[:, None] < M + tl.store(c_ptrs, acc, mask=mask) + else: + out_desc.store([start_m, start_n], acc) + + +def get_test_configs(): + fp8_gemm_configs = [] + + for BLOCK_M in [8, 16, 32, 64]: + for BLOCK_N in [64, 128, 256]: + for BLOCK_K in [32, 64, 128, 256]: + if BLOCK_K * BLOCK_M * BLOCK_N >= 256 * 256 * 128: + continue + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4, 5, 6]: + config = { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "BLOCK_K": BLOCK_K, + "GROUP_M": 8, + "num_stages": num_stages, + "num_warps": num_warps, + } + fp8_gemm_configs.append(config) + + return fp8_gemm_configs + + +def _get_static_key(A, B, out_dtype): + M, K = A.shape + _, N = B.shape + return { + "N": N, + "K": K, + "out_dtype": str(out_dtype), + } + + +@autotune( + kernel_name="fp8_scaled_mm_per_token:v3", + configs_gen_func=get_test_configs, + static_key_func=_get_static_key, + run_key_func=lambda A: A.shape[0], + mutates_args=["out"], +) +def fp8_scaled_mm_per_token( + A: torch.Tensor, + B: torch.Tensor, + Ascale: torch.Tensor, + Bscale: torch.Tensor, + out_dtype: torch.dtype, + out: torch.Tensor, + run_config=None, +) -> torch.Tensor: + """w8a8fp8 per-token quantization mm. + + Args: + A: Matrix A with shape of [M, K]. + B: Matrix B with shape of [K, N]. + Ascale: per-token Quantization scale for A: [M] or [M, 1]. + Bscale: per-channel Quantization scale for B: [N] or [1, N]. + out_dtype: The data type of out. + out: The output matrix with the shape of [M, N]. + Returns: + torch.Tensor: out. + """ + assert A.is_contiguous() + B_is_trans = not B.is_contiguous() and B.stride(0) == 1 + + M, K = A.shape + _, N = B.shape + if not run_config: + run_config = Fp8ScaledMMKernelConfig.try_to_get_best_config(M=M, N=N, K=K, out_dtype=out_dtype) + NEED_N_MASK = N % run_config["BLOCK_N"] != 0 + NEED_K_MASK = K % run_config["BLOCK_K"] != 0 + grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),) + + BLOCK_M = run_config["BLOCK_M"] + BLOCK_K = run_config["BLOCK_K"] + BLOCK_N = run_config["BLOCK_N"] + + # use tma + support_tma = triton_support_tensor_descriptor() + # 5090 上,小shape开启tma性能不是很好。 + support_tma = support_tma and (not is_5090_gpu()) + if support_tma: + stride = A.stride(-2) + if (stride * A.dtype.itemsize) % 16 != 0: + support_tma = False + _B = B if not B_is_trans else B.transpose(0, 1) + stride = _B.stride(-2) + if (stride * _B.dtype.itemsize) % 16 != 0: + support_tma = False + + if support_tma: + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + from triton.tools.tensor_descriptor import TensorDescriptor + + A_desc = TensorDescriptor(A, A.shape, A.stride(), [BLOCK_M, BLOCK_K]) + if B_is_trans: + _B = B.transpose(0, 1) + assert _B.is_contiguous() + B_desc = TensorDescriptor(_B, _B.shape, _B.stride(), [BLOCK_N, BLOCK_K]) + else: + B_desc = TensorDescriptor(B, B.shape, B.stride(), [BLOCK_K, BLOCK_N]) + out_desc = TensorDescriptor(out, out.shape, out.stride(), [BLOCK_M, BLOCK_N]) + else: + A_desc = None + B_desc = None + out_desc = None + + _scaled_mm_per_token[grid]( + A=A, + A_desc=A_desc, + B=B, + B_desc=B_desc, + out=out, + out_desc=out_desc, + Ascale=Ascale, + Bscale=Bscale, + M=M, + N=N, + K=K, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_bk=B.stride(0), + stride_bn=B.stride(1), + stride_cm=out.stride(0), + stride_cn=out.stride(1), + USE_TMA=support_tma, + B_IS_TRANS=B_is_trans, + NEED_N_MASK=NEED_N_MASK, + NEED_K_MASK=NEED_K_MASK, + **run_config, + ) + + return out + + +if __name__ == "__main__": + import time + import os + from lightllm.common.triton_utils.autotuner import Autotuner + import torch.nn.functional as F + + output_dtype = torch.bfloat16 + N, K = 4096, 5120 + + # 测试多个不同的 M 值 + M_list = [1, 2, 4, 8, 16, 32, 48] + + print(f"{'='*80}") + print(f"Starting Autotune for FP8 Scaled MM (N={N}, K={K})") + print(f"M values to test: {M_list}") + print(f"Total configs per M: {len(get_test_configs())}") + print(f"{'='*80}\n") + + # 准备权重矩阵 B(所有测试共享) + B = torch.randn((N, K), dtype=output_dtype).cuda().to(torch.float8_e4m3fn).transpose(0, 1) # [K, N] + Bscale = torch.ones((1, N)).cuda() + + # 准备所有测试数据 + test_data = {} + for M in M_list: + A = torch.randn((M, K), dtype=output_dtype).cuda().to(torch.float8_e4m3fn) + Ascale = torch.randn((M, 1)).cuda() + out = torch.zeros((M, N), dtype=output_dtype).cuda() + test_data[M] = {"A": A, "Ascale": Ascale, "out": out} + + # ============ Phase 0: Correctness Check ============ + print("\n" + "=" * 80) + print("PHASE 0: Verifying Correctness Before Autotune") + print("=" * 80) + + # 选择一个中等大小的 M 进行正确性验证 + M_verify = 16 if 16 in M_list else M_list[len(M_list) // 2] + A_verify = test_data[M_verify]["A"] + Ascale_verify = test_data[M_verify]["Ascale"] + out_verify = test_data[M_verify]["out"] + + print(f"\n[Verification] Testing with M={M_verify}") + + # 计算ground truth + d_A = A_verify.to(output_dtype) * Ascale_verify.to(output_dtype) + d_B = B.to(output_dtype) * Bscale.to(output_dtype) + gt_C = d_A.mm(d_B) + + # 运行kernel验证正确性 + fp8_scaled_mm_per_token(A_verify, B, Ascale_verify, Bscale, output_dtype, out_verify) + + # 计算cosine similarity + cosine_sim = F.cosine_similarity(out_verify.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) + print(f"[Verification] Cosine Similarity: {cosine_sim.item():.6f}") + + # 计算max absolute error + max_abs_error = torch.max(torch.abs(out_verify - gt_C)).item() + mean_abs_error = torch.mean(torch.abs(out_verify - gt_C)).item() + print(f"[Verification] Max Absolute Error: {max_abs_error:.6e}") + print(f"[Verification] Mean Absolute Error: {mean_abs_error:.6e}") + + # 验证阈值 + if cosine_sim.item() < 0.99: + raise RuntimeError(f"Correctness check failed! Cosine similarity {cosine_sim.item():.6f} < 0.99") + + print("[Verification] ✅ Correctness check passed!") + print("=" * 80) + + # ============ Phase 1: Autotune ============ + print("\n" + "=" * 80) + print("PHASE 1: Running Autotune") + print("=" * 80) + Autotuner.start_autotune_warmup() + + for M in M_list: + print(f"\n[M={M}] Running autotune...") + A = test_data[M]["A"] + Ascale = test_data[M]["Ascale"] + out = test_data[M]["out"] + fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + print(f"[M={M}] Autotune completed!") + + Autotuner.end_autotune_warmup() + print("\n" + "=" * 80) + print("All autotune completed! Now starting benchmarks...") + print("=" * 80) + + # ============ Phase 2: Benchmark ============ + results = [] + from sgl_kernel import fp8_scaled_mm + + for M in M_list: + print(f"\n{'='*80}") + print(f"Benchmarking M={M}") + print(f"{'='*80}") + + A = test_data[M]["A"] + Ascale = test_data[M]["Ascale"] + out = test_data[M]["out"] + + # 验证正确性 + print(f"[M={M}] Verifying correctness...") + d_A = A.to(output_dtype) * Ascale.to(output_dtype) + d_B = B.to(output_dtype) * Bscale.to(output_dtype) + gt_C = d_A.mm(d_B) + + # 运行一次确保结果正确 + fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + sgl_res = fp8_scaled_mm(A, B, Ascale, Bscale, output_dtype) + + cosine_sim = F.cosine_similarity(out.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) + sgl_cosine_sim = F.cosine_similarity(sgl_res.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) + print(f"[M={M}] Cosine Similarity - Our: {cosine_sim.item():.6f}, SGL: {sgl_cosine_sim.item():.6f}") + + # Benchmark 性能 + print(f"[M={M}] Benchmarking performance...") + + # BF16 baseline + fn_bf16 = lambda: torch.mm(d_A, d_B) + ms_bf16 = triton.testing.do_bench(fn_bf16, warmup=25, rep=100) + + # SGL kernel + fn_sgl = lambda: fp8_scaled_mm(A, B, Ascale, Bscale, output_dtype) + ms_sgl = triton.testing.do_bench(fn_sgl, warmup=25, rep=100) + + # Our kernel + fn_ours = lambda: fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + ms_ours = triton.testing.do_bench_cudagraph(fn_ours, rep=100) + + print(f"[M={M}] BF16: {ms_bf16:.3f} ms") + print(f"[M={M}] SGL FP8: {ms_sgl:.3f} ms ({ms_bf16/ms_sgl:.2f}x)") + print(f"[M={M}] Our FP8: {ms_ours:.3f} ms ({ms_bf16/ms_ours:.2f}x)") + + results.append( + { + "M": M, + "bf16_ms": ms_bf16, + "sgl_ms": ms_sgl, + "ours_ms": ms_ours, + "cosine_sim": cosine_sim.item(), + } + ) + + # 打印汇总结果 + print(f"\n{'='*80}") + print("SUMMARY - Performance Comparison") + print(f"{'='*80}") + print(f"{'M':<8} {'BF16(ms)':<12} {'SGL(ms)':<12} {'Our(ms)':<12} {'vs BF16':<10} {'vs SGL':<10}") + print(f"{'-'*80}") + for r in results: + vs_bf16 = f"{r['bf16_ms']/r['ours_ms']:.2f}x" + vs_sgl = f"{r['sgl_ms']/r['ours_ms']:.2f}x" + emoji = "🎉" if r["ours_ms"] < r["sgl_ms"] else "" + print( + f"{r['M']:<8} {r['bf16_ms']:<12.3f} {r['sgl_ms']:<12.3f}" + f"{r['ours_ms']:<12.3f} {vs_bf16:<10} {vs_sgl:<10} {emoji}" + ) + print(f"{'='*80}") diff --git a/lightllm/common/quantization/triton_quant/triton_quant.py b/lightllm/common/quantization/triton_quant/triton_quant.py index 410f925a5e..9f6a7bee27 100644 --- a/lightllm/common/quantization/triton_quant/triton_quant.py +++ b/lightllm/common/quantization/triton_quant/triton_quant.py @@ -7,8 +7,7 @@ from .fp8.fp8act_quant_kernel import per_token_group_quant_fp8 from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from lightllm.common.quantization.quantize_method import WeightPack class TritonBaseQuantizationMethod(QuantizationMethod): @@ -18,16 +17,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> WeightPack: + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -44,17 +44,18 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: # TODO block-wise quant kernel - pass + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale @@ -83,3 +84,29 @@ def apply( dtype=input_tensor.dtype, ) return out + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[ + start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size + ].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_zero_point[ + start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size + ].copy_(weight_zero_point) + return diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index cec6d17789..8c5d1cc1e1 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -9,8 +9,8 @@ from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +from .quantize_method import WeightPack if HAS_LIGHTLLM_KERNEL: @@ -30,16 +30,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -47,6 +48,11 @@ def apply( def method_name(self): return "w8a8-base" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + raise NotImplementedError("Not implemented") + @QUANTMETHODS.register(["vllm-w8a8", "w8a8"]) class w8a8QuantizationMethod(BaseQuantizationMethod): @@ -55,27 +61,27 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): - if isinstance(weight, tuple): - return (weight[0].transpose(0, 1).cuda(self.device_id_),) + weight[1:] - weight = weight.float() + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + weight = weight.float().cuda(self.device_id_) scale = weight.abs().max(dim=-1)[0] / 127 - weight = weight.transpose(0, 1) / scale.reshape(1, -1) + weight = weight / scale.reshape(-1, 1) weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - return weight.cuda(self.device_id_), scale.cuda(self.device_id_), None + output.weight[offset : offset + weight.shape[0]].copy_(weight) + output.weight_scale[offset : offset + weight.shape[0]].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_scale = None - qweight = weight_pack.weight + qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - bias = weight_pack.bias input_scale = None # dynamic quantization for input tensor x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True) m = input_tensor.shape[0] @@ -94,6 +100,14 @@ def apply( def method_name(self): return "vllm-w8a8" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + @QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"]) class FP8w8a8QuantizationMethod(BaseQuantizationMethod): @@ -103,19 +117,20 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: if self.is_moe: - return self.quantize_moe(weight) + return self.quantize_moe(weight, output, offset) qweight, weight_scale = scaled_fp8_quant( - weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) - return qweight.transpose(0, 1), weight_scale, None + output.weight[offset : offset + qweight.shape[0], :].copy_(qweight) + output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) + return - def quantize_moe(self, weight: torch.Tensor): + def quantize_moe(self, weight: torch.Tensor) -> WeightPack: num_experts = weight.shape[0] - qweights = [] - weight_scales = [] qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) + weight_scales = [] for i in range(num_experts): qweight, weight_scale = scaled_fp8_quant( weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True @@ -123,19 +138,19 @@ def quantize_moe(self, weight: torch.Tensor): qweights[i] = qweight weight_scales.append(weight_scale) weight_scale = torch.stack(weight_scales, dim=0).contiguous() - return qweights, weight_scale + return WeightPack(weight=qweights, weight_scale=weight_scale) def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weight_pack.weight + qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - bias = weight_pack.bias x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) m = input_tensor.shape[0] n = qweight.shape[1] @@ -153,6 +168,14 @@ def apply( def method_name(self): return "vllm-fp8w8a8" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + @QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"]) class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): @@ -163,21 +186,26 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - raise Exception("Not implemented") + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - bias = weight_pack.bias + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale.t() input_scale = None # dynamic quantization for input tensor m, k = input_tensor.shape n = qweight.shape[1] @@ -206,3 +234,13 @@ def apply( @property def method_name(self): return "vllm-fp8w8a8-b128" + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index b740bb62f9..b03ed061d0 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -6,21 +6,36 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_vob_size = vob_size // self.tp_world_size_ + + # Pre-allocate memory for weights + self.pre_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.pre_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + return def load_hf_weights(self, weights): if "word_embeddings_layernorm.weight" in weights: - self.pre_norm_weight_ = self._cuda(weights["word_embeddings_layernorm.weight"]) + self.pre_norm_weight_.copy_(weights["word_embeddings_layernorm.weight"]) if "word_embeddings_layernorm.bias" in weights: - self.pre_norm_bias_ = self._cuda(weights["word_embeddings_layernorm.bias"]) + self.pre_norm_bias_.copy_(weights["word_embeddings_layernorm.bias"]) if "ln_f.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["ln_f.weight"]) + self.final_norm_weight_.copy_(weights["ln_f.weight"]) if "ln_f.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["ln_f.bias"]) + self.final_norm_bias_.copy_(weights["ln_f.bias"]) if "word_embeddings.weight" in weights: vob_size = self.network_config_["vocab_size"] split_vob_size = vob_size // self.tp_world_size_ - self.wte_weight_ = self._cuda( + self.wte_weight_.copy_( weights["word_embeddings.weight"][ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : ] diff --git a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py index 0b125bea35..a93b30f94b 100644 --- a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py @@ -13,19 +13,16 @@ def load_hf_weights(self, weights): vob_size = self.network_config_["padded_vocab_size"] split_vob_size = vob_size // self.tp_world_size_ if "transformer.embedding.word_embeddings.weight" in weights: - self.wte_weight_ = weights["transformer.embedding.word_embeddings.weight"] - self.wte_weight_ = self.wte_weight_[ + wte_weight = weights["transformer.embedding.word_embeddings.weight"] + self.wte_weight_.copy_(wte_weight[ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - self.wte_weight_ = self._cuda(self.wte_weight_) + ]) if "transformer.output_layer.weight" in weights: - self.lm_head_weight_ = weights["transformer.output_layer.weight"] - self.lm_head_weight_ = self.lm_head_weight_[ + lm_head_weight = weights["transformer.output_layer.weight"] + self.lm_head_weight_.copy_(lm_head_weight[ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - self.lm_head_weight_ = self._cuda(self.lm_head_weight_) + ]) if "transformer.encoder.final_layernorm.weight" in weights: - self.final_norm_weight_ = weights["transformer.encoder.final_layernorm.weight"] - self.final_norm_weight_ = self._cuda(self.final_norm_weight_) + self.final_norm_weight_.copy_(weights["transformer.encoder.final_layernorm.weight"]) return diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py index 993acd64d7..ed550fecf8 100644 --- a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py @@ -13,13 +13,13 @@ def load_hf_weights(self, weights): split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) if tie_weight: self.lm_head_weight_ = self.wte_weight_ if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) if "model.lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["model.lm_head.weight"]) + self.lm_head_weight_.copy_(weights["model.lm_head.weight"]) return def verify_load(self): diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py index fff92abf55..1dce8b51f1 100644 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py @@ -18,14 +18,14 @@ def _init_norm(self, weights): q_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ k_split_head = self.network_config_["num_key_value_heads"] // self.tp_world_size_ - self.att_norm_weight_ = NormWeight(self._att_norm_weight_name, self.data_type_) + self.att_norm_weight_ = NormWeight(self.n_embed, self._att_norm_weight_name, self.data_type_) if self.use_qk_norm: self.q_norm_weight_ = TpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_, q_split_head + q_split_head, f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_ ) self.k_norm_weight_ = TpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_, k_split_head + k_split_head, f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_ ) return diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index c899751eb7..390a26aa83 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -299,14 +299,14 @@ def _init_ffn(self): self._load_mlp(f"model.layers.{self.layer_num_}.mlp") def _init_norm(self): - self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) + self.att_norm_weight_ = NormWeight(self.network_config_["n_embed"], f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) self.ffn_norm_weight_ = NormWeight( - f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_ + self.network_config_["n_embed"], f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_ ) self.kv_a_layernorm_ = NormWeight( - f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", self.data_type_ + self.network_config_["n_embed"], f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", self.data_type_ ) if self.q_lora_rank is not None: self.q_a_layernorm_ = NormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", self.data_type_ + self.network_config_["n_embed"], f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", self.data_type_ ) diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index f5b805647b..66131a858d 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -1,3 +1,4 @@ +import torch import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight @@ -10,15 +11,26 @@ def __init__(self, data_type, network_config, mode): self.lm_head_weight_ = None return + def _create_weight(self): + hidden_size = self.network_config_["hidden_size"] + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + + # Pre-allocate memory for weights + self.eh_proj_weight_ = torch.empty((moe_intermediate_size, hidden_size), dtype=self.data_type_).cuda() + self.enorm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.hnorm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + return + def load_hf_weights(self, weights): if "model.layers.0.eh_proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() + self.eh_proj_weight_.copy_(weights["model.layers.0.eh_proj.weight"].t()) if "model.layers.0.enorm.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) + self.enorm_weight_.copy_(weights["model.layers.0.enorm.weight"]) if "model.layers.0.hnorm.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) + self.hnorm_weight_.copy_(weights["model.layers.0.hnorm.weight"]) if "model.layers.0.shared_head.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.layers.0.shared_head.norm.weight"]) return def verify_load(self): diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 6f55304619..ec05a98edf 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -66,13 +66,13 @@ def _init_qkv(self): def _init_norm(self): super()._init_norm() - self.k_norm_weight_ = NormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) - self.q_norm_weight_ = NormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) + self.k_norm_weight_ = NormWeight(self.head_dim, self._k_norm_weight_name, self.data_type_, bias_name=None) + self.q_norm_weight_ = NormWeight(self.head_dim, self._q_norm_weight_name, self.data_type_, bias_name=None) self.pre_feedforward_layernorm_weight_ = NormWeight( - self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None + self.n_embed, self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None ) self.post_feedforward_layernorm_weight_ = NormWeight( - self._post_feedforward_layernorm_name, self.data_type_, bias_name=None + self.n_embed, self._post_feedforward_layernorm_name, self.data_type_, bias_name=None ) def load_hf_weights(self, weights): diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index c119960c52..fe388d5323 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -15,11 +15,11 @@ def load_hf_weights(self, weights): split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) self.lm_head_weight_ = self.wte_weight_ if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) self.final_norm_weight_ = self.final_norm_weight_ + 1 return diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 32248e6dd9..49bc6150c9 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -29,5 +29,5 @@ def _init_qkv(self): ) def _init_norm(self): - self.att_norm_weight_ = GEMMANormWeight(self._att_norm_weight_name, self.data_type_) - self.ffn_norm_weight_ = GEMMANormWeight(self._ffn_norm_weight_name, self.data_type_) + self.att_norm_weight_ = GEMMANormWeight(self.n_embed, self._att_norm_weight_name, self.data_type_) + self.ffn_norm_weight_ = GEMMANormWeight(self.n_embed, self._ffn_norm_weight_name, self.data_type_) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7e6035dc5c..55fcf4f33b 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -2,7 +2,9 @@ import torch import numpy as np -from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.gpt_oss_fused_moe_weight_tp import ( + GPTOSSFusedMoeWeightTP, +) from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight, TpNormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight @@ -71,7 +73,7 @@ def _init_weight(self): super()._init_weight() n_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.attn_sinks = TpNormWeight(self._attn_sink_name, torch.bfloat16, n_split_head) + self.attn_sinks = TpNormWeight(n_split_head, self._attn_sink_name, torch.bfloat16) def _init_ffn(self): self._init_moe() diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index dd8c64915e..c77269db81 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -8,16 +8,30 @@ def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + return + def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.tok_embeddings.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.tok_embeddings.weight"][split_start:split_end, :]) if "output.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["output.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["output.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index 78fb0c5d73..7735a3f303 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -14,10 +14,10 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.tok_embeddings.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.tok_embeddings.weight"][split_start:split_end, :]) if "v_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["v_head.weight"]).transpose(0, 1) + self.lm_head_weight_.copy_(weights["v_head.weight"].transpose(0, 1)) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 711406e3f1..98cf0d51c1 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -7,6 +7,18 @@ class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + return + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_vob_size = vob_size // self.tp_world_size_ + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() return def load_hf_weights(self, weights): @@ -15,14 +27,14 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: self.lm_head_weight_ = self.wte_weight_ if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 6247170077..8ca0fe15d8 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -24,11 +24,16 @@ def _init_weight(self): self._init_norm() def _parse_config(self): + self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1) + self.tp_v_head_num_ = self.tp_k_head_num_ + self.tp_o_head_num_ = self.tp_q_head_num_ + head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] + self.head_dim = self.network_config_.get("head_dim", head_dim) + assert (self.tp_k_head_num_ * self.tp_world_size_) % self.network_config_["num_key_value_heads"] == 0 self.n_embed = self.network_config_["hidden_size"] - self.n_head = self.network_config_["num_attention_heads"] self.n_inter = self.network_config_["intermediate_size"] - self.n_kv_head = self.network_config_["num_key_value_heads"] - self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head) + self.n_head = self.network_config_["num_attention_heads"] def _init_weight_names(self): self._q_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" @@ -57,55 +62,63 @@ def _init_weight_names(self): self._ffn_norm_bias_name = None def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.tp_q_head_num_ * self.head_dim + k_out_dim = self.tp_k_head_num_ * self.head_dim + v_out_dim = self.tp_v_head_num_ * self.head_dim self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], weight_names=self._q_weight_name, data_type=self.data_type_, bias_names=self._q_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_proj", + quant_method=self.get_quant_method("q_proj"), ) self.kv_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[k_out_dim, v_out_dim], weight_names=[self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._k_bias_name, self._v_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="kv_proj", + quant_method=self.get_quant_method("kv_proj"), ) def _init_o(self): + in_dim = self.tp_o_head_num_ * self.head_dim + out_dim = self.n_embed self.o_proj = COLMMWeight( + in_dim=in_dim, + out_dims=[out_dim], weight_names=self._o_weight_name, data_type=self.data_type_, bias_names=self._o_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="o_proj", + quant_method=self.get_quant_method("o_proj"), ) def _init_ffn(self): + in_dim = self.n_embed + out_dim = self.n_inter // self.tp_world_size_ self.gate_up_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[out_dim, out_dim], weight_names=[self._gate_weight_name, self._up_weight_name], data_type=self.data_type_, bias_names=[self._gate_bias_name, self._up_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=out_dim, + out_dims=[in_dim], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) def _init_norm(self): self.att_norm_weight_ = NormWeight( - self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + self.n_embed, self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name ) self.ffn_norm_weight_ = NormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + self.n_embed, self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) diff --git a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py index 94e1a27e0c..0a9230b5b3 100644 --- a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py @@ -18,11 +18,11 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) / self.lm_head_scale + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :] / self.lm_head_scale) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index f425ad08ba..4967687103 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -62,6 +62,7 @@ def _init_moe(self): layer_num=self.layer_num_, quant_cfg=self.quant_cfg, num_fused_shared_experts=0, + hidden_size=self.network_config_.get("hidden_size"), ) else: raise ValueError(f"Unsupported moe mode: {moe_mode}") diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index 95af6ecd34..a07a55e8c9 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -6,6 +6,18 @@ class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + return + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_vob_size = vob_size // self.tp_world_size_ + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() return def load_hf_weights(self, weights): @@ -14,17 +26,17 @@ def load_hf_weights(self, weights): split_vob_size = vob_size // self.tp_world_size_ if "transformer.wte.weight" in weights: - self.wte_weight_ = self._cuda( + self.wte_weight_.copy_( weights["transformer.wte.weight"][ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : ] ) if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda( + self.lm_head_weight_.copy_( weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] ) if "transformer.ln_f.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["transformer.ln_f.weight"]) + self.final_norm_weight_.copy_(weights["transformer.ln_f.weight"]) return diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index 5735b03399..772400a1e2 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -6,6 +6,21 @@ class Qwen2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + return + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() return def load_hf_weights(self, weights): @@ -14,14 +29,14 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: self.lm_head_weight_ = self.wte_weight_ if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 2e2c0d3bb2..d65353108d 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -14,15 +14,6 @@ def _init_weight_names(self): self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" - def _parse_config(self): - self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1) - self.tp_v_head_num_ = self.tp_k_head_num_ - self.tp_o_head_num_ = self.tp_q_head_num_ - head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] - self.head_dim = self.network_config_.get("head_dim", head_dim) - assert (self.tp_k_head_num_ * self.tp_world_size_) % self.network_config_["num_key_value_heads"] == 0 - def _repeat_weight(self, name, weights): # for tp_world_size_ > num_key_value_heads if name not in weights: diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index a56c5d6cbb..9babad35a8 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -9,29 +9,49 @@ def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + + # Reward model specific weights + self.score_up_weight = torch.empty((hidden_size, 1), dtype=self.data_type_).cuda() + self.score_up_bias = torch.empty(1, dtype=self.data_type_).cuda() + self.score_down_weight = torch.empty((hidden_size, 1), dtype=self.data_type_).cuda() + self.score_down_bias = torch.empty(1, dtype=self.data_type_).cuda() + return + def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: self.lm_head_weight_ = self.wte_weight_ if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) if "score.0.weight" in weights: - self.score_up_weight = self._cuda(weights["score.0.weight"]).transpose(0, 1) + self.score_up_weight.copy_(weights["score.0.weight"].transpose(0, 1)) if "score.0.bias" in weights: - self.score_up_bias = self._cuda(weights["score.0.bias"]) + self.score_up_bias.copy_(weights["score.0.bias"]) if "score.2.weight" in weights: - self.score_down_weight = self._cuda(weights["score.2.weight"]).transpose(0, 1) + self.score_down_weight.copy_(weights["score.2.weight"].transpose(0, 1)) if "score.2.bias" in weights: - self.score_down_bias = self._cuda(weights["score.2.bias"]) + self.score_down_bias.copy_(weights["score.2.bias"]) return diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 4c0ef586f0..dcee72a1c0 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -20,5 +20,5 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) - self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + self.q_norm_weight_ = NormWeight(self.head_dim, weight_name=self._q_norm_name, data_type=self.data_type_) + self.k_norm_weight_ = NormWeight(self.head_dim, weight_name=self._k_norm_name, data_type=self.data_type_) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 72721f9d6f..bc4b548192 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -53,10 +53,11 @@ def _init_weight(self): def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, - layer_num=self.layer_num_, - name="moe_gate", + quant_method=None, tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 80966c7b49..9002c463d5 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -8,6 +8,21 @@ def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + return + def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) @@ -15,14 +30,14 @@ def load_hf_weights(self, weights): split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) if "lm_head.weight" in weights: # print(weights['lm_head.weight'].shape) - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) if "model.norm.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) + self.final_norm_bias_.copy_(weights["model.norm.bias"]) return diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index 8d87c11632..b54fc068b7 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -6,6 +6,21 @@ class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_vob_size = vob_size // self.tp_world_size_ + max_position_embeddings = self.network_config_["max_position_embeddings"] + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.wpe_weight_ = torch.empty((max_position_embeddings, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + return def load_hf_weights(self, weights): @@ -13,28 +28,22 @@ def load_hf_weights(self, weights): split_vob_size = vob_size // self.tp_world_size_ if "transformer.wte.weight" in weights: # print(weights['transformer.wte.weight'].shape) - self.wte_weight_ = ( + self.wte_weight_.copy_( weights["transformer.wte.weight"][ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : ] - .contiguous() - .to(self.data_type_) - .cuda() ) if "transformer.wpe.weight" in weights: # print(weights['transformer.wpe.weight'].shape) - self.wpe_weight_ = weights["transformer.wpe.weight"].to(self.data_type_).cuda() + self.wpe_weight_.copy_(weights["transformer.wpe.weight"]) if "lm_head.weight" in weights: - self.lm_head_weight_ = ( + self.lm_head_weight_.copy_( weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] - .contiguous() - .to(self.data_type_) - .cuda() ) if "transformer.ln_f.weight" in weights: - self.final_norm_weight_ = weights["transformer.ln_f.weight"].contiguous().to(self.data_type_).cuda() + self.final_norm_weight_.copy_(weights["transformer.ln_f.weight"]) if "transformer.ln_f.bias" in weights: - self.final_norm_bias_ = weights["transformer.ln_f.bias"].contiguous().to(self.data_type_).cuda() + self.final_norm_bias_.copy_(weights["transformer.ln_f.bias"]) return def verify_load(self): diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index fd2d47575a..cfe1969c0b 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -1,3 +1,4 @@ +import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight @@ -5,6 +6,22 @@ class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + return + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() return def load_hf_weights(self, weights): @@ -13,19 +30,19 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) # for starcoder2-3b and 7b which didn't use lm_head.weight (tie_word_embeddings) self.lm_head_weight_ = self.wte_weight_ if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) if "model.norm.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) + self.final_norm_bias_.copy_(weights["model.norm.bias"]) return diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 276d4e5d0b..1a6f76fde7 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -13,6 +13,34 @@ def __init__(self, data_type, network_config, mode): self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] + self._create_weight() + return + + def _create_weight(self): + split_indexes = np.linspace(0, self.embed_dim, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_embed_dim = split_end - split_start + + # Pre-allocate memory for vision model weights + self.class_embedding = torch.empty((1, 1, split_embed_dim), dtype=self.data_type_).cuda() + self.position_embedding = torch.empty((1, 197, split_embed_dim), dtype=self.data_type_).cuda() # 197 = (224//16)^2 + 1 + self.patch_embedding_weight_ = torch.empty((split_embed_dim, 3, self.patch_size, self.patch_size), dtype=self.data_type_).cuda() + self.patch_embedding_bias_ = torch.empty(split_embed_dim, dtype=self.data_type_).cuda() + + # Pre-allocate memory for adapter weights + self.layernorm_weight_ = torch.empty(self.embed_dim, dtype=self.data_type_).cuda() + self.layernorm_bias_ = torch.empty(self.embed_dim, dtype=self.data_type_).cuda() + + split_indexes_llm = np.linspace(0, self.llm_hidden_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start_llm = split_indexes_llm[self.tp_rank_] + split_end_llm = split_indexes_llm[self.tp_rank_ + 1] + split_llm_hidden_size = split_end_llm - split_start_llm + + self.mlp1_1_weight_ = torch.empty((self.llm_hidden_size, split_llm_hidden_size), dtype=self.data_type_).cuda() + self.mlp1_1_bias_ = torch.empty(split_llm_hidden_size, dtype=self.data_type_).cuda() + self.mlp1_3_weight_ = torch.empty((split_llm_hidden_size, self.llm_hidden_size), dtype=self.data_type_).cuda() + self.mlp1_3_bias_ = torch.empty(self.llm_hidden_size, dtype=self.data_type_).cuda() return def _cuda(self, cpu_tensor): @@ -40,40 +68,40 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "vision_model.embeddings.class_embedding" in weights: - self.class_embedding = self._cuda( + self.class_embedding.copy_( weights["vision_model.embeddings.class_embedding"][:, :, split_start:split_end] ) if "vision_model.embeddings.position_embedding" in weights: - self.position_embedding = self._cuda( + self.position_embedding.copy_( weights["vision_model.embeddings.position_embedding"][:, :, split_start:split_end] ) if "vision_model.embeddings.patch_embedding.weight" in weights: - self.patch_embedding_weight_ = self._cuda( + self.patch_embedding_weight_.copy_( weights["vision_model.embeddings.patch_embedding.weight"][split_start:split_end, :, :, :] ) if "vision_model.embeddings.patch_embedding.bias" in weights: - self.patch_embedding_bias_ = self._cuda( + self.patch_embedding_bias_.copy_( weights["vision_model.embeddings.patch_embedding.bias"][split_start:split_end] ) if "mlp1.0.weight" in weights: - self.layernorm_weight_ = self._cuda(weights["mlp1.0.weight"]) + self.layernorm_weight_.copy_(weights["mlp1.0.weight"]) if "mlp1.0.bias" in weights: - self.layernorm_bias_ = self._cuda(weights["mlp1.0.bias"]) + self.layernorm_bias_.copy_(weights["mlp1.0.bias"]) split_indexes = np.linspace(0, self.llm_hidden_size, self.tp_world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "mlp1.1.weight" in weights: - self.mlp1_1_weight_ = self._cuda(weights["mlp1.1.weight"][split_start:split_end, :]).t() + self.mlp1_1_weight_.copy_(weights["mlp1.1.weight"][split_start:split_end, :].t()) if "mlp1.1.bias" in weights: - self.mlp1_1_bias_ = self._cuda(weights["mlp1.1.bias"][split_start:split_end]) + self.mlp1_1_bias_.copy_(weights["mlp1.1.bias"][split_start:split_end]) if "mlp1.3.weight" in weights: - self.mlp1_3_weight_ = self._cuda(weights["mlp1.3.weight"][:, split_start:split_end]).t() + self.mlp1_3_weight_.copy_(weights["mlp1.3.weight"][:, split_start:split_end].t()) if "mlp1.3.bias" in weights: - self.mlp1_3_bias_ = self._cuda(weights["mlp1.3.bias"]) + self.mlp1_3_bias_.copy_(weights["mlp1.3.bias"]) return diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index f1de0bdc16..05d8edbad1 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -119,17 +119,18 @@ def _init_ffn(self): ) def _init_norm(self): + n_embed = self.network_config_["hidden_size"] self.att_norm_weight_ = NormWeight( - self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + n_embed, self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name ) self.ffn_norm_weight_ = NormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + n_embed, self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) if self.qk_norm: n_embed = self.network_config_["hidden_size"] split_n_embed = (n_embed + self.padding_hidden_size) // self.tp_world_size_ - self.q_norm_weight_ = TpNormWeight(self._q_norm_weight_name, self.data_type_, split_n_embed) - self.k_norm_weight_ = TpNormWeight(self._k_norm_weight_name, self.data_type_, split_n_embed) + self.q_norm_weight_ = TpNormWeight(split_n_embed, self._q_norm_weight_name, self.data_type_) + self.k_norm_weight_ = TpNormWeight(split_n_embed, self._k_norm_weight_name, self.data_type_) def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: diff --git a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py index 811d39a729..e3a71379d2 100644 --- a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py @@ -8,10 +8,10 @@ import json from typing import List from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_ep_redundancy import ( +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_ep_redundancy import ( FusedMoeWeightEPAutoRedundancy, ) -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_ep import FusedMoeWeightEP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_ep import FusedMoeWeightEP from lightllm.utils.envs_utils import get_env_start_args, get_redundancy_expert_update_interval from lightllm.utils.envs_utils import get_redundancy_expert_update_max_load_count from lightllm.utils.envs_utils import get_redundancy_expert_num From 3d225d70ca7d22846355a47dd6a30eee3e482f3a Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 25 Nov 2025 04:07:03 +0000 Subject: [PATCH 011/180] add_cli --- lightllm/server/api_cli.py | 3 +-- lightllm/server/api_server.py | 12 ++++++++---- lightllm/server/core/objs/start_args_type.py | 3 +++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ee3f184e41..fa01dd0689 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -1,8 +1,7 @@ import argparse -def make_argument_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() +def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--run_mode", diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index dd531f58d4..1eb5ff24c0 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,16 +1,17 @@ import torch -from .api_cli import make_argument_parser +from .api_cli import add_cli_args from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) + def launch_server(args: StartArgs): from .api_start import pd_master_start, normal_or_p_d_start, config_server_start - + try: # this code will not be ok for settings to fork to subprocess - torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_start_method("spawn") except RuntimeError as e: logger.warning(f"Failed to set start method: {e}") except Exception as e: @@ -26,7 +27,10 @@ def launch_server(args: StartArgs): if __name__ == "__main__": - parser = make_argument_parser() + from argparse import ArgumentParser + + parser = ArgumentParser() + add_cli_args(parser) args = parser.parse_args() launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index eff4dfab55..0af795a096 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -150,3 +150,6 @@ class StartArgs: enable_weight_cpu_backup: bool = field(default=False) weight_version: str = "default" + + enable_torch_memory_saver: bool = field(default=False) + enable_weight_cpu_backup: bool = field(default=False) From 499074a871727c2787ea90963112410d1bc8c1fc Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 8 Dec 2025 14:42:16 +0000 Subject: [PATCH 012/180] add 30b moe configs --- lightllm/common/basemodel/basemodel.py | 2 +- .../basemodel/layer_weights/hf_load_utils.py | 69 +---------- lightllm/common/quantization/w8a8_quant.py | 13 --- ...num=8,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...orch.bfloat16,topk_num=8}_NVIDIA_H200.json | 74 ++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++++++++++++ lightllm/models/bloom/model.py | 9 -- lightllm/models/deepseek2/model.py | 9 -- lightllm/models/llama/model.py | 20 ---- lightllm/server/api_cli.py | 2 +- lightllm/server/api_start.py | 14 +-- 12 files changed, 378 insertions(+), 128 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 3eb5d7dbe4..a0cafa25d2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -125,7 +125,7 @@ def __init__(self, kvargs): self._init_some_value() self._init_custom() self._init_inferstate_cls() - # self._autotune_warmup() + self._autotune_warmup() self._init_padded_req() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 2a9006efd6..ec0e282844 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -30,7 +30,7 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay gc.collect() -def load_hf_weights_old(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): +def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): if isinstance(data_type, str): data_type = torch.float16 if data_type == "fp16" else torch.float32 if pre_post_layer is not None: @@ -67,74 +67,7 @@ def load_hf_weights_old(data_type, weight_dir, pre_post_layer=None, transformer_ iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str) - for _ in iterator: pass return - - -def _read_file(file_, use_safetensors, weight_dir): - if use_safetensors: - weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") - weights = {k: weights.get_tensor(k) for k in weights.keys()} - else: - weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") - - return weights - - -def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): - if isinstance(data_type, str): - data_type = torch.float16 if data_type == "fp16" else torch.float32 - if pre_post_layer is not None: - assert pre_post_layer.data_type_ == data_type, "type is not right" - if transformer_layer_list is not None: - assert transformer_layer_list[0].data_type_ == data_type, "type is not right" - if weight_dict: - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weight_dict) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weight_dict) - del weight_dict - return - use_safetensors = True - files = utils.PetrelHelper.list(weight_dir, extension="all") - candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) - if len(candidate_files) == 0: - use_safetensors = False - candidate_files = list(filter(lambda x: x.endswith(".bin"), files)) - assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." - - weight_queue = Queue(maxsize=5) # 控制内存使用 - - def producer(chunk): - for file_ in chunk: - weights = _read_file(file_, use_safetensors, weight_dir) - weight_queue.put(weights) - - LOADWORKER = int(os.environ.get("LOADWORKER", 1)) - - num_producers = min(LOADWORKER, len(candidate_files)) # 生产者数量 - chunk_size = (len(candidate_files) + num_producers - 1) // num_producers - file_chunks = [candidate_files[i : i + chunk_size] for i in range(0, len(candidate_files), chunk_size)] - - producer_threads = [] - for i, chunk in enumerate(file_chunks): - thread = Thread(target=producer, args=(chunk,), name=f"Producer-{i}") - thread.start() - producer_threads.append(thread) - - for _ in tqdm(range(len(candidate_files)), desc="Loading weights"): - weights = weight_queue.get() - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weights) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weights) - del weights - gc.collect() - - for thread in producer_threads: - thread.join() diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index 8c5d1cc1e1..e4f7b552aa 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -127,19 +127,6 @@ def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) return - def quantize_moe(self, weight: torch.Tensor) -> WeightPack: - num_experts = weight.shape[0] - qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) - weight_scales = [] - for i in range(num_experts): - qweight, weight_scale = scaled_fp8_quant( - weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True - ) - qweights[i] = qweight - weight_scales.append(weight_scale) - weight_scale = torch.stack(weight_scales, dim=0).contiguous() - return WeightPack(weight=qweights, weight_scale=weight_scale) - def apply( self, input_tensor: torch.Tensor, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..c75c871c72 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..14026090e6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json new file mode 100644 index 0000000000..939c939523 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 2, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..13ba4ba8e5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 2c341a7906..4a07a7ff5f 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -56,13 +56,4 @@ def _init_weights(self): ) for i in range(self.config["n_layer"]) ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index a081477698..0ac24cf8b6 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -129,15 +129,6 @@ def _init_weights(self): ) for i in range(self.config["n_layer"]) ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return def _init_infer_layer(self): diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index abc258e8bd..4ff802d81e 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -146,26 +146,6 @@ def _init_weights(self): ) for i in range(self.config["n_layer"]) ] - if self.load_way == "HF": - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - else: - load_ds_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - prefix="model.layers.", - num_layer=self.config["n_layer"], - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return def _init_to_get_rotary(self, default_base=10000): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index fa01dd0689..03f751d363 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -363,7 +363,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: "--visual_nccl_ports", nargs="+", type=int, - default=[29500], + default=[], help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 6a02dda17b..ffd794b2d6 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -181,13 +181,13 @@ def normal_or_p_d_start(args: StartArgs): args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] # 检查visual_nccl_port数量是否足够 - if len(args.visual_nccl_ports) < args.visual_dp: - raise ValueError( - f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " - f"but got ({len(args.visual_nccl_ports)})." - ) - else: - args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + # if len(args.visual_nccl_ports) < args.visual_dp: + # raise ValueError( + # f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " + # f"but got ({len(args.visual_nccl_ports)})." + # ) + # else: + # args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] if args.visual_dp <= 0: raise ValueError("visual_dp must be a positive integer.") From f73758582d30e67495c2eb718129b1cf7e6f79b5 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 9 Dec 2025 09:10:59 +0000 Subject: [PATCH 013/180] update requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 20f27dc05a..f062c9d2fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -89,3 +89,4 @@ orjson==3.11.2 setproctitle==1.3.6 xxhash==3.6.0 torch_memory_saver==0.0.9 +portpicker==1.6.0 From 8a67a4751063f0ec89566f18c1091440a7dab2aa Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 26 Dec 2025 10:21:19 +0000 Subject: [PATCH 014/180] add-neo-chat --- lightllm/models/__init__.py | 1 + lightllm/models/neo_chat/__init__.py | 0 .../models/neo_chat/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 7 + .../models/neo_chat/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 23 ++ .../layer_weights/transformer_layer_weight.py | 7 + lightllm/models/neo_chat/model.py | 138 +++++++++ lightllm/models/neo_chat/neo_visual.py | 273 ++++++++++++++++++ lightllm/models/neo_chat/vision_process.py | 141 +++++++++ lightllm/server/tokenizer.py | 3 + .../visualserver/model_infer/model_rpc.py | 3 + 12 files changed, 596 insertions(+) create mode 100644 lightllm/models/neo_chat/__init__.py create mode 100644 lightllm/models/neo_chat/layer_infer/__init__.py create mode 100644 lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/neo_chat/layer_weights/__init__.py create mode 100644 lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/neo_chat/model.py create mode 100644 lightllm/models/neo_chat/neo_visual.py create mode 100644 lightllm/models/neo_chat/vision_process.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 4ee02f003b..5618dfd0cd 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -38,4 +38,5 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel +from lightllm.models.neo_chat.model import NeoTpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..c9297ee84d --- /dev/null +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -0,0 +1,7 @@ +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer + + +class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..7766a5d29f --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..2dc87f3ca7 --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -0,0 +1,7 @@ +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight + + +class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py new file mode 100644 index 0000000000..61fd98b980 --- /dev/null +++ b/lightllm/models/neo_chat/model.py @@ -0,0 +1,138 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer + +IMG_START_TOKEN = "" +IMG_END_TOKEN = "" +IMG_TOKEN = "" +AUDIO_START_TOKEN = "" + + +class NeoChatTokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, **kwargs): + super().__init__(tokenizer) + self.tokenizer = tokenizer + self.min_pixel = model_cfg.get("vision_config").get("min_pixels") + self.max_pixel = model_cfg.get("vision_config").get("max_pixels") + self.patch_size = model_cfg.get("vision_config").get("patch_size") + self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") + + self.image_token_id = model_cfg.get("image_token_id") + self.image_start_tag = IMG_START_TOKEN + self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) + self.image_end_tag = IMG_END_TOKEN + self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=int(self.patch_size // self.downsample_ratio), + min_pixels=self.min_pixel, + max_pixels=self.max_pixel, + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + token_num = (grid_h * grid_w) * (self.downsample_ratio ** 2) + # grid_thwd是为了mrope准备的,这里不需要 + img.grid_thwd = (1, grid_h, grid_w, 0) + return int(token_num) + + # only change the impl of the encode func: + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + # TEXTTEXTTEXT --> TEXTTEXTTEXT + image_tokens = IMG_START_TOKEN + IMG_END_TOKEN + if multimodal_params is None: + add_special_tokens = kwargs.get("add_special_tokens", True) + return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + image_count = len(multimodal_params.images) + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id, start_idx) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + start_idx = 0 + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids[start_idx:]) + return input_ids + + +@ModelRegistry(["neo_chat"], is_multimodal=True) +class NeoTpPartModel(Qwen3MOEModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + + pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight + transformer_weight_class = NeoChatMOETransformerLayerWeight + + infer_state_class = LlamaInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat/neo_visual.py b/lightllm/models/neo_chat/neo_visual.py new file mode 100644 index 0000000000..c9d4b81617 --- /dev/null +++ b/lightllm/models/neo_chat/neo_visual.py @@ -0,0 +1,273 @@ +import os +import torch +import torch.nn.functional as F +from PIL import Image +from typing import List +from io import BytesIO +import torch.nn as nn +from transformers.activations import ACT2FN +from safetensors import safe_open +from lightllm.server.multimodal_params import ImageItem +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from lightllm.models.neo_chat.vision_process import load_image_native +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data + + +def apply_rotary_emb_1d( + x: torch.Tensor, + cos_cached: torch.Tensor, + sin_cached: torch.Tensor, + positions: torch.Tensor, +): + """对输入张量的一部分应用1D RoPE。""" + # x: (..., seq_len, dim_part) + # positions: (..., seq_len) + # cos_cached: (max_pos, dim_part / 2) + cos_cached = cos_cached.to(device=positions.device) + sin_cached = sin_cached.to(device=positions.device) + + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + + x1 = x[..., 0::2] + x2 = x[..., 1::2] + + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + x_rotated = torch.empty_like(x) + x_rotated[..., 0::2] = rotated_x1 + x_rotated[..., 1::2] = rotated_x2 + return x_rotated + + +def apply_2d_rotary_pos_emb( + x: torch.Tensor, + cos_cached_x: torch.Tensor, + sin_cached_x: torch.Tensor, + cos_cached_y: torch.Tensor, + sin_cached_y: torch.Tensor, + abs_positions_x: torch.Tensor, + abs_positions_y: torch.Tensor, +): + """应用2D RoPE到输入张量x。""" + dim = x.shape[-1] + dim_half = dim // 2 + + # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 + # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) + x_part_1 = x[..., :dim_half] + x_part_2 = x[..., dim_half:] + + # 将与 abs_positions_x 相关的旋转应用于 x_part_1 + rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) + # 将与 abs_positions_y 相关的旋转应用于 x_part_2 + rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) + + # 将它们重新拼接起来。确保顺序与你分割时一致。 + return torch.cat((rotated_part_1, rotated_part_2), dim=-1) + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +class NeoVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + hidden_size: int = 1024, + llm_hidden_size: int = 2048, + downsample_ratio: float = 0.5, + patch_size: int = 16, + num_channels: int = 3, + max_position_embeddings_vision: int = 10000, + rope_theta_vision: float = 10000.0, + min_pixels: int = 65536, + max_pixels: int = 2408448, + **kwargs, + ): + super().__init__() + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + self.embed_dim = hidden_size + self.llm_hidden_size = llm_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.downsample_ratio = downsample_ratio + self.downsample_factor = int(1 / downsample_ratio) + self.max_position_embeddings_vision = max_position_embeddings_vision + self.rope_theta_vision = rope_theta_vision + self.rope_dim_part = self.embed_dim // 2 + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.patch_embedding = nn.Conv2d( + in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size + ) + + self.dense_embedding = nn.Conv2d( + in_channels=self.embed_dim, + out_channels=self.llm_hidden_size, + kernel_size=self.downsample_factor, + stride=self.downsample_factor, + ) + self.gelu = nn.GELU() + + self.repe_dim_part = self.embed_dim // 2 + self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() + self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) + self.load_state_dict(weight_dict) + + def precompute_rope_freqs_sincos(self): + inv_freq = 1.0 / ( + self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) + ) + t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + return torch.cos(freqs), torch.sin(freqs) + + def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): + """ + Apply 2D Rotary Position Embedding to the patch embeddings. + """ + abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) + embeddings = apply_2d_rotary_pos_emb( + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_x, + self.sin_x, + self.cos_y, + self.sin_y, + abs_pos_x, + abs_pos_y, + ).to(self.patch_embedding.weight.dtype) + return embeddings + + def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + pixel_values = pixel_values.view( + -1, + 3, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) + assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ + 0 + ], "Grid size and patch embeds size mismatch." + + patches_list = [] + cur_position = 0 + for i in range(grid_hw.shape[0]): + h, w = grid_hw[i] + patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) + patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) + patches_per_img = patches_per_img.permute(0, 2, 3, 1) + patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) + cur_position += h * w + + embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) + assert cur_position == patch_embeds.shape[0] + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) + + return embeddings + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values, image_grid_hw = load_image_native(image_data) + img_tensors.append(pixel_values) + img_grids.append(image_grid_hw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) + print(f"cur_num is {cur_num}") + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_hw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_hw = grid_hw.to("cuda", non_blocking=True) + + all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat/vision_process.py b/lightllm/models/neo_chat/vision_process.py new file mode 100644 index 0000000000..aa008e18fb --- /dev/null +++ b/lightllm/models/neo_chat/vision_process.py @@ -0,0 +1,141 @@ +import re +import math +import torch +import string +import numpy as np +import pandas as pd +from PIL import Image +import torch.distributed as dist +import torchvision.transforms as T + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 +def smart_resize( + height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def preprocess_pixel_values(pixel_values, patch_size=16): + c, h, w = pixel_values.shape + grid_h = h // patch_size + grid_w = w // patch_size + + flatten_pixel_values = ( + pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) + .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] + .reshape(grid_h * grid_w, c * patch_size ** 2) + ) + + grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) + + return flatten_pixel_values, grid_hw + + +def get_contrasting_background(image): + """ + Calculate the color (white or black) that is different from the average foreground color + to use as the background color + """ + image_np = np.array(image) + if (image_np[:, :, 3] == 0).any(): + non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] + if non_transparent_pixels.size == 0: + return None + pixel_mean = non_transparent_pixels.mean() + contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) + return contrasting_color + else: + return None + + +def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): + """ + Load and preprocess an image file, converting it to RGB mode, + resizing, normalizing, and optionally adding a thumbnail version. + """ + if image.mode == "RGBA": + bg_color = get_contrasting_background(image) + if bg_color: + background = Image.new("RGB", image.size, bg_color) + background.paste(image, mask=image.split()[3]) + image = background.convert("RGB") + else: + image = image.convert("RGB") + else: + image = image.convert("RGB") + + if upscale: + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + ) + + new_image = dynamic_preprocess_native_resolution( + image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels + ) + pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) + + print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + + return pixel_values, grid_hw diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425e..17f5a741ac 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -30,6 +30,7 @@ from ..models.qwen2_vl.model import QWen2VLTokenizer from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer +from ..models.neo_chat.model import NeoChatTokenizer from ..models.gemma3.model import Gemma3Tokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. @@ -104,5 +105,7 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "neo_chat": + tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) return tokenizer diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d3d1610f33..d77271af86 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,6 +19,7 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel +from lightllm.models.neo_chat.neo_visual import NeoVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry @@ -78,6 +79,8 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif self.model_type == "neo_chat": + self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() else: raise Exception(f"can not support {self.model_type} now") From fdc1369e315487a59f99c5c52d769bf3328fb43f Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 30 Dec 2025 14:50:41 +0000 Subject: [PATCH 015/180] add-neo-chat --- .../kv_cache_mem_manager/mem_manager.py | 4 +- lightllm/models/llama/model.py | 47 +- lightllm/models/neo_chat/infer_state.py | 95 ++++ .../layer_infer/transformer_layer_infer.py | 166 +++++++ .../layer_weights/transformer_layer_weight.py | 44 ++ lightllm/models/neo_chat/model.py | 17 +- .../models/neo_chat/triton_kernel/__init__.py | 0 .../context_attention_fwd_neo.py | 467 ++++++++++++++++++ .../triton_kernel/get_neo_position.py | 174 +++++++ 9 files changed, 1003 insertions(+), 11 deletions(-) create mode 100644 lightllm/models/neo_chat/infer_state.py create mode 100644 lightllm/models/neo_chat/triton_kernel/__init__.py create mode 100644 lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py create mode 100644 lightllm/models/neo_chat/triton_kernel/get_neo_position.py diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index d8fd93009f..b599bedfc0 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -28,7 +28,7 @@ class MemoryManager: def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num - self.head_dim = head_dim + self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 self.layer_num = layer_num self.always_copy = always_copy self.dtype = dtype @@ -60,7 +60,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.size, dtype, head_num, - head_dim, + self.head_dim, layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index a228e00253..36b5d79b54 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -110,6 +110,8 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() return if "rope_type" in rope_scaling: @@ -132,6 +134,8 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() return def _init_weights(self): @@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta", float(default_base)) - + print(f"base is {base}") if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: @@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000): self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return + def _init_to_get_hw_rotary(self, default_base=10000): + partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) + if self.config.get("rope_scaling", {}) is None: + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) + + base = self.config.get("rope_theta_hw", float(default_base)) + print(f"hw_base is {base}") + if "max_sequence_length" in self.config: + max_seq_len = self.config["max_sequence_length"] + else: + max_position_embeddings = self.config.get( + "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 + ) + max_seq_len = max_position_embeddings * rope_scaling_factor + + # NTK + try: + ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula + except: + pass + + inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) + t = ( + torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) + / rope_scaling_factor + ) + freqs = torch.outer(t, inv_freq) + + self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() + return + def _init_to_get_dynamic_ntk_rotary(self): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) max_position_embeddings = self.config.get("max_position_embeddings", 2048) diff --git a/lightllm/models/neo_chat/infer_state.py b/lightllm/models/neo_chat/infer_state.py new file mode 100644 index 0000000000..9a71c3ddb9 --- /dev/null +++ b/lightllm/models/neo_chat/infer_state.py @@ -0,0 +1,95 @@ +from typing import Optional, List +import torch +import numpy as np +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.neo_chat.triton_kernel.get_neo_position import get_neo_position_triton +from lightllm.models.llama.model import LlamaTpPartModel + + +class NeoChatInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.position_cos_h = None + self.position_sin_h = None + self.position_cos_w = None + self.position_sin_w = None + + def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): + LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) + if self.is_prefill: + self.position_ids = self.get_neo_position(self.multimodal_params) + else: + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + position_delta = 0 + for image in p["images"]: + position_delta += image["grid_thwd"][3] + b_position_delta[batch_idx] = position_delta + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1) + self.position_ids[1:].zero_() + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids[0]] + self.position_sin = model._sin_cached[self.position_ids[0]] + self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] + self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] + self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] + self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] + return + + def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: + if len(multimodal_params) == 0: + return self.position_ids.unsqueeze(0).expand(3, -1) + b_image_start_idx = [] + b_image_nums = [] + b_image_start_num = [] + b_image_len = [] + image_start_num = 0 + b_image_thwd = [] + + # pad multimodal_params to batch size. + batch_size = self.b_q_seq_len.shape[0] + multimodal_params = multimodal_params + [ + {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) + ] + + for _, p in enumerate(multimodal_params): + images = p.get("images", []) + for img in images: + b_image_start_idx.append(img["start_idx"]) + a = img["start_idx"] + print(f"img start_idx: {a}") + b_image_len.append(img["token_num"]) + b_image_thwd.append(img["grid_thwd"]) + b_image_nums.append(len(images)) + b_image_start_num.append(image_start_num) + image_start_num += len(images) + + # 没有任何图片 + if image_start_num == 0: + return self.position_ids.unsqueeze(0).expand(3, -1).contiguous() + b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) + b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 + b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) + b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) + b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) + + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + + get_neo_position_triton( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_start_loc=self.b_start_loc, + ) + return position_ids diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index c9297ee84d..b0ee428563 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -1,7 +1,173 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat.infer_state import NeoChatInferStateInfo +from lightllm.models.neo_chat.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_decode_attention_normal + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + return + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: NeoChatInferStateInfo, + layer_weight: NeoChatMOETransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) + + q_h, q_w = layer_weight.q_hw_proj.mm(input).chunk(2, dim=-1) + k_h, k_w = layer_weight.k_hw_proj.mm(input).chunk(2, dim=-1) + + cache_kv = layer_weight.kv_proj.mm(input) + qk_rmsnorm_forward( + q, + weight=layer_weight.q_norm_weight_.weight, + eps=self.eps_, + ) + + qk_rmsnorm_forward( + q_h, + weight=layer_weight.q_norm_h_weight_.weight, + eps=self.eps_, + ) + + qk_rmsnorm_forward( + q_w, + weight=layer_weight.q_norm_w_weight_.weight, + eps=self.eps_, + ) + + qk_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ) + + qk_rmsnorm_forward( + k_h, + weight=layer_weight.k_norm_h_weight_.weight, + eps=self.eps_, + ) + qk_rmsnorm_forward( + k_w, + weight=layer_weight.k_norm_w_weight_.weight, + eps=self.eps_, + ) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h.view(-1, self.tp_q_head_num_, self.head_dim_ // 2), + k_h.view(-1, self.tp_k_head_num_, self.head_dim_ // 2), + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w.view(-1, self.tp_q_head_num_, self.head_dim_ // 2), + k_w.view(-1, self.tp_k_head_num_, self.head_dim_ // 2), + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + # 拼接q q_h q_w + q = torch.cat([q, q_h, q_w], dim=-1) + # 拼接k k_h k_w + seq_len = cache_kv.shape[0] + k_h = k_h.reshape(seq_len, self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w.reshape(seq_len, self.tp_k_head_num_, self.head_dim_ // 2) + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + # 对齐V的shape + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros( + (seq_len, self.tp_v_head_num_, self.head_dim_), + device=v.device, + dtype=v.dtype, + ) + v = torch.cat([v, v_pad], dim=-1) + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + infer_state.position_ids[0], + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o3 = o3[:, :, : self.head_dim_].contiguous() + return o3.view(o3.shape[0], -1) + + def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + token_att_fwd( + q.view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o_tensor = o_tensor[:, :, : self.head_dim_].contiguous() + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + calcu_shape2 = (batch_size, self.tp_q_head_num_, self.head_dim_) + token_softmax_reducev_fwd( + att_m_tensor, + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(calcu_shape2), + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o_tensor diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py index 2dc87f3ca7..bc38f1adcb 100644 --- a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -1,7 +1,51 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + NormWeight, + ROWMMWeight, +) class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, data_type, network_config, mode, quant_cfg) return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" + self._q_bias_hw_name = None + self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" + self._k_bias_hw_name = None + + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + + def _init_qkv(self): + super()._init_qkv() + self.q_hw_proj = ROWMMWeight( + weight_names=self._q_weight_hw_name, + data_type=self.data_type_, + bias_names=self._q_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="q_hw_proj", + ) + self.k_hw_proj = ROWMMWeight( + weight_names=self._k_weight_hw_name, + data_type=self.data_type_, + bias_names=self._k_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="k_hw_proj", + ) + + def _init_norm(self): + super()._init_norm() + + self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) + self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) + self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) + self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py index 61fd98b980..edc7349864 100644 --- a/lightllm/models/neo_chat/model.py +++ b/lightllm/models/neo_chat/model.py @@ -19,6 +19,7 @@ from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat.infer_state import NeoChatInferStateInfo IMG_START_TOKEN = "" IMG_END_TOKEN = "" @@ -65,10 +66,10 @@ def get_image_token_length(self, img: ImageItem): max_pixels=self.max_pixel, ) grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - token_num = (grid_h * grid_w) * (self.downsample_ratio ** 2) - # grid_thwd是为了mrope准备的,这里不需要 - img.grid_thwd = (1, grid_h, grid_w, 0) - return int(token_num) + token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) + # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 + img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) + return token_num # only change the impl of the encode func: def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): @@ -87,23 +88,23 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): start_idx = 0 while True: try: - start_idx = origin_ids.index(self.image_start_id, start_idx) + start_idx = origin_ids.index(self.image_start_id) if start_idx + 1 >= len(origin_ids): break if origin_ids[start_idx + 1] == self.image_end_id: input_ids.extend(origin_ids[: start_idx + 1]) token_id = multimodal_params.images[image_id].token_id token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) input_ids.extend(range(token_id, token_id + token_num)) input_ids.append(self.image_end_id) origin_ids = origin_ids[start_idx + 2 :] - start_idx = 0 image_id += 1 else: raise ValueError("image token error") except ValueError: break - input_ids.extend(origin_ids[start_idx:]) + input_ids.extend(origin_ids) return input_ids @@ -116,7 +117,7 @@ class NeoTpPartModel(Qwen3MOEModel): pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight transformer_weight_class = NeoChatMOETransformerLayerWeight - infer_state_class = LlamaInferStateInfo + infer_state_class = NeoChatInferStateInfo def __init__(self, kvargs): super().__init__(kvargs) diff --git a/lightllm/models/neo_chat/triton_kernel/__init__.py b/lightllm/models/neo_chat/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py new file mode 100644 index 0000000000..46376502f1 --- /dev/null +++ b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py @@ -0,0 +1,467 @@ +# context_attention_fwd_neo_pos1d.py +# From : https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html + +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + position_ids, # 1D, concatenated by batch order, length = sum(B_Seqlen) + B_Pos_Start, # [batch], prefix sum of B_Seqlen (int32) + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + H: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + # where this request starts inside 1D position_ids + pos_base = tl.load(B_Pos_Start + cur_batch) + + block_start_loc = BLOCK_M * start_m + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + + # load Q for current block + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q_valid = offs_m < cur_batch_seq_len + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # init online softmax + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = total_len + + # query absolute pos inside request: [prompt_cache_len .. total_len-1] + q_pos = prompt_cache_len + offs_m + + # gid by pos (NOT by mem_index) + q_gid = tl.load( + position_ids + pos_base + q_pos, + mask=q_valid, + other=-2147483648, + ).to(tl.int32) + + # main loop over keys by logical pos + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_pos = start_n + offs_n + k_valid = k_pos < block_end_loc + + # gid by pos (NOT by mem_index) + k_gid = tl.load( + position_ids + pos_base + k_pos, + mask=k_valid, + other=-2147483647, + ).to(tl.int32) + + # map logical k_pos -> kv cache mem_index + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + # load K using mem_index + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + + # qk + qk = tl.dot(q, k) + + # mask: causal OR same gid (image block full-attn) + mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) + mask = mask & q_valid[:, None] & k_valid[None, :] + + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + # online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # load V using mem_index + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + + # store + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, # 1D concatenated for this batch + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, +): + # position_ids must cover sum of b_seq_len + # b_pos_start: prefix sum over b_seq_len, defines each request's start inside position_ids + # NOTE: assumes position_ids is concatenated in the SAME order as cur_batch = 0..batch-1 + batch = b_seq_len.shape[0] + device = b_seq_len.device + b_pos_start = torch.zeros((batch,), device=device, dtype=torch.int32) + if batch > 1: + b_pos_start[1:] = torch.cumsum(b_seq_len[:-1].to(torch.int32), dim=0) + + needed = int((b_pos_start[-1] + b_seq_len[-1]).item()) + assert position_ids.numel() >= needed, (position_ids.numel(), needed) + + BLOCK_M = 128 if not is_tesla() else 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256} + + # same trick as your original: exp2 + 1/log(2) + sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 + + head = q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + position_ids, + b_pos_start, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + H=head, + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def reference_attention( + q, + k, + v, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, +): + """ + q: [sum_q, Hq, D] packed by b_start_loc + k/v: [KV_SIZE, Hk, D] by mem_index + position_ids: 1D concatenated by batch order, length = sum(b_seq_len) + """ + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + batch = b_seq_len.shape[0] + b_pos_start = torch.zeros((batch,), device=device, dtype=torch.int64) + if batch > 1: + b_pos_start[1:] = torch.cumsum(b_seq_len[:-1].to(torch.int64), dim=0) + + out = torch.empty_like(q) + + scale = 1.0 / math.sqrt(D) + + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + q_len = total_len - prompt_len + + q_start = int(b_start_loc[b].item()) + q_blk = q[q_start : q_start + q_len] # [M, Hq, D] + + pos_base = int(b_pos_start[b].item()) + gid = position_ids[pos_base : pos_base + total_len].to(torch.int64) # [L] + + # gather K/V for this request by logical pos -> mem_index + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] + k_blk = k[token_locs] # [L, Hk, D] + v_blk = v[token_locs] # [L, Hk, D] + + # expand kv heads to q heads (GQA) + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + + # build mask by pos + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] + allow = (k_pos[None, :] <= q_pos[:, None]) | (gid[q_pos][:, None] == gid[k_pos][None, :]) # [M, L] + + # scores: [Hq, M, L] + q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] + k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] + scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] + + # mask + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + + # softmax + reduce + p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] + v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] + out_hq = torch.matmul(p, v_t) # [Hq, M, D] + out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] + + out[q_start : q_start + q_len] = out_blk + + return out + + +def make_test_case( + device="cuda", + dtype=torch.float16, + batch=3, + Hq=8, + Hk=4, + D=64, + seed=0, + base_index=5000, +): + torch.manual_seed(seed) + + prompt_lens = torch.randint(low=1, high=5, size=(batch,), device=device) + q_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + q_lens).to(torch.int32) + + max_total_len = int(total_lens.max().item()) + + # b_start_loc for packed q (q_len per batch) + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(q_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + + # one req per batch for test + num_req = batch + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + # build a global KV "mem_index" space with offset, to simulate large indices + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 16 + + # allocate unique mem indices + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + # Req_to_tokens: [num_req, max_total_len] + req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # position_ids: 1D concatenated by batch order (length = sum_kv) + position_ids = torch.empty((sum_kv,), device=device, dtype=torch.int32) + off = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + gid = torch.arange(L, device=device, dtype=torch.int32) + + # make 0-2 repeated blocks (simulate image tokens) + if L >= 4: + # repeat a short block + s = int(torch.randint(0, max(1, L - 2), (1,), device=device).item()) + e = min(L, s + int(torch.randint(2, min(4, L - s) + 1, (1,), device=device).item())) + gid[s:e] = gid[s] + if L >= 8 and torch.rand((), device=device).item() > 0.5: + s = 4 + e = min(L, 7) + gid[s:e] = gid[s] + + position_ids[off : off + L] = gid + off += L + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_total_len, + req_to_token_indexs, + ) + + +def check_once(device="cuda", dtype=torch.float16, seed=0): + ( + q, + k, + v, + o, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_total_len, + req_to_token_indexs, + ) = make_test_case(device=device, dtype=dtype, seed=seed) + + # triton + context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_total_len, + req_to_token_indexs, + ) + + # reference + ref = reference_attention( + q, + k, + v, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + + print(f"seed={seed}, dtype={dtype}") + print(f"max_abs_error = {max_abs:.6e}") + print(f"max_rel_error = {max_rel:.6e}") + print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) + + +# if __name__ == "__main__": +# if not torch.cuda.is_available(): +# print("No CUDA, skip Triton check.") +# else: +# torch.cuda.synchronize() +# check_once(dtype=torch.float16, seed=0) +# check_once(dtype=torch.float16, seed=1) +# check_once(dtype=torch.float16, seed=2) diff --git a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py new file mode 100644 index 0000000000..5cf270a127 --- /dev/null +++ b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py @@ -0,0 +1,174 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_thwd_stride0: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + position_ids_stride0: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + BLOCK_SIZE: tl.constexpr, +) -> torch.Tensor: + cur_batch = tl.program_id(0) + cache_len = tl.load(b_ready_cache_len + cur_batch) + q_seq_len = tl.load(b_q_seq_len + cur_batch) + image_num = tl.load(b_image_nums + cur_batch) + image_start_num = tl.load(b_image_start_num + cur_batch) + start_loc = tl.load(b_start_loc + cur_batch) + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_start_idx = start_loc + local_image_start_idx - cache_len + image_len = tl.load(b_image_len + image_start_num + i) + image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + # 目前没考虑视频,所以t 恒为 0 + t_pos = local_image_start_idx + off * 0 + h_pos = off // image_h + w_pos = off % image_w + tl.store( + position_ids + off + image_start_idx, + t_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 + off + image_start_idx, + h_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 * 2 + off + image_start_idx, + w_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_len = tl.load(b_image_len + image_start_num + i) + image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) + image_end = local_image_start_idx + image_len - cache_len + text_start = tl.maximum(0, image_end) + for j in range(text_start, q_seq_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta + h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) + w_pos = tl.load( + position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 + ) + tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) + return + + +def get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, +) -> torch.Tensor: + + batch_size = b_q_seq_len.shape[0] + assert batch_size == b_image_nums.shape[0] + grid = (batch_size,) + BLOCK_SIZE = 64 + _get_neo_position_triton[grid]( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_thwd_stride0=b_image_thwd.stride(0), + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + position_ids_stride0=position_ids.stride(0), + b_ready_cache_len=b_ready_cache_len, + b_q_seq_len=b_q_seq_len, + b_start_loc=b_start_loc, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def test(): + b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") + b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") + b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") + position_ids = ( + torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + .unsqueeze(0) + .expand(3, -1) + .contiguous() + ) + position_ids[1:].zero_() + b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") + b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + get_neo_position_triton( + b_image_start_idx, + b_image_thwd, + b_image_nums, + b_image_start_num, + b_image_len, + position_ids, + b_ready_cache_len, + b_q_seq_len, + b_start_loc, + ) + + print(position_ids) + # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) + + # position_ids = ( + # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + # .unsqueeze(0) + # .expand(3, -1) + # .contiguous() + # ) + # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") + # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") + # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") + + # get_neo_position_triton( + # b_image_start_idx, + # b_image_thwd, + # b_image_nums, + # b_image_start_num, + # b_image_len, + # position_ids, + # b_ready_cache_len, + # b_q_seq_len, + # b_start_loc, + # ) + + # print(f"old_value:\n{old_value}") + # print(f"position_ids:\n{position_ids}") + # assert torch.equal(old_value, position_ids) + + """ + tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], + [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], + [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], + device='cuda:0', dtype=torch.int32) + """ From e8e74168c24b2c34a117561c5bec245e930aaa52 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 31 Dec 2025 05:12:28 +0000 Subject: [PATCH 016/180] add-neo-chat --- .../neo_chat/{infer_state.py => infer_struct.py} | 10 +++++++--- .../neo_chat/layer_infer/transformer_layer_infer.py | 2 +- lightllm/models/neo_chat/model.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) rename lightllm/models/neo_chat/{infer_state.py => infer_struct.py} (91%) diff --git a/lightllm/models/neo_chat/infer_state.py b/lightllm/models/neo_chat/infer_struct.py similarity index 91% rename from lightllm/models/neo_chat/infer_state.py rename to lightllm/models/neo_chat/infer_struct.py index 9a71c3ddb9..8e5347e8be 100644 --- a/lightllm/models/neo_chat/infer_state.py +++ b/lightllm/models/neo_chat/infer_struct.py @@ -29,7 +29,7 @@ def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor position_delta += image["grid_thwd"][3] b_position_delta[batch_idx] = position_delta position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) - self.position_ids = position_ids.unsqueeze(0).expand(3, -1) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() self.position_ids[1:].zero_() self.position_ids = self.position_ids.contiguous() @@ -43,7 +43,9 @@ def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: if len(multimodal_params) == 0: - return self.position_ids.unsqueeze(0).expand(3, -1) + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids b_image_start_idx = [] b_image_nums = [] b_image_start_num = [] @@ -71,7 +73,9 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: # 没有任何图片 if image_start_num == 0: - return self.position_ids.unsqueeze(0).expand(3, -1).contiguous() + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids.contiguous() b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index b0ee428563..e6b0402bbb 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -3,7 +3,7 @@ from typing import Tuple from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.neo_chat.infer_state import NeoChatInferStateInfo +from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo from lightllm.models.neo_chat.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py index edc7349864..0cc469cead 100644 --- a/lightllm/models/neo_chat/model.py +++ b/lightllm/models/neo_chat/model.py @@ -19,7 +19,7 @@ from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer -from lightllm.models.neo_chat.infer_state import NeoChatInferStateInfo +from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo IMG_START_TOKEN = "" IMG_END_TOKEN = "" From ba4498317f427f316c15da8a6301e72422d94136 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 31 Dec 2025 05:28:39 +0000 Subject: [PATCH 017/180] add-neo-chat --- .../context_attention_fwd_neo.py | 217 ++++++++---------- 1 file changed, 101 insertions(+), 116 deletions(-) diff --git a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py index 46376502f1..80fc2ea44e 100644 --- a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py @@ -1,6 +1,3 @@ -# context_attention_fwd_neo_pos1d.py -# From : https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html - import math import torch import triton @@ -16,8 +13,7 @@ def _fwd_kernel( V, sm_scale, Out, - position_ids, # 1D, concatenated by batch order, length = sum(B_Seqlen) - B_Pos_Start, # [batch], prefix sum of B_Seqlen (int32) + position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] B_Start_Loc, B_Seqlen, Req_to_tokens, @@ -53,28 +49,26 @@ def _fwd_kernel( cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) total_len = tl.load(B_Seqlen + cur_batch) - cur_batch_seq_len = total_len - prompt_cache_len + cur_batch_seq_len = total_len - prompt_cache_len # NEW len cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - # where this request starts inside 1D position_ids - pos_base = tl.load(B_Pos_Start + cur_batch) - block_start_loc = BLOCK_M * start_m offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = block_start_loc + tl.arange(0, BLOCK_M) - # load Q for current block + # Q pointers off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd ) + q_valid = offs_m < cur_batch_seq_len q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) - # init online softmax + # online softmax state m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -82,44 +76,55 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = total_len - # query absolute pos inside request: [prompt_cache_len .. total_len-1] - q_pos = prompt_cache_len + offs_m + # absolute q positions in the request + q_pos = prompt_cache_len + offs_m # [M] - # gid by pos (NOT by mem_index) + # q_gid from packed position_ids (aligned with Q rows) q_gid = tl.load( - position_ids + pos_base + q_pos, + position_ids + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=-2147483648, ).to(tl.int32) - # main loop over keys by logical pos + BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - k_pos = start_n + offs_n - k_valid = k_pos < block_end_loc - # gid by pos (NOT by mem_index) - k_gid = tl.load( - position_ids + pos_base + k_pos, - mask=k_valid, - other=-2147483647, - ).to(tl.int32) + k_pos = start_n + offs_n # [N] + k_valid = k_pos < block_end_loc - # map logical k_pos -> kv cache mem_index + # map logical pos -> mem_index (for K/V) kv_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, mask=k_valid, other=0, ).to(tl.int64) - # load K using mem_index + # load K off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) - # qk qk = tl.dot(q, k) - # mask: causal OR same gid (image block full-attn) + # k_gid: + # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false + # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) + k_in_new = k_pos >= prompt_cache_len + k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new + k_gid_new = tl.load( + position_ids + cur_batch_in_all_start_index + k_new_idx, + mask=k_valid & k_in_new, + other=-2147483647, + ).to(tl.int32) + + k_gid = tl.where( + k_in_new, + k_gid_new, + (k_pos.to(tl.int32) + BIG), + ) + + # mask: causal OR same gid (only possible inside NEW part) mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) mask = mask & q_valid[:, None] & k_valid[None, :] @@ -127,7 +132,7 @@ def _fwd_kernel( # online softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] + qk -= m_ij[:, None] p = tl.math.exp2(qk) l_ij = tl.sum(p, 1) @@ -135,7 +140,7 @@ def _fwd_kernel( l_i = l_i * alpha + l_ij acc = acc * alpha[:, None] - # load V using mem_index + # load V off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) @@ -146,14 +151,12 @@ def _fwd_kernel( acc = acc / l_i[:, None] - # store off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=q_valid[:, None]) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) @torch.no_grad() @@ -162,7 +165,7 @@ def context_attention_fwd_neo( k, v, o, - position_ids, # 1D concatenated for this batch + position_ids, # 1D packed like q (only NEW tokens) b_req_idx, b_start_loc, b_seq_len, @@ -170,17 +173,8 @@ def context_attention_fwd_neo( max_input_len, req_to_token_indexs, ): - # position_ids must cover sum of b_seq_len - # b_pos_start: prefix sum over b_seq_len, defines each request's start inside position_ids - # NOTE: assumes position_ids is concatenated in the SAME order as cur_batch = 0..batch-1 - batch = b_seq_len.shape[0] - device = b_seq_len.device - b_pos_start = torch.zeros((batch,), device=device, dtype=torch.int32) - if batch > 1: - b_pos_start[1:] = torch.cumsum(b_seq_len[:-1].to(torch.int32), dim=0) - - needed = int((b_pos_start[-1] + b_seq_len[-1]).item()) - assert position_ids.numel() >= needed, (position_ids.numel(), needed) + # minimal safety: position_ids must cover packed q rows + assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) BLOCK_M = 128 if not is_tesla() else 64 @@ -188,10 +182,9 @@ def context_attention_fwd_neo( assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128, 256} - # same trick as your original: exp2 + 1/log(2) sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 - head = q.shape[1] + batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) @@ -207,7 +200,6 @@ def context_attention_fwd_neo( sm_scale, o, position_ids, - b_pos_start, b_start_loc, b_seq_len, req_to_token_indexs, @@ -241,18 +233,13 @@ def reference_attention( q, k, v, - position_ids, + position_ids_q, # 1D packed like q (only NEW tokens) b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, req_to_token_indexs, ): - """ - q: [sum_q, Hq, D] packed by b_start_loc - k/v: [KV_SIZE, Hk, D] by mem_index - position_ids: 1D concatenated by batch order, length = sum(b_seq_len) - """ device = q.device dtype = q.dtype sum_q, Hq, D = q.shape @@ -260,27 +247,20 @@ def reference_attention( kv_group_num = Hq // Hk batch = b_seq_len.shape[0] - b_pos_start = torch.zeros((batch,), device=device, dtype=torch.int64) - if batch > 1: - b_pos_start[1:] = torch.cumsum(b_seq_len[:-1].to(torch.int64), dim=0) - out = torch.empty_like(q) - scale = 1.0 / math.sqrt(D) for b in range(batch): req = int(b_req_idx[b].item()) total_len = int(b_seq_len[b].item()) prompt_len = int(b_prompt_cache_len[b].item()) - q_len = total_len - prompt_len + new_len = total_len - prompt_len q_start = int(b_start_loc[b].item()) - q_blk = q[q_start : q_start + q_len] # [M, Hq, D] - - pos_base = int(b_pos_start[b].item()) - gid = position_ids[pos_base : pos_base + total_len].to(torch.int64) # [L] + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] - # gather K/V for this request by logical pos -> mem_index + # gather K/V for full request by logical pos -> mem_index token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] k_blk = k[token_locs] # [L, Hk, D] v_blk = v[token_locs] # [L, Hk, D] @@ -289,27 +269,39 @@ def reference_attention( k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] - # build mask by pos + # positions q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] - allow = (k_pos[None, :] <= q_pos[:, None]) | (gid[q_pos][:, None] == gid[k_pos][None, :]) # [M, L] + + # build allow mask: + # causal always + allow = k_pos[None, :] <= q_pos[:, None] + + # full-attn only inside NEW part by gid + # compare only when k_pos in NEW + k_in_new = k_pos >= prompt_len + k_rel = (k_pos - prompt_len).clamp_min(0) # [L] + # map k_rel to gid_new, but only valid where k_in_new + k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) + k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new + k_gid[k_in_new] = gid_new[k_rel[k_in_new]] + + allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) # scores: [Hq, M, L] q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] - # mask neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) scores = torch.where(allow[None, :, :], scores, neg) - # softmax + reduce p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] out_hq = torch.matmul(p, v_t) # [Hq, M, D] out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] - out[q_start : q_start + q_len] = out_blk + out[q_start : q_start + new_len] = out_blk return out @@ -322,39 +314,39 @@ def make_test_case( Hk=4, D=64, seed=0, - base_index=5000, + base_index=50000, ): torch.manual_seed(seed) - prompt_lens = torch.randint(low=1, high=5, size=(batch,), device=device) - q_lens = torch.randint(low=2, high=8, size=(batch,), device=device) - total_lens = (prompt_lens + q_lens).to(torch.int32) + # prompt (cached) len and new len + prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) - # b_start_loc for packed q (q_len per batch) + # packed q start b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) cur = 0 for b in range(batch): b_start_loc[b] = cur - cur += int(q_lens[b].item()) + cur += int(new_lens[b].item()) sum_q = cur b_seq_len = total_lens b_prompt_cache_len = prompt_lens.to(torch.int32) - # one req per batch for test + # one req per batch num_req = batch b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) - # build a global KV "mem_index" space with offset, to simulate large indices + # global KV space large, indices not small sum_kv = int(total_lens.sum().item()) - kv_size = base_index + sum_kv + 16 - - # allocate unique mem indices + kv_size = base_index + sum_kv + 1024 pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index - # Req_to_tokens: [num_req, max_total_len] + # Req_to_tokens [num_req, max_total_len] req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) p = 0 for r in range(num_req): @@ -362,26 +354,21 @@ def make_test_case( req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) p += L - # position_ids: 1D concatenated by batch order (length = sum_kv) - position_ids = torch.empty((sum_kv,), device=device, dtype=torch.int32) - off = 0 - for r in range(num_req): - L = int(total_lens[r].item()) - gid = torch.arange(L, device=device, dtype=torch.int32) + # position_ids_q: only NEW tokens, packed like q + position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + start = int(b_start_loc[b].item()) - # make 0-2 repeated blocks (simulate image tokens) - if L >= 4: - # repeat a short block - s = int(torch.randint(0, max(1, L - 2), (1,), device=device).item()) - e = min(L, s + int(torch.randint(2, min(4, L - s) + 1, (1,), device=device).item())) - gid[s:e] = gid[s] - if L >= 8 and torch.rand((), device=device).item() > 0.5: - s = 4 - e = min(L, 7) + gid = torch.arange(M, device=device, dtype=torch.int32) + + # make one repeated block inside NEW part to simulate image tokens + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + e = min(M, s + 3) gid[s:e] = gid[s] - position_ids[off : off + L] = gid - off += L + position_ids_q[start : start + M] = gid q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) @@ -393,12 +380,12 @@ def make_test_case( k, v, o, - position_ids, + position_ids_q, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, - max_total_len, + max_new_len, req_to_token_indexs, ) @@ -409,36 +396,34 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): k, v, o, - position_ids, + position_ids_q, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, - max_total_len, + max_new_len, req_to_token_indexs, ) = make_test_case(device=device, dtype=dtype, seed=seed) - # triton context_attention_fwd_neo( q, k, v, o, - position_ids, + position_ids_q, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, - max_total_len, + max_new_len, req_to_token_indexs, ) - # reference ref = reference_attention( q, k, v, - position_ids, + position_ids_q, b_req_idx, b_start_loc, b_seq_len, @@ -457,11 +442,11 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) -# if __name__ == "__main__": -# if not torch.cuda.is_available(): -# print("No CUDA, skip Triton check.") -# else: -# torch.cuda.synchronize() -# check_once(dtype=torch.float16, seed=0) -# check_once(dtype=torch.float16, seed=1) -# check_once(dtype=torch.float16, seed=2) +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + torch.cuda.synchronize() + check_once(dtype=torch.bfloat16, seed=0) + check_once(dtype=torch.bfloat16, seed=1) + check_once(dtype=torch.bfloat16, seed=2) From 4d41a33fcd6d3447661486b6c79c12f268410a5f Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 31 Dec 2025 10:27:23 +0000 Subject: [PATCH 018/180] add-neo-chat --- .../token_attention_nopad_att1.py | 3 +- .../layer_infer/transformer_layer_infer.py | 120 ++++++++---------- .../context_attention_fwd_neo.py | 4 +- 3 files changed, 57 insertions(+), 70 deletions(-) diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index eb5af6fecd..02bd277adb 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -74,7 +74,8 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk assert Lk in {16, 32, 64, 128, 256} - sm_scale = 1.0 / (Lk ** 0.5) + Lk_scale = Lk // 2 + sm_scale = 1.0 / (Lk_scale ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index e6b0402bbb..b0105131f3 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -26,36 +26,28 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal return - def _get_qkv( - self, - input: torch.Tensor, - infer_state: NeoChatInferStateInfo, - layer_weight: NeoChatMOETransformerLayerWeight, - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) + q = layer_weight.q_proj.mm(input) # [T, Hq*D] - q_h, q_w = layer_weight.q_hw_proj.mm(input).chunk(2, dim=-1) - k_h, k_w = layer_weight.k_hw_proj.mm(input).chunk(2, dim=-1) + q_hw = layer_weight.q_hw_proj.mm(input) + q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) - cache_kv = layer_weight.kv_proj.mm(input) - qk_rmsnorm_forward( - q, - weight=layer_weight.q_norm_weight_.weight, - eps=self.eps_, - ) + k_hw = layer_weight.k_hw_proj.mm(input) + k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) - qk_rmsnorm_forward( - q_h, - weight=layer_weight.q_norm_h_weight_.weight, - eps=self.eps_, - ) + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - qk_rmsnorm_forward( - q_w, - weight=layer_weight.q_norm_w_weight_.weight, - eps=self.eps_, - ) + qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + + q_h_2d = q_h.reshape(q.shape[0], -1) + q_w_2d = q_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) qk_rmsnorm_forward( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -63,17 +55,15 @@ def _get_qkv( eps=self.eps_, ) - qk_rmsnorm_forward( - k_h, - weight=layer_weight.k_norm_h_weight_.weight, - eps=self.eps_, - ) - qk_rmsnorm_forward( - k_w, - weight=layer_weight.k_norm_w_weight_.weight, - eps=self.eps_, - ) + k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + k_w_2d = k_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, : self.tp_k_head_num_, :], @@ -81,33 +71,29 @@ def _get_qkv( infer_state.position_sin, ) rotary_emb_fwd( - q_h.view(-1, self.tp_q_head_num_, self.head_dim_ // 2), - k_h.view(-1, self.tp_k_head_num_, self.head_dim_ // 2), + q_h, + k_h, infer_state.position_cos_h, infer_state.position_sin_h, ) rotary_emb_fwd( - q_w.view(-1, self.tp_q_head_num_, self.head_dim_ // 2), - k_w.view(-1, self.tp_k_head_num_, self.head_dim_ // 2), + q_w, + k_w, infer_state.position_cos_w, infer_state.position_sin_w, ) - # 拼接q q_h q_w - q = torch.cat([q, q_h, q_w], dim=-1) - # 拼接k k_h k_w - seq_len = cache_kv.shape[0] - k_h = k_h.reshape(seq_len, self.tp_k_head_num_, self.head_dim_ // 2) - k_w = k_w.reshape(seq_len, self.tp_k_head_num_, self.head_dim_ // 2) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + k = cache_kv[:, : self.tp_k_head_num_, :] k = torch.cat([k, k_h, k_w], dim=-1) - # 对齐V的shape + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] - v_pad = torch.zeros( - (seq_len, self.tp_v_head_num_, self.head_dim_), - device=v.device, - dtype=v.dtype, - ) + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) v = torch.cat([v, v_pad], dim=-1) + cache_kv = torch.cat([k, v], dim=1) return q, cache_kv @@ -121,7 +107,7 @@ def _context_attention_kernel( kv[:, 0 : self.tp_k_head_num_, :], kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), - infer_state.position_ids[0], + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, @@ -136,13 +122,15 @@ def _context_attention_kernel( def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] token_att_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + q_3d, + k_3d, att_m_tensor, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -150,24 +138,22 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, infer_state.b_seq_len, infer_state.max_len_in_batch, ) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) - o_tensor = o_tensor[:, :, : self.head_dim_].contiguous() - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + + v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + ] + + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out - calcu_shape2 = (batch_size, self.tp_q_head_num_, self.head_dim_) token_softmax_reducev_fwd( att_m_tensor, - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape2), + v_3d, + o_3d, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, ) - return o_tensor + return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py index 80fc2ea44e..f5dae493cb 100644 --- a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py @@ -181,8 +181,8 @@ def context_attention_fwd_neo( Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128, 256} - - sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 + base_head_dim = Lq // 2 + sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] From 0e8845c160ce2563815eb0dfc5d851d209b5dbd7 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 1 Jan 2026 16:38:33 +0000 Subject: [PATCH 019/180] fix-neo-chat --- lightllm/models/neo_chat/neo_visual.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lightllm/models/neo_chat/neo_visual.py b/lightllm/models/neo_chat/neo_visual.py index c9d4b81617..16b30511e2 100644 --- a/lightllm/models/neo_chat/neo_visual.py +++ b/lightllm/models/neo_chat/neo_visual.py @@ -247,7 +247,13 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - pixel_values, image_grid_hw = load_image_native(image_data) + pixel_values, image_grid_hw = load_image_native( + image_data, + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) img_tensors.append(pixel_values) img_grids.append(image_grid_hw) else: From b48cd499e1b661e55ca8f32c5f6f0164e1da7045 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 5 Jan 2026 05:00:10 +0000 Subject: [PATCH 020/180] fix-neo-chat-position-ids-h --- lightllm/models/neo_chat/triton_kernel/get_neo_position.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py index 5cf270a127..955f48bd80 100644 --- a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py +++ b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py @@ -28,13 +28,13 @@ def _get_neo_position_triton( local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_start_idx = start_loc + local_image_start_idx - cache_len image_len = tl.load(b_image_len + image_start_num + i) - image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) for j in range(0, image_len, BLOCK_SIZE): off = j + tl.arange(0, BLOCK_SIZE) # 目前没考虑视频,所以t 恒为 0 t_pos = local_image_start_idx + off * 0 - h_pos = off // image_h + h_pos = off // image_w w_pos = off % image_w tl.store( position_ids + off + image_start_idx, From 7a904f39d054ad3271683886f8fd0bece0cce665 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 6 Jan 2026 08:14:56 +0000 Subject: [PATCH 021/180] add-neo-chat-dense --- lightllm/models/__init__.py | 1 + lightllm/models/neo_chat/infer_struct.py | 99 ---- .../layer_infer/transformer_layer_infer.py | 12 +- .../pre_and_post_layer_weight.py | 2 +- .../layer_weights/transformer_layer_weight.py | 4 +- lightllm/models/neo_chat/model.py | 108 +---- lightllm/models/neo_chat/neo_visual.py | 279 ----------- .../models/neo_chat/triton_kernel/__init__.py | 0 .../context_attention_fwd_neo.py | 452 ------------------ .../triton_kernel/get_neo_position.py | 174 ------- lightllm/models/neo_chat/vision_process.py | 141 ------ lightllm/server/tokenizer.py | 2 +- .../visualserver/model_infer/model_rpc.py | 2 +- 13 files changed, 23 insertions(+), 1253 deletions(-) delete mode 100644 lightllm/models/neo_chat/infer_struct.py delete mode 100644 lightllm/models/neo_chat/neo_visual.py delete mode 100644 lightllm/models/neo_chat/triton_kernel/__init__.py delete mode 100644 lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py delete mode 100644 lightllm/models/neo_chat/triton_kernel/get_neo_position.py delete mode 100644 lightllm/models/neo_chat/vision_process.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 5618dfd0cd..9a51d9512f 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -38,5 +38,6 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel +from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel from lightllm.models.neo_chat.model import NeoTpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/neo_chat/infer_struct.py b/lightllm/models/neo_chat/infer_struct.py deleted file mode 100644 index 8e5347e8be..0000000000 --- a/lightllm/models/neo_chat/infer_struct.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Optional, List -import torch -import numpy as np -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager -from lightllm.models.neo_chat.triton_kernel.get_neo_position import get_neo_position_triton -from lightllm.models.llama.model import LlamaTpPartModel - - -class NeoChatInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self.position_cos = None - self.position_sin = None - self.position_cos_h = None - self.position_sin_h = None - self.position_cos_w = None - self.position_sin_w = None - - def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): - LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) - if self.is_prefill: - self.position_ids = self.get_neo_position(self.multimodal_params) - else: - b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] - for batch_idx, p in enumerate(self.multimodal_params): - position_delta = 0 - for image in p["images"]: - position_delta += image["grid_thwd"][3] - b_position_delta[batch_idx] = position_delta - position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) - self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() - self.position_ids[1:].zero_() - - self.position_ids = self.position_ids.contiguous() - self.position_cos = model._cos_cached[self.position_ids[0]] - self.position_sin = model._sin_cached[self.position_ids[0]] - self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] - self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] - self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] - self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] - return - - def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: - if len(multimodal_params) == 0: - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - return position_ids - b_image_start_idx = [] - b_image_nums = [] - b_image_start_num = [] - b_image_len = [] - image_start_num = 0 - b_image_thwd = [] - - # pad multimodal_params to batch size. - batch_size = self.b_q_seq_len.shape[0] - multimodal_params = multimodal_params + [ - {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) - ] - - for _, p in enumerate(multimodal_params): - images = p.get("images", []) - for img in images: - b_image_start_idx.append(img["start_idx"]) - a = img["start_idx"] - print(f"img start_idx: {a}") - b_image_len.append(img["token_num"]) - b_image_thwd.append(img["grid_thwd"]) - b_image_nums.append(len(images)) - b_image_start_num.append(image_start_num) - image_start_num += len(images) - - # 没有任何图片 - if image_start_num == 0: - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - return position_ids.contiguous() - b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) - b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 - b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) - b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) - b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) - - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - - get_neo_position_triton( - b_image_start_idx=b_image_start_idx, - b_image_thwd=b_image_thwd, - b_image_nums=b_image_nums, - b_image_start_num=b_image_start_num, - b_image_len=b_image_len, - position_ids=position_ids, - b_ready_cache_len=self.b_ready_cache_len, - b_q_seq_len=self.b_q_seq_len, - b_start_loc=self.b_start_loc, - ) - return position_ids diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index b0105131f3..1cf13c4130 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -3,19 +3,19 @@ from typing import Tuple from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo -from lightllm.models.neo_chat.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd -from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer -from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight from lightllm.distributed import all_reduce import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward -class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): +class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return @@ -26,7 +26,7 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal return - def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight): input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) # [T, Hq*D] diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py index 7766a5d29f..c1f0638ac4 100644 --- a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -12,7 +12,7 @@ def rename_weight_keys(weights): weights[k.replace(prefix, "")] = weights.pop(k) -class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): +class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py index bc38f1adcb..e5e769a769 100644 --- a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -1,11 +1,11 @@ -from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( NormWeight, ROWMMWeight, ) -class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): +class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, data_type, network_config, mode, quant_cfg) return diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py index 0cc469cead..14d9f96dc7 100644 --- a/lightllm/models/neo_chat/model.py +++ b/lightllm/models/neo_chat/model.py @@ -1,7 +1,7 @@ import os import json from lightllm.common.build_utils import repair_config -from lightllm.models.registry import ModelRegistry +from lightllm.models.registry import ModelRegistry, llm_model_type_is from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer @@ -11,111 +11,25 @@ from lightllm.server.core.objs import SamplingParams from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem -from lightllm.models.neo_chat.vision_process import smart_resize +from lightllm.models.neo_chat_moe.vision_process import smart_resize from lightllm.models.internvl.model import InternvlTokenizer from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer -from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight -from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight +from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer -from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo -IMG_START_TOKEN = "" -IMG_END_TOKEN = "" -IMG_TOKEN = "" -AUDIO_START_TOKEN = "" - -class NeoChatTokenizer(BaseMultiModalTokenizer): - def __init__(self, tokenizer, model_cfg, **kwargs): - super().__init__(tokenizer) - self.tokenizer = tokenizer - self.min_pixel = model_cfg.get("vision_config").get("min_pixels") - self.max_pixel = model_cfg.get("vision_config").get("max_pixels") - self.patch_size = model_cfg.get("vision_config").get("patch_size") - self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") - - self.image_token_id = model_cfg.get("image_token_id") - self.image_start_tag = IMG_START_TOKEN - self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) - self.image_end_tag = IMG_END_TOKEN - self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) - - def init_imageitem_extral_params( - self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams - ): - return - - def init_audioitem_extral_params( - self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams - ): - raise NotImplementedError - - def get_audio_token_length(self, audio: AudioItem): - raise NotImplementedError - - def get_image_token_length(self, img: ImageItem): - width, height = img.image_w, img.image_h - resized_height, resized_width = smart_resize( - height=height, - width=width, - factor=int(self.patch_size // self.downsample_ratio), - min_pixels=self.min_pixel, - max_pixels=self.max_pixel, - ) - grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) - # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 - img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) - return token_num - - # only change the impl of the encode func: - def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): - # TEXTTEXTTEXT --> TEXTTEXTTEXT - image_tokens = IMG_START_TOKEN + IMG_END_TOKEN - if multimodal_params is None: - add_special_tokens = kwargs.get("add_special_tokens", True) - return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) - image_count = len(multimodal_params.images) - prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) - - origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) - # --> id,id+1...id+num - input_ids = [] - image_id = 0 - start_idx = 0 - while True: - try: - start_idx = origin_ids.index(self.image_start_id) - if start_idx + 1 >= len(origin_ids): - break - if origin_ids[start_idx + 1] == self.image_end_id: - input_ids.extend(origin_ids[: start_idx + 1]) - token_id = multimodal_params.images[image_id].token_id - token_num = multimodal_params.images[image_id].token_num - multimodal_params.images[image_id].start_idx = len(input_ids) - input_ids.extend(range(token_id, token_id + token_num)) - input_ids.append(self.image_end_id) - origin_ids = origin_ids[start_idx + 2 :] - image_id += 1 - else: - raise ValueError("image token error") - except ValueError: - break - input_ids.extend(origin_ids) - return input_ids - - -@ModelRegistry(["neo_chat"], is_multimodal=True) -class NeoTpPartModel(Qwen3MOEModel): +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3")) +class NeoTpPartModel(Qwen3TpPartModel): pre_layer_infer_class = LlamaMultimodalPreLayerInfer - transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + transformer_layer_infer_class = NeoChatTransformerLayerInfer - pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight - transformer_weight_class = NeoChatMOETransformerLayerWeight + pre_and_post_weight_class = NeoChatPreAndPostLayerWeight + transformer_weight_class = NeoChatTransformerLayerWeight infer_state_class = NeoChatInferStateInfo diff --git a/lightllm/models/neo_chat/neo_visual.py b/lightllm/models/neo_chat/neo_visual.py deleted file mode 100644 index 16b30511e2..0000000000 --- a/lightllm/models/neo_chat/neo_visual.py +++ /dev/null @@ -1,279 +0,0 @@ -import os -import torch -import torch.nn.functional as F -from PIL import Image -from typing import List -from io import BytesIO -import torch.nn as nn -from transformers.activations import ACT2FN -from safetensors import safe_open -from lightllm.server.multimodal_params import ImageItem -from transformers.modeling_outputs import BaseModelOutputWithPooling -from transformers.modeling_utils import PreTrainedModel -from lightllm.models.neo_chat.vision_process import load_image_native -from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data - - -def apply_rotary_emb_1d( - x: torch.Tensor, - cos_cached: torch.Tensor, - sin_cached: torch.Tensor, - positions: torch.Tensor, -): - """对输入张量的一部分应用1D RoPE。""" - # x: (..., seq_len, dim_part) - # positions: (..., seq_len) - # cos_cached: (max_pos, dim_part / 2) - cos_cached = cos_cached.to(device=positions.device) - sin_cached = sin_cached.to(device=positions.device) - - cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) - sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) - - x1 = x[..., 0::2] - x2 = x[..., 1::2] - - rotated_x1 = x1 * cos - x2 * sin - rotated_x2 = x1 * sin + x2 * cos - - x_rotated = torch.empty_like(x) - x_rotated[..., 0::2] = rotated_x1 - x_rotated[..., 1::2] = rotated_x2 - return x_rotated - - -def apply_2d_rotary_pos_emb( - x: torch.Tensor, - cos_cached_x: torch.Tensor, - sin_cached_x: torch.Tensor, - cos_cached_y: torch.Tensor, - sin_cached_y: torch.Tensor, - abs_positions_x: torch.Tensor, - abs_positions_y: torch.Tensor, -): - """应用2D RoPE到输入张量x。""" - dim = x.shape[-1] - dim_half = dim // 2 - - # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 - # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) - x_part_1 = x[..., :dim_half] - x_part_2 = x[..., dim_half:] - - # 将与 abs_positions_x 相关的旋转应用于 x_part_1 - rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) - # 将与 abs_positions_y 相关的旋转应用于 x_part_2 - rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) - - # 将它们重新拼接起来。确保顺序与你分割时一致。 - return torch.cat((rotated_part_1, rotated_part_2), dim=-1) - - -def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): - """ - Compute patch coordinates (x, y) - - Args: - grid_hw: (B, 2) tensor representing (H, W) per image - """ - device = grid_hw.device - B = grid_hw.shape[0] - - # Get the number of patches per image - H = grid_hw[:, 0] - W = grid_hw[:, 1] - N = H * W - N_total = N.sum() - - # Create the batch index for each patch (B x patch count) - patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) - - # Generate intra-image patch index (row-major order) - patch_id_within_image = torch.arange(N_total, device=device) - patch_id_within_image = ( - patch_id_within_image - - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] - ) - - # Get H/W for each patch according to its image - W_per_patch = W[patch_to_sample] - abs_x = patch_id_within_image % W_per_patch - abs_y = patch_id_within_image // W_per_patch - - return abs_x, abs_y - - -class NeoVisionTransformerPretrainedModel(nn.Module): - def __init__( - self, - kvargs, - hidden_size: int = 1024, - llm_hidden_size: int = 2048, - downsample_ratio: float = 0.5, - patch_size: int = 16, - num_channels: int = 3, - max_position_embeddings_vision: int = 10000, - rope_theta_vision: float = 10000.0, - min_pixels: int = 65536, - max_pixels: int = 2408448, - **kwargs, - ): - super().__init__() - self.weight_dir = kvargs["weight_dir"] - self.data_type = kvargs.get("data_type", "bfloat16") - self.embed_dim = hidden_size - self.llm_hidden_size = llm_hidden_size - self.patch_size = patch_size - self.num_channels = num_channels - self.downsample_ratio = downsample_ratio - self.downsample_factor = int(1 / downsample_ratio) - self.max_position_embeddings_vision = max_position_embeddings_vision - self.rope_theta_vision = rope_theta_vision - self.rope_dim_part = self.embed_dim // 2 - self.min_pixels = min_pixels - self.max_pixels = max_pixels - - self.patch_embedding = nn.Conv2d( - in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size - ) - - self.dense_embedding = nn.Conv2d( - in_channels=self.embed_dim, - out_channels=self.llm_hidden_size, - kernel_size=self.downsample_factor, - stride=self.downsample_factor, - ) - self.gelu = nn.GELU() - - self.repe_dim_part = self.embed_dim // 2 - self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() - self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() - self._init_datatype() - - def _init_datatype(self): - if isinstance(self.data_type, torch.dtype): - return - if self.data_type in ["fp16", "float16"]: - self.data_type = torch.float16 - elif self.data_type in ["bf16", "bfloat16"]: - self.data_type = torch.bfloat16 - elif self.data_type in ["fp32", "float32"]: - self.data_type = torch.float32 - else: - raise ValueError(f"Unsupport datatype {self.data_type}!") - return - - def load_model(self, weight_dir): - bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] - if bin_weight_files: - weight_dict = {} - for file_ in bin_weight_files: - f = torch.load(os.path.join(weight_dir, file_), "cpu") - for k, v in f.items(): - if "vision_model" in k: - weight_dict[k[len("vision_model.embeddings.") :]] = v - else: - hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] - weight_dict = {} - for file_ in hf_weight_files: - f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") - for k in f.keys(): - if "vision_model" in k: - weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) - self.load_state_dict(weight_dict) - - def precompute_rope_freqs_sincos(self): - inv_freq = 1.0 / ( - self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) - ) - t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) - freqs = torch.outer(t, inv_freq) - return torch.cos(freqs), torch.sin(freqs) - - def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): - """ - Apply 2D Rotary Position Embedding to the patch embeddings. - """ - abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) - embeddings = apply_2d_rotary_pos_emb( - patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 - self.cos_x, - self.sin_x, - self.cos_y, - self.sin_y, - abs_pos_x, - abs_pos_y, - ).to(self.patch_embedding.weight.dtype) - return embeddings - - def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: - pixel_values = pixel_values.view( - -1, - 3, - self.patch_size, - self.patch_size, - ) - patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) - patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) - assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ - 0 - ], "Grid size and patch embeds size mismatch." - - patches_list = [] - cur_position = 0 - for i in range(grid_hw.shape[0]): - h, w = grid_hw[i] - patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) - patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) - patches_per_img = patches_per_img.permute(0, 2, 3, 1) - patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) - cur_position += h * w - - embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) - assert cur_position == patch_embeds.shape[0] - assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) - - return embeddings - - def encode(self, images: List[ImageItem]): - img_tensors = [] - valid_ids = [] - valid_id = 0 - img_grids = [] - uuids = [] - - for i, img in enumerate(images): - if isinstance(img, ImageItem): - uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) - image_data = Image.open(BytesIO(image_data)) - pixel_values, image_grid_hw = load_image_native( - image_data, - patch_size=self.patch_size, - downsample_ratio=self.downsample_ratio, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, - ) - img_tensors.append(pixel_values) - img_grids.append(image_grid_hw) - else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - - # must devide merge_length - cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) - print(f"cur_num is {cur_num}") - valid_ids.append([valid_id, valid_id + cur_num]) - valid_id += cur_num - - if len(img_tensors) <= 0: - return None - - imgs = torch.cat(img_tensors, dim=0) - grid_hw = torch.cat(img_grids, dim=0) - - pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) - image_grid_hw = grid_hw.to("cuda", non_blocking=True) - - all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) - - return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat/triton_kernel/__init__.py b/lightllm/models/neo_chat/triton_kernel/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py deleted file mode 100644 index f5dae493cb..0000000000 --- a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py +++ /dev/null @@ -1,452 +0,0 @@ -import math -import torch -import triton -import triton.language as tl - -from lightllm.utils.device_utils import is_tesla - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - Out, - position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] - B_Start_Loc, - B_Seqlen, - Req_to_tokens, - B_req_idx, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - kv_group_num, - b_prompt_cache_len, - H: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - cur_bh = tl.program_id(1) - cur_batch = cur_bh // H - cur_head = cur_bh % H - - cur_kv_head = cur_head // kv_group_num - - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) - total_len = tl.load(B_Seqlen + cur_batch) - cur_batch_seq_len = total_len - prompt_cache_len # NEW len - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - block_start_loc = BLOCK_M * start_m - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = block_start_loc + tl.arange(0, BLOCK_M) - - # Q pointers - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - - q_valid = offs_m < cur_batch_seq_len - q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) - - # online softmax state - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - block_end_loc = total_len - - # absolute q positions in the request - q_pos = prompt_cache_len + offs_m # [M] - - # q_gid from packed position_ids (aligned with Q rows) - q_gid = tl.load( - position_ids + cur_batch_in_all_start_index + offs_m, - mask=q_valid, - other=-2147483648, - ).to(tl.int32) - - BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid - - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - k_pos = start_n + offs_n # [N] - k_valid = k_pos < block_end_loc - - # map logical pos -> mem_index (for K/V) - kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, - mask=k_valid, - other=0, - ).to(tl.int64) - - # load K - off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) - - qk = tl.dot(q, k) - - # k_gid: - # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false - # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) - k_in_new = k_pos >= prompt_cache_len - k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new - k_gid_new = tl.load( - position_ids + cur_batch_in_all_start_index + k_new_idx, - mask=k_valid & k_in_new, - other=-2147483647, - ).to(tl.int32) - - k_gid = tl.where( - k_in_new, - k_gid_new, - (k_pos.to(tl.int32) + BIG), - ) - - # mask: causal OR same gid (only possible inside NEW part) - mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) - mask = mask & q_valid[:, None] & k_valid[None, :] - - qk = tl.where(mask, qk * sm_scale, -1.0e8) - - # online softmax - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - acc = acc * alpha[:, None] - - # load V - off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) - - p = p.to(v.dtype) - acc = tl.dot(p, v, acc) - - m_i = m_ij - - acc = acc / l_i[:, None] - - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - tl.store(Out + off_o, acc, mask=q_valid[:, None]) - - -@torch.no_grad() -def context_attention_fwd_neo( - q, - k, - v, - o, - position_ids, # 1D packed like q (only NEW tokens) - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_input_len, - req_to_token_indexs, -): - # minimal safety: position_ids must cover packed q rows - assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) - - BLOCK_M = 128 if not is_tesla() else 64 - - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} - base_head_dim = Lq // 2 - sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 - - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) - - BLOCK_N = BLOCK_M - num_warps = 4 if Lk <= 64 else 8 - num_stages = 1 - - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - o, - position_ids, - b_start_loc, - b_seq_len, - req_to_token_indexs, - b_req_idx, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - kv_group_num=kv_group_num, - b_prompt_cache_len=b_prompt_cache_len, - H=head, - BLOCK_DMODEL=Lk, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def reference_attention( - q, - k, - v, - position_ids_q, # 1D packed like q (only NEW tokens) - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - req_to_token_indexs, -): - device = q.device - dtype = q.dtype - sum_q, Hq, D = q.shape - Hk = k.shape[1] - kv_group_num = Hq // Hk - - batch = b_seq_len.shape[0] - out = torch.empty_like(q) - scale = 1.0 / math.sqrt(D) - - for b in range(batch): - req = int(b_req_idx[b].item()) - total_len = int(b_seq_len[b].item()) - prompt_len = int(b_prompt_cache_len[b].item()) - new_len = total_len - prompt_len - - q_start = int(b_start_loc[b].item()) - q_blk = q[q_start : q_start + new_len] # [M, Hq, D] - gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] - - # gather K/V for full request by logical pos -> mem_index - token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] - k_blk = k[token_locs] # [L, Hk, D] - v_blk = v[token_locs] # [L, Hk, D] - - # expand kv heads to q heads (GQA) - k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] - v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] - - # positions - q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] - k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] - - # build allow mask: - # causal always - allow = k_pos[None, :] <= q_pos[:, None] - - # full-attn only inside NEW part by gid - # compare only when k_pos in NEW - k_in_new = k_pos >= prompt_len - k_rel = (k_pos - prompt_len).clamp_min(0) # [L] - # map k_rel to gid_new, but only valid where k_in_new - k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) - k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new - k_gid[k_in_new] = gid_new[k_rel[k_in_new]] - - allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) - - # scores: [Hq, M, L] - q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] - k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] - scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] - - neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) - scores = torch.where(allow[None, :, :], scores, neg) - - p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] - v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] - out_hq = torch.matmul(p, v_t) # [Hq, M, D] - out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] - - out[q_start : q_start + new_len] = out_blk - - return out - - -def make_test_case( - device="cuda", - dtype=torch.float16, - batch=3, - Hq=8, - Hk=4, - D=64, - seed=0, - base_index=50000, -): - torch.manual_seed(seed) - - # prompt (cached) len and new len - prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) - new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) - total_lens = (prompt_lens + new_lens).to(torch.int32) - - max_total_len = int(total_lens.max().item()) - max_new_len = int(new_lens.max().item()) - - # packed q start - b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) - cur = 0 - for b in range(batch): - b_start_loc[b] = cur - cur += int(new_lens[b].item()) - sum_q = cur - - b_seq_len = total_lens - b_prompt_cache_len = prompt_lens.to(torch.int32) - - # one req per batch - num_req = batch - b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) - - # global KV space large, indices not small - sum_kv = int(total_lens.sum().item()) - kv_size = base_index + sum_kv + 1024 - pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index - - # Req_to_tokens [num_req, max_total_len] - req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) - p = 0 - for r in range(num_req): - L = int(total_lens[r].item()) - req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) - p += L - - # position_ids_q: only NEW tokens, packed like q - position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) - for b in range(batch): - M = int(new_lens[b].item()) - start = int(b_start_loc[b].item()) - - gid = torch.arange(M, device=device, dtype=torch.int32) - - # make one repeated block inside NEW part to simulate image tokens - if M >= 4 and torch.rand((), device=device).item() > 0.3: - s = int(torch.randint(0, M - 2, (1,), device=device).item()) - e = min(M, s + 3) - gid[s:e] = gid[s] - - position_ids_q[start : start + M] = gid - - q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) - k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) - v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) - o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) - - return ( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) - - -def check_once(device="cuda", dtype=torch.float16, seed=0): - ( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) = make_test_case(device=device, dtype=dtype, seed=seed) - - context_attention_fwd_neo( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) - - ref = reference_attention( - q, - k, - v, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - req_to_token_indexs, - ) - - diff = (o - ref).abs() - max_abs = diff.max().item() - denom = ref.abs().max().item() + 1e-6 - max_rel = max_abs / denom - - print(f"seed={seed}, dtype={dtype}") - print(f"max_abs_error = {max_abs:.6e}") - print(f"max_rel_error = {max_rel:.6e}") - print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) - - -if __name__ == "__main__": - if not torch.cuda.is_available(): - print("No CUDA, skip.") - else: - torch.cuda.synchronize() - check_once(dtype=torch.bfloat16, seed=0) - check_once(dtype=torch.bfloat16, seed=1) - check_once(dtype=torch.bfloat16, seed=2) diff --git a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py deleted file mode 100644 index 955f48bd80..0000000000 --- a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _get_neo_position_triton( - b_image_start_idx: torch.Tensor, - b_image_thwd: torch.Tensor, - b_image_thwd_stride0: torch.Tensor, - b_image_nums: torch.Tensor, - b_image_start_num: torch.Tensor, - b_image_len: torch.Tensor, - position_ids: torch.Tensor, - position_ids_stride0: torch.Tensor, - b_ready_cache_len: torch.Tensor, - b_q_seq_len: torch.Tensor, - b_start_loc: torch.Tensor, - BLOCK_SIZE: tl.constexpr, -) -> torch.Tensor: - cur_batch = tl.program_id(0) - cache_len = tl.load(b_ready_cache_len + cur_batch) - q_seq_len = tl.load(b_q_seq_len + cur_batch) - image_num = tl.load(b_image_nums + cur_batch) - image_start_num = tl.load(b_image_start_num + cur_batch) - start_loc = tl.load(b_start_loc + cur_batch) - for i in range(image_num): - local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) - image_start_idx = start_loc + local_image_start_idx - cache_len - image_len = tl.load(b_image_len + image_start_num + i) - # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) - image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) - for j in range(0, image_len, BLOCK_SIZE): - off = j + tl.arange(0, BLOCK_SIZE) - # 目前没考虑视频,所以t 恒为 0 - t_pos = local_image_start_idx + off * 0 - h_pos = off // image_w - w_pos = off % image_w - tl.store( - position_ids + off + image_start_idx, - t_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - tl.store( - position_ids + position_ids_stride0 + off + image_start_idx, - h_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - tl.store( - position_ids + position_ids_stride0 * 2 + off + image_start_idx, - w_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - - for i in range(image_num): - local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) - image_len = tl.load(b_image_len + image_start_num + i) - image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) - image_end = local_image_start_idx + image_len - cache_len - text_start = tl.maximum(0, image_end) - for j in range(text_start, q_seq_len, BLOCK_SIZE): - off = j + tl.arange(0, BLOCK_SIZE) - t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta - h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) - w_pos = tl.load( - position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 - ) - tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) - tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) - tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) - return - - -def get_neo_position_triton( - b_image_start_idx: torch.Tensor, - b_image_thwd: torch.Tensor, - b_image_nums: torch.Tensor, - b_image_start_num: torch.Tensor, - b_image_len: torch.Tensor, - position_ids: torch.Tensor, - b_ready_cache_len: torch.Tensor, - b_q_seq_len: torch.Tensor, - b_start_loc: torch.Tensor, -) -> torch.Tensor: - - batch_size = b_q_seq_len.shape[0] - assert batch_size == b_image_nums.shape[0] - grid = (batch_size,) - BLOCK_SIZE = 64 - _get_neo_position_triton[grid]( - b_image_start_idx=b_image_start_idx, - b_image_thwd=b_image_thwd, - b_image_thwd_stride0=b_image_thwd.stride(0), - b_image_nums=b_image_nums, - b_image_start_num=b_image_start_num, - b_image_len=b_image_len, - position_ids=position_ids, - position_ids_stride0=position_ids.stride(0), - b_ready_cache_len=b_ready_cache_len, - b_q_seq_len=b_q_seq_len, - b_start_loc=b_start_loc, - BLOCK_SIZE=BLOCK_SIZE, - ) - - -def test(): - b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") - b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") - b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") - b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") - position_ids = ( - torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") - .unsqueeze(0) - .expand(3, -1) - .contiguous() - ) - position_ids[1:].zero_() - b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") - b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") - b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") - get_neo_position_triton( - b_image_start_idx, - b_image_thwd, - b_image_nums, - b_image_start_num, - b_image_len, - position_ids, - b_ready_cache_len, - b_q_seq_len, - b_start_loc, - ) - - print(position_ids) - # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) - - # position_ids = ( - # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") - # .unsqueeze(0) - # .expand(3, -1) - # .contiguous() - # ) - # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") - # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") - # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") - - # get_neo_position_triton( - # b_image_start_idx, - # b_image_thwd, - # b_image_nums, - # b_image_start_num, - # b_image_len, - # position_ids, - # b_ready_cache_len, - # b_q_seq_len, - # b_start_loc, - # ) - - # print(f"old_value:\n{old_value}") - # print(f"position_ids:\n{position_ids}") - # assert torch.equal(old_value, position_ids) - - """ - tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], - [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], - [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], - device='cuda:0', dtype=torch.int32) - """ diff --git a/lightllm/models/neo_chat/vision_process.py b/lightllm/models/neo_chat/vision_process.py deleted file mode 100644 index aa008e18fb..0000000000 --- a/lightllm/models/neo_chat/vision_process.py +++ /dev/null @@ -1,141 +0,0 @@ -import re -import math -import torch -import string -import numpy as np -import pandas as pd -from PIL import Image -import torch.distributed as dist -import torchvision.transforms as T - -IMAGENET_MEAN = (0.485, 0.456, 0.406) -IMAGENET_STD = (0.229, 0.224, 0.225) - - -def round_by_factor(number: int, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - - -# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 -def smart_resize( - height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 -) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - - 3. The aspect ratio of the image is maintained as closely as possible. - """ - if max(height, width) / min(height, width) > 200: - raise ValueError( - f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" - ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = max(factor, floor_by_factor(height / beta, factor)) - w_bar = max(factor, floor_by_factor(width / beta, factor)) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) - return h_bar, w_bar - - -def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): - width, height = image.size - resized_height, resized_width = smart_resize( - height, - width, - factor=size_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - image = image.resize((resized_width, resized_height)) - - return image - - -def preprocess_pixel_values(pixel_values, patch_size=16): - c, h, w = pixel_values.shape - grid_h = h // patch_size - grid_w = w // patch_size - - flatten_pixel_values = ( - pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) - .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] - .reshape(grid_h * grid_w, c * patch_size ** 2) - ) - - grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) - - return flatten_pixel_values, grid_hw - - -def get_contrasting_background(image): - """ - Calculate the color (white or black) that is different from the average foreground color - to use as the background color - """ - image_np = np.array(image) - if (image_np[:, :, 3] == 0).any(): - non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] - if non_transparent_pixels.size == 0: - return None - pixel_mean = non_transparent_pixels.mean() - contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) - return contrasting_color - else: - return None - - -def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): - """ - Load and preprocess an image file, converting it to RGB mode, - resizing, normalizing, and optionally adding a thumbnail version. - """ - if image.mode == "RGBA": - bg_color = get_contrasting_background(image) - if bg_color: - background = Image.new("RGB", image.size, bg_color) - background.paste(image, mask=image.split()[3]) - image = background.convert("RGB") - else: - image = image.convert("RGB") - else: - image = image.convert("RGB") - - if upscale: - image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) - - transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), - ] - ) - - new_image = dynamic_preprocess_native_resolution( - image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels - ) - pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) - - print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") - - return pixel_values, grid_hw diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 17f5a741ac..3563739f79 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -30,7 +30,7 @@ from ..models.qwen2_vl.model import QWen2VLTokenizer from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer -from ..models.neo_chat.model import NeoChatTokenizer +from ..models.neo_chat_moe.model import NeoChatTokenizer from ..models.gemma3.model import Gemma3Tokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d77271af86..df5d66bcbc 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,7 +19,7 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.models.neo_chat.neo_visual import NeoVisionTransformerPretrainedModel +from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry From 4b757ddb4150b6f0498d7b97b9a7dcd9de71d4a5 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 6 Jan 2026 08:36:24 +0000 Subject: [PATCH 022/180] add-neo-chat-dense --- lightllm/models/neo_chat_moe/__init__.py | 0 lightllm/models/neo_chat_moe/infer_struct.py | 99 ++++ .../neo_chat_moe/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 159 ++++++ .../neo_chat_moe/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 23 + .../layer_weights/transformer_layer_weight.py | 51 ++ lightllm/models/neo_chat_moe/model.py | 139 ++++++ lightllm/models/neo_chat_moe/neo_visual.py | 279 +++++++++++ .../neo_chat_moe/triton_kernel/__init__.py | 0 .../context_attention_fwd_neo.py | 452 ++++++++++++++++++ .../triton_kernel/get_neo_position.py | 174 +++++++ .../models/neo_chat_moe/vision_process.py | 141 ++++++ 13 files changed, 1517 insertions(+) create mode 100644 lightllm/models/neo_chat_moe/__init__.py create mode 100644 lightllm/models/neo_chat_moe/infer_struct.py create mode 100644 lightllm/models/neo_chat_moe/layer_infer/__init__.py create mode 100644 lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/__init__.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/neo_chat_moe/model.py create mode 100644 lightllm/models/neo_chat_moe/neo_visual.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/__init__.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py create mode 100644 lightllm/models/neo_chat_moe/vision_process.py diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py new file mode 100644 index 0000000000..0c7d9372e2 --- /dev/null +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -0,0 +1,99 @@ +from typing import Optional, List +import torch +import numpy as np +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton +from lightllm.models.llama.model import LlamaTpPartModel + + +class NeoChatInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.position_cos_h = None + self.position_sin_h = None + self.position_cos_w = None + self.position_sin_w = None + + def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): + LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) + if self.is_prefill: + self.position_ids = self.get_neo_position(self.multimodal_params) + else: + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + position_delta = 0 + for image in p["images"]: + position_delta += image["grid_thwd"][3] + b_position_delta[batch_idx] = position_delta + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() + self.position_ids[1:].zero_() + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids[0]] + self.position_sin = model._sin_cached[self.position_ids[0]] + self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] + self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] + self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] + self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] + return + + def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: + if len(multimodal_params) == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids + b_image_start_idx = [] + b_image_nums = [] + b_image_start_num = [] + b_image_len = [] + image_start_num = 0 + b_image_thwd = [] + + # pad multimodal_params to batch size. + batch_size = self.b_q_seq_len.shape[0] + multimodal_params = multimodal_params + [ + {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) + ] + + for _, p in enumerate(multimodal_params): + images = p.get("images", []) + for img in images: + b_image_start_idx.append(img["start_idx"]) + a = img["start_idx"] + print(f"img start_idx: {a}") + b_image_len.append(img["token_num"]) + b_image_thwd.append(img["grid_thwd"]) + b_image_nums.append(len(images)) + b_image_start_num.append(image_start_num) + image_start_num += len(images) + + # 没有任何图片 + if image_start_num == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids.contiguous() + b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) + b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 + b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) + b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) + b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) + + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + + get_neo_position_triton( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_start_loc=self.b_start_loc, + ) + return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..ed48a9c6f1 --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -0,0 +1,159 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward + + +class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_decode_attention_normal + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + return + + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) # [T, Hq*D] + + q_hw = layer_weight.q_hw_proj.mm(input) + q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + k_hw = layer_weight.k_hw_proj.mm(input) + k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + + q_h_2d = q_h.reshape(q.shape[0], -1) + q_w_2d = q_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + + qk_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ) + + k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + k_w_2d = k_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o3 = o3[:, :, : self.head_dim_].contiguous() + return o3.view(o3.shape[0], -1) + + def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + + q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + token_att_fwd( + q_3d, + k_3d, + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + + v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + ] + + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out + + token_softmax_reducev_fwd( + att_m_tensor, + v_3d, + o_3d, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..7766a5d29f --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..bc38f1adcb --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -0,0 +1,51 @@ +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + NormWeight, + ROWMMWeight, +) + + +class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" + self._q_bias_hw_name = None + self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" + self._k_bias_hw_name = None + + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + + def _init_qkv(self): + super()._init_qkv() + self.q_hw_proj = ROWMMWeight( + weight_names=self._q_weight_hw_name, + data_type=self.data_type_, + bias_names=self._q_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="q_hw_proj", + ) + self.k_hw_proj = ROWMMWeight( + weight_names=self._k_weight_hw_name, + data_type=self.data_type_, + bias_names=self._k_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="k_hw_proj", + ) + + def _init_norm(self): + super()._init_norm() + + self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) + self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) + self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) + self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py new file mode 100644 index 0000000000..e4123d1097 --- /dev/null +++ b/lightllm/models/neo_chat_moe/model.py @@ -0,0 +1,139 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry, llm_model_type_is +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat_moe.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo + +IMG_START_TOKEN = "" +IMG_END_TOKEN = "" +IMG_TOKEN = "" +AUDIO_START_TOKEN = "" + + +class NeoChatTokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, **kwargs): + super().__init__(tokenizer) + self.tokenizer = tokenizer + self.min_pixel = model_cfg.get("vision_config").get("min_pixels") + self.max_pixel = model_cfg.get("vision_config").get("max_pixels") + self.patch_size = model_cfg.get("vision_config").get("patch_size") + self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") + + self.image_token_id = model_cfg.get("image_token_id") + self.image_start_tag = IMG_START_TOKEN + self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) + self.image_end_tag = IMG_END_TOKEN + self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=int(self.patch_size // self.downsample_ratio), + min_pixels=self.min_pixel, + max_pixels=self.max_pixel, + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) + # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 + img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) + return token_num + + # only change the impl of the encode func: + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + # TEXTTEXTTEXT --> TEXTTEXTTEXT + image_tokens = IMG_START_TOKEN + IMG_END_TOKEN + if multimodal_params is None: + add_special_tokens = kwargs.get("add_special_tokens", True) + return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + image_count = len(multimodal_params.images) + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids) + return input_ids + + +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe")) +class NeoTpMOEPartModel(Qwen3MOEModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + + pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight + transformer_weight_class = NeoChatMOETransformerLayerWeight + + infer_state_class = NeoChatInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py new file mode 100644 index 0000000000..852ddc0952 --- /dev/null +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -0,0 +1,279 @@ +import os +import torch +import torch.nn.functional as F +from PIL import Image +from typing import List +from io import BytesIO +import torch.nn as nn +from transformers.activations import ACT2FN +from safetensors import safe_open +from lightllm.server.multimodal_params import ImageItem +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from lightllm.models.neo_chat_moe.vision_process import load_image_native +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data + + +def apply_rotary_emb_1d( + x: torch.Tensor, + cos_cached: torch.Tensor, + sin_cached: torch.Tensor, + positions: torch.Tensor, +): + """对输入张量的一部分应用1D RoPE。""" + # x: (..., seq_len, dim_part) + # positions: (..., seq_len) + # cos_cached: (max_pos, dim_part / 2) + cos_cached = cos_cached.to(device=positions.device) + sin_cached = sin_cached.to(device=positions.device) + + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + + x1 = x[..., 0::2] + x2 = x[..., 1::2] + + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + x_rotated = torch.empty_like(x) + x_rotated[..., 0::2] = rotated_x1 + x_rotated[..., 1::2] = rotated_x2 + return x_rotated + + +def apply_2d_rotary_pos_emb( + x: torch.Tensor, + cos_cached_x: torch.Tensor, + sin_cached_x: torch.Tensor, + cos_cached_y: torch.Tensor, + sin_cached_y: torch.Tensor, + abs_positions_x: torch.Tensor, + abs_positions_y: torch.Tensor, +): + """应用2D RoPE到输入张量x。""" + dim = x.shape[-1] + dim_half = dim // 2 + + # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 + # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) + x_part_1 = x[..., :dim_half] + x_part_2 = x[..., dim_half:] + + # 将与 abs_positions_x 相关的旋转应用于 x_part_1 + rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) + # 将与 abs_positions_y 相关的旋转应用于 x_part_2 + rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) + + # 将它们重新拼接起来。确保顺序与你分割时一致。 + return torch.cat((rotated_part_1, rotated_part_2), dim=-1) + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +class NeoVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + hidden_size: int = 1024, + llm_hidden_size: int = 2048, + downsample_ratio: float = 0.5, + patch_size: int = 16, + num_channels: int = 3, + max_position_embeddings_vision: int = 10000, + rope_theta_vision: float = 10000.0, + min_pixels: int = 65536, + max_pixels: int = 2408448, + **kwargs, + ): + super().__init__() + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + self.embed_dim = hidden_size + self.llm_hidden_size = llm_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.downsample_ratio = downsample_ratio + self.downsample_factor = int(1 / downsample_ratio) + self.max_position_embeddings_vision = max_position_embeddings_vision + self.rope_theta_vision = rope_theta_vision + self.rope_dim_part = self.embed_dim // 2 + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.patch_embedding = nn.Conv2d( + in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size + ) + + self.dense_embedding = nn.Conv2d( + in_channels=self.embed_dim, + out_channels=self.llm_hidden_size, + kernel_size=self.downsample_factor, + stride=self.downsample_factor, + ) + self.gelu = nn.GELU() + + self.repe_dim_part = self.embed_dim // 2 + self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() + self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) + self.load_state_dict(weight_dict) + + def precompute_rope_freqs_sincos(self): + inv_freq = 1.0 / ( + self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) + ) + t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + return torch.cos(freqs), torch.sin(freqs) + + def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): + """ + Apply 2D Rotary Position Embedding to the patch embeddings. + """ + abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) + embeddings = apply_2d_rotary_pos_emb( + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_x, + self.sin_x, + self.cos_y, + self.sin_y, + abs_pos_x, + abs_pos_y, + ).to(self.patch_embedding.weight.dtype) + return embeddings + + def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + pixel_values = pixel_values.view( + -1, + 3, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) + assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ + 0 + ], "Grid size and patch embeds size mismatch." + + patches_list = [] + cur_position = 0 + for i in range(grid_hw.shape[0]): + h, w = grid_hw[i] + patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) + patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) + patches_per_img = patches_per_img.permute(0, 2, 3, 1) + patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) + cur_position += h * w + + embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) + assert cur_position == patch_embeds.shape[0] + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) + + return embeddings + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values, image_grid_hw = load_image_native( + image_data, + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + img_tensors.append(pixel_values) + img_grids.append(image_grid_hw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) + print(f"cur_num is {cur_num}") + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_hw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_hw = grid_hw.to("cuda", non_blocking=True) + + all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py new file mode 100644 index 0000000000..f5dae493cb --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -0,0 +1,452 @@ +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + H: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len # NEW len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + + # Q pointers + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + + q_valid = offs_m < cur_batch_seq_len + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # online softmax state + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = total_len + + # absolute q positions in the request + q_pos = prompt_cache_len + offs_m # [M] + + # q_gid from packed position_ids (aligned with Q rows) + q_gid = tl.load( + position_ids + cur_batch_in_all_start_index + offs_m, + mask=q_valid, + other=-2147483648, + ).to(tl.int32) + + BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid + + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + k_pos = start_n + offs_n # [N] + k_valid = k_pos < block_end_loc + + # map logical pos -> mem_index (for K/V) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + # load K + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + + qk = tl.dot(q, k) + + # k_gid: + # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false + # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) + k_in_new = k_pos >= prompt_cache_len + k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new + k_gid_new = tl.load( + position_ids + cur_batch_in_all_start_index + k_new_idx, + mask=k_valid & k_in_new, + other=-2147483647, + ).to(tl.int32) + + k_gid = tl.where( + k_in_new, + k_gid_new, + (k_pos.to(tl.int32) + BIG), + ) + + # mask: causal OR same gid (only possible inside NEW part) + mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) + mask = mask & q_valid[:, None] & k_valid[None, :] + + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + # online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # load V + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, +): + # minimal safety: position_ids must cover packed q rows + assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) + + BLOCK_M = 128 if not is_tesla() else 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256} + base_head_dim = Lq // 2 + sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 + + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + position_ids, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + H=head, + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def reference_attention( + q, + k, + v, + position_ids_q, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, +): + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + batch = b_seq_len.shape[0] + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(D) + + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + new_len = total_len - prompt_len + + q_start = int(b_start_loc[b].item()) + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] + + # gather K/V for full request by logical pos -> mem_index + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] + k_blk = k[token_locs] # [L, Hk, D] + v_blk = v[token_locs] # [L, Hk, D] + + # expand kv heads to q heads (GQA) + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + + # positions + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] + + # build allow mask: + # causal always + allow = k_pos[None, :] <= q_pos[:, None] + + # full-attn only inside NEW part by gid + # compare only when k_pos in NEW + k_in_new = k_pos >= prompt_len + k_rel = (k_pos - prompt_len).clamp_min(0) # [L] + # map k_rel to gid_new, but only valid where k_in_new + k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) + k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new + k_gid[k_in_new] = gid_new[k_rel[k_in_new]] + + allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) + + # scores: [Hq, M, L] + q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] + k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] + scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] + + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + + p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] + v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] + out_hq = torch.matmul(p, v_t) # [Hq, M, D] + out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] + + out[q_start : q_start + new_len] = out_blk + + return out + + +def make_test_case( + device="cuda", + dtype=torch.float16, + batch=3, + Hq=8, + Hk=4, + D=64, + seed=0, + base_index=50000, +): + torch.manual_seed(seed) + + # prompt (cached) len and new len + prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) + + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + # packed q start + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(new_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + + # one req per batch + num_req = batch + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + # global KV space large, indices not small + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 1024 + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + # Req_to_tokens [num_req, max_total_len] + req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # position_ids_q: only NEW tokens, packed like q + position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + start = int(b_start_loc[b].item()) + + gid = torch.arange(M, device=device, dtype=torch.int32) + + # make one repeated block inside NEW part to simulate image tokens + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + e = min(M, s + 3) + gid[s:e] = gid[s] + + position_ids_q[start : start + M] = gid + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + +def check_once(device="cuda", dtype=torch.float16, seed=0): + ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) = make_test_case(device=device, dtype=dtype, seed=seed) + + context_attention_fwd_neo( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + ref = reference_attention( + q, + k, + v, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + + print(f"seed={seed}, dtype={dtype}") + print(f"max_abs_error = {max_abs:.6e}") + print(f"max_rel_error = {max_rel:.6e}") + print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + torch.cuda.synchronize() + check_once(dtype=torch.bfloat16, seed=0) + check_once(dtype=torch.bfloat16, seed=1) + check_once(dtype=torch.bfloat16, seed=2) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py new file mode 100644 index 0000000000..955f48bd80 --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -0,0 +1,174 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_thwd_stride0: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + position_ids_stride0: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + BLOCK_SIZE: tl.constexpr, +) -> torch.Tensor: + cur_batch = tl.program_id(0) + cache_len = tl.load(b_ready_cache_len + cur_batch) + q_seq_len = tl.load(b_q_seq_len + cur_batch) + image_num = tl.load(b_image_nums + cur_batch) + image_start_num = tl.load(b_image_start_num + cur_batch) + start_loc = tl.load(b_start_loc + cur_batch) + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_start_idx = start_loc + local_image_start_idx - cache_len + image_len = tl.load(b_image_len + image_start_num + i) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + # 目前没考虑视频,所以t 恒为 0 + t_pos = local_image_start_idx + off * 0 + h_pos = off // image_w + w_pos = off % image_w + tl.store( + position_ids + off + image_start_idx, + t_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 + off + image_start_idx, + h_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 * 2 + off + image_start_idx, + w_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_len = tl.load(b_image_len + image_start_num + i) + image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) + image_end = local_image_start_idx + image_len - cache_len + text_start = tl.maximum(0, image_end) + for j in range(text_start, q_seq_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta + h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) + w_pos = tl.load( + position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 + ) + tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) + return + + +def get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, +) -> torch.Tensor: + + batch_size = b_q_seq_len.shape[0] + assert batch_size == b_image_nums.shape[0] + grid = (batch_size,) + BLOCK_SIZE = 64 + _get_neo_position_triton[grid]( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_thwd_stride0=b_image_thwd.stride(0), + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + position_ids_stride0=position_ids.stride(0), + b_ready_cache_len=b_ready_cache_len, + b_q_seq_len=b_q_seq_len, + b_start_loc=b_start_loc, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def test(): + b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") + b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") + b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") + position_ids = ( + torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + .unsqueeze(0) + .expand(3, -1) + .contiguous() + ) + position_ids[1:].zero_() + b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") + b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + get_neo_position_triton( + b_image_start_idx, + b_image_thwd, + b_image_nums, + b_image_start_num, + b_image_len, + position_ids, + b_ready_cache_len, + b_q_seq_len, + b_start_loc, + ) + + print(position_ids) + # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) + + # position_ids = ( + # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + # .unsqueeze(0) + # .expand(3, -1) + # .contiguous() + # ) + # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") + # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") + # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") + + # get_neo_position_triton( + # b_image_start_idx, + # b_image_thwd, + # b_image_nums, + # b_image_start_num, + # b_image_len, + # position_ids, + # b_ready_cache_len, + # b_q_seq_len, + # b_start_loc, + # ) + + # print(f"old_value:\n{old_value}") + # print(f"position_ids:\n{position_ids}") + # assert torch.equal(old_value, position_ids) + + """ + tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], + [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], + [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], + device='cuda:0', dtype=torch.int32) + """ diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py new file mode 100644 index 0000000000..aa008e18fb --- /dev/null +++ b/lightllm/models/neo_chat_moe/vision_process.py @@ -0,0 +1,141 @@ +import re +import math +import torch +import string +import numpy as np +import pandas as pd +from PIL import Image +import torch.distributed as dist +import torchvision.transforms as T + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 +def smart_resize( + height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def preprocess_pixel_values(pixel_values, patch_size=16): + c, h, w = pixel_values.shape + grid_h = h // patch_size + grid_w = w // patch_size + + flatten_pixel_values = ( + pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) + .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] + .reshape(grid_h * grid_w, c * patch_size ** 2) + ) + + grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) + + return flatten_pixel_values, grid_hw + + +def get_contrasting_background(image): + """ + Calculate the color (white or black) that is different from the average foreground color + to use as the background color + """ + image_np = np.array(image) + if (image_np[:, :, 3] == 0).any(): + non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] + if non_transparent_pixels.size == 0: + return None + pixel_mean = non_transparent_pixels.mean() + contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) + return contrasting_color + else: + return None + + +def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): + """ + Load and preprocess an image file, converting it to RGB mode, + resizing, normalizing, and optionally adding a thumbnail version. + """ + if image.mode == "RGBA": + bg_color = get_contrasting_background(image) + if bg_color: + background = Image.new("RGB", image.size, bg_color) + background.paste(image, mask=image.split()[3]) + image = background.convert("RGB") + else: + image = image.convert("RGB") + else: + image = image.convert("RGB") + + if upscale: + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + ) + + new_image = dynamic_preprocess_native_resolution( + image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels + ) + pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) + + print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + + return pixel_values, grid_hw From e208733708e86834fafd21f17adc818e4145fa3e Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Thu, 8 Jan 2026 16:42:33 +0800 Subject: [PATCH 023/180] support verl. --- lightllm/server/api_http.py | 1 + lightllm/server/api_start.py | 16 +++++++++++++++- lightllm/server/core/objs/start_args_type.py | 2 -- lightllm/server/httpserver/manager.py | 4 ++-- lightllm/utils/device_utils.py | 3 ++- lightllm/utils/serializer.py | 2 ++ lightllm/utils/torch_memory_saver_utils.py | 2 +- 7 files changed, 23 insertions(+), 7 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index ff9acafc94..a933a79477 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -335,6 +335,7 @@ async def handle_request_common(request_obj, handler): else: return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) except Exception as e: + logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True) return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index ffd794b2d6..f47d0ddd42 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import sys import time @@ -83,7 +84,13 @@ def signal_handler(sig, frame): return -def normal_or_p_d_start(args: StartArgs): +def _set_envs_and_config(args: StartArgs): + mp.set_start_method("spawn", force=True) + + +def _launch_subprocesses(args: StartArgs): + + _set_envs_and_config(args) set_unique_server_name(args) if not args.disable_shm_warning: @@ -350,6 +357,13 @@ def normal_or_p_d_start(args: StartArgs): ], ) + return process_manager + + +def normal_or_p_d_start(args: StartArgs): + + process_manager = _launch_subprocesses(args) + # 启动 gunicorn command = [ "gunicorn", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 0af795a096..80719745aa 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -151,5 +151,3 @@ class StartArgs: weight_version: str = "default" - enable_torch_memory_saver: bool = field(default=False) - enable_weight_cpu_backup: bool = field(default=False) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 083e939bac..5136032adb 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -924,7 +924,7 @@ async def update_weights_from_distributed(self, request: UpdateWeightsFromDistri await self.abort_request(AbortReq(abort_all=True)) if request.flush_cache: - await self.flush_cache() + await self.flush_cache(FlushCacheReq()) return await self.http_to_model_special_request( GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request) @@ -935,7 +935,7 @@ async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) await self.abort_request(AbortReq(abort_all=True)) if request.flush_cache: - await self.flush_cache() + await self.flush_cache(FlushCacheReq()) return await self.http_to_model_special_request( GeneralHttpToModelRpcReq(func_name="update_weights_from_tensor", func_args=request) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index d2b6d06a8a..5d58d5d7b6 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -85,7 +85,8 @@ def get_current_device_name(): gpu_name = torch.cuda.get_device_name(device).replace(" ", "_") return gpu_name else: - raise RuntimeError("No GPU available") + return "unknown" # need fix + # raise RuntimeError("No GPU available") @lru_cache(maxsize=None) diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py index e0b5233032..ae6f418df6 100644 --- a/lightllm/utils/serializer.py +++ b/lightllm/utils/serializer.py @@ -88,6 +88,8 @@ class SafeUnpickler(pickle.Unpickler): "sglang.srt.model_executor.model_runner.", "sglang.srt.layers.", "sglang.srt.utils.", + # --- LightLLM --- + "lightllm.utils.", } DENY_CLASSES = { diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py index edf15fa837..c1184ef30c 100644 --- a/lightllm/utils/torch_memory_saver_utils.py +++ b/lightllm/utils/torch_memory_saver_utils.py @@ -20,7 +20,7 @@ class MemoryTag(Enum): KV_CACHE = "kv_cache" - WEIGHT = "weight" + WEIGHT = "weights" GRAPH = "graph" def is_kv_cache(self): From 245357cc25d503e63fc9f6f2690769dd2cfd48a3 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 8 Jan 2026 11:40:51 +0000 Subject: [PATCH 024/180] improve0108 --- .../common/kv_cache_mem_manager/__init__.py | 2 + .../kv_cache_mem_manager/mem_manager.py | 2 +- .../common/kv_cache_mem_manager/mem_utils.py | 7 ++- .../kv_cache_mem_manager/neo_mem_manager.py | 46 +++++++++++++++++++ lightllm/utils/kv_cache_utils.py | 20 ++++++-- 5 files changed, 71 insertions(+), 6 deletions(-) create mode 100755 lightllm/common/kv_cache_mem_manager/neo_mem_manager.py diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 66caf5d789..a780261447 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -6,6 +6,7 @@ from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager from .deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager +from .neo_mem_manager import NeoMemoryManager __all__ = [ "MemoryManager", @@ -17,4 +18,5 @@ "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", "Deepseek2FP8KVMemoryManager", + "NeoMemoryManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index b599bedfc0..64483a79dc 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -28,7 +28,7 @@ class MemoryManager: def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num - self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 + self.head_dim = head_dim self.layer_num = layer_num self.always_copy = always_copy self.dtype = dtype diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 259c5a56f8..b655a274b7 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -7,6 +7,7 @@ PPLINT4KVMemoryManager, Deepseek2MemoryManager, Deepseek2FP8KVMemoryManager, + NeoMemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -23,7 +24,7 @@ def select_mem_manager_class(): # case 1 # 先判断是否是 deepseek 系列的模型 model_class = get_llm_model_class() - from lightllm.models import Deepseek2TpPartModel + from lightllm.models import Deepseek2TpPartModel, NeoTpMOEPartModel, NeoTpPartModel if issubclass(model_class, Deepseek2TpPartModel): mem_class = Deepseek2MemoryManager @@ -32,6 +33,10 @@ def select_mem_manager_class(): logger.info(f"Model kv cache using mode {mode}, mem_manager class: {mem_class}") return mem_class + # 判断是否是 neo 系列的模型 + elif issubclass(model_class, NeoTpMOEPartModel) or issubclass(model_class, NeoTpPartModel): + mem_class = NeoMemoryManager + return mem_class # case normal logger.info(f"mode setting params: {mode}") diff --git a/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py new file mode 100755 index 0000000000..0a79aa072b --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py @@ -0,0 +1,46 @@ +import torch +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + +class NeoMemoryManager(MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + self.size = size + self.head_num = head_num + self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 + self.layer_num = layer_num + self.always_copy = always_copy + self.dtype = dtype + # profile the max total token num if the size is None + self.profile_size(mem_fraction) + + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + + self.can_use_mem_size = self.size + + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name + + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) + + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self._init_buffers( + self.size, + dtype, + head_num, + self.head_dim, + layer_num, + ) + self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index ed183e3936..26d50f810b 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -21,6 +21,7 @@ PPLINT4KVMemoryManager, Deepseek2MemoryManager, Deepseek2FP8KVMemoryManager, + NeoMemoryManager, ) from typing import List, Tuple, Optional @@ -83,26 +84,37 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": scale_head_dim=0, scale_data_type=get_llm_data_type(), ) - elif mem_manager_class is MemoryManager: + elif mem_manager_class is PPLINT8KVMemoryManager: cpu_cache_meta = CpuKVCacheMeta( page_num=0, token_page_size=args.cpu_cache_token_page_size, layer_num=get_layer_num(args.model_dir), num_heads=get_num_key_value_heads(args.model_dir) * 2, head_dim=get_head_dim(args.model_dir), + data_type=torch.int8, + scale_head_dim=get_head_dim(args.model_dir) // 8, + scale_data_type=get_llm_data_type(), + ) + elif mem_manager_class is PPLINT8KVMemoryManager: + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=args.cpu_cache_token_page_size, + layer_num=get_layer_num(args.model_dir), + num_heads=get_num_key_value_heads(args.model_dir) * 2, + head_dim=get_head_dim(args.model_dir) * 2, data_type=get_llm_data_type(), scale_head_dim=0, scale_data_type=get_llm_data_type(), ) - elif mem_manager_class is PPLINT8KVMemoryManager: + elif mem_manager_class is MemoryManager: cpu_cache_meta = CpuKVCacheMeta( page_num=0, token_page_size=args.cpu_cache_token_page_size, layer_num=get_layer_num(args.model_dir), num_heads=get_num_key_value_heads(args.model_dir) * 2, head_dim=get_head_dim(args.model_dir), - data_type=torch.int8, - scale_head_dim=get_head_dim(args.model_dir) // 8, + data_type=get_llm_data_type(), + scale_head_dim=0, scale_data_type=get_llm_data_type(), ) else: From 6503ac8040e283998c01d903cba4ec72e8061943 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 8 Jan 2026 12:58:42 +0000 Subject: [PATCH 025/180] add min/max pixels sampling parameters --- lightllm/models/neo_chat_moe/model.py | 13 +++++++++++-- lightllm/models/neo_chat_moe/neo_visual.py | 8 +++++--- lightllm/server/core/objs/sampling_params.py | 7 ++++++- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index e4123d1097..ce2093d450 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -46,6 +46,15 @@ def __init__(self, tokenizer, model_cfg, **kwargs): def init_imageitem_extral_params( self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams ): + img.extra_params["min_pixels"] = ( + sampling_params.min_pixels if sampling_params.min_pixels > 0 else self.min_pixel + ) + img.extra_params["max_pixels"] = ( + sampling_params.max_pixels if sampling_params.max_pixels > 0 else self.max_pixel + ) + assert ( + img.extra_params["min_pixels"] <= img.extra_params["max_pixels"] + ), "min_pixels should be less than or equal to max_pixels" return def init_audioitem_extral_params( @@ -62,8 +71,8 @@ def get_image_token_length(self, img: ImageItem): height=height, width=width, factor=int(self.patch_size // self.downsample_ratio), - min_pixels=self.min_pixel, - max_pixels=self.max_pixel, + min_pixels=img.extra_params["min_pixels"], + max_pixels=img.extra_params["max_pixels"], ) grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py index 852ddc0952..59bd23e2bb 100644 --- a/lightllm/models/neo_chat_moe/neo_visual.py +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -247,12 +247,15 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) + a = img.extra_params["min_pixels"] + b = img.extra_params["max_pixels"] + print(f"self.min_pixels is {a} ,max_pixelx is {b}") pixel_values, image_grid_hw = load_image_native( image_data, patch_size=self.patch_size, downsample_ratio=self.downsample_ratio, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, + min_pixels=img.extra_params["min_pixels"], + max_pixels=img.extra_params["max_pixels"], ) img_tensors.append(pixel_values) img_grids.append(image_grid_hw) @@ -261,7 +264,6 @@ def encode(self, images: List[ImageItem]): # must devide merge_length cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) - print(f"cur_num is {cur_num}") valid_ids.append([valid_id, valid_id + cur_num]) valid_id += cur_num diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index f073319d79..3ab2c36c46 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -293,6 +293,8 @@ class SamplingParams(ctypes.Structure): ("ignore_eos", ctypes.c_bool), # the max number of image patches to be used in the internvl model, for the test ("image_max_patch_num", ctypes.c_int), + ("min_pixels", ctypes.c_int), + ("max_pixels", ctypes.c_int), ("max_new_tokens", ctypes.c_int), ("min_new_tokens", ctypes.c_int), # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty @@ -343,7 +345,8 @@ def init(self, tokenizer, **kwargs): self.top_p = kwargs.get("top_p", SamplingParams._top_p) self.top_k = kwargs.get("top_k", SamplingParams._top_k) self.ignore_eos = kwargs.get("ignore_eos", False) - self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) + self.min_pixels = kwargs.get("min_pixels", -1) + self.max_pixels = kwargs.get("max_pixels", -1) self.max_new_tokens = kwargs.get("max_new_tokens", 16) self.min_new_tokens = kwargs.get("min_new_tokens", 1) self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) @@ -482,6 +485,8 @@ def to_dict(self): "image_max_patch_num": self.image_max_patch_num, "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, "exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(), "stop_sequences": self.stop_sequences.to_list(), "best_of": self.best_of, From 07df460fae6dacb64d9304583ecd842cf5435f08 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 12 Jan 2026 15:07:17 +0800 Subject: [PATCH 026/180] fix fused_moe not installed use pip. --- .../basemodel/layer_weights/meta_weights/fused_moe/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From a6f00fbe4a2a9f6149675557d64e51b25f24b024 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 15 Jan 2026 11:51:50 +0000 Subject: [PATCH 027/180] add visual nccl port alloc --- lightllm/server/api_start.py | 20 +++++++------------- lightllm/server/core/objs/start_args_type.py | 3 +-- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f47d0ddd42..c64502e585 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -187,15 +187,6 @@ def _launch_subprocesses(args: StartArgs): else: args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] - # 检查visual_nccl_port数量是否足够 - # if len(args.visual_nccl_ports) < args.visual_dp: - # raise ValueError( - # f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " - # f"but got ({len(args.visual_nccl_ports)})." - # ) - # else: - # args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] - if args.visual_dp <= 0: raise ValueError("visual_dp must be a positive integer.") @@ -240,9 +231,9 @@ def _launch_subprocesses(args: StartArgs): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] + already_uesd_ports = [args.nccl_port, args.port] if args.run_mode == "decode": - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port] + already_uesd_ports = [args.nccl_port, args.port, args.pd_decode_rpyc_port] # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -251,7 +242,7 @@ def _launch_subprocesses(args: StartArgs): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=9 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=9 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -268,11 +259,14 @@ def _launch_subprocesses(args: StartArgs): can_use_ports = can_use_ports[9:] visual_model_tp_ports = [] + visual_nccl_ports = [] for _ in range(args.visual_dp): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] - can_use_ports = can_use_ports[args.visual_tp :] + visual_nccl_ports.append(can_use_ports[args.visual_tp]) + can_use_ports = can_use_ports[args.visual_tp + 1 :] visual_model_tp_ports.append(tp_ports_for_dp) + args.visual_nccl_ports = visual_nccl_ports # 将申请好的端口放入args参数中 args.router_port = router_port args.router_rpc_port = router_rpc_port diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 80719745aa..6834518a1e 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -85,7 +85,7 @@ class StartArgs: visual_gpu_ids: Optional[List[int]] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) - visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) + visual_nccl_ports: List[int] = field(default=None) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) graph_max_batch_size: int = field(default=256) @@ -150,4 +150,3 @@ class StartArgs: enable_weight_cpu_backup: bool = field(default=False) weight_version: str = "default" - From 9360197e355499bb63dda8829c4d9ad61e6896e8 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 15 Jan 2026 12:48:41 +0000 Subject: [PATCH 028/180] fix0115 --- .../layer_infer/transformer_layer_infer.py | 4 +- .../models/llama/triton_kernel/rmsnorm.py | 43 +++++++++---------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 8c60156775..2cf37a10a3 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -191,14 +191,14 @@ def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) - rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out) + out = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out) return out def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) - rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) + out = rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) return out def _get_qkv( diff --git a/lightllm/models/llama/triton_kernel/rmsnorm.py b/lightllm/models/llama/triton_kernel/rmsnorm.py index 0140847afc..de60891593 100644 --- a/lightllm/models/llama/triton_kernel/rmsnorm.py +++ b/lightllm/models/llama/triton_kernel/rmsnorm.py @@ -4,7 +4,7 @@ import triton.language as tl import os -rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) +rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "4")) @triton.jit @@ -36,15 +36,15 @@ def _rms_norm_fwd_fused( for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) + w = tl.load(W + cols, mask=mask) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - y = x_hat * w + x_hat = (x * rstd).to(tl.bfloat16) + y = x_hat * w.to(tl.bfloat16) # Write output - tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + cols * y_stride1, y, mask=mask) -def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): +def rmsnorm_forward1(x: torch.Tensor, weight, eps, out=None): # allocate output y = torch.empty_like(x) if out is None else out # reshape input data into 2D tensor @@ -78,22 +78,19 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): return y -def torch_rms_norm(x, weight, eps): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight +def rmsnorm_forward(hidden_states, weight, eps, out=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + out = weight * hidden_states.to(input_dtype) + return out -def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): - # create data - x_shape = (M, N) - w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - # forward pass - y_tri = rmsnorm_forward(x, weight, eps) - y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype) - - # compare - print("type:", y_tri.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) - return +def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + print(f"norm weight dtype:{self.weight.dtype}") + return self.weight * hidden_states.to(input_dtype) From 920a741d1b7e300dd8e9b64d6db02969f2e10bfe Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 15 Jan 2026 13:11:44 +0000 Subject: [PATCH 029/180] fix0115 --- .../models/qwen3/triton_kernel/qk_norm.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen3/triton_kernel/qk_norm.py b/lightllm/models/qwen3/triton_kernel/qk_norm.py index 40322e5093..e58cce0f2e 100644 --- a/lightllm/models/qwen3/triton_kernel/qk_norm.py +++ b/lightllm/models/qwen3/triton_kernel/qk_norm.py @@ -34,7 +34,7 @@ def _rms_norm_fwd_fused( tl.store(X + cols, y.to(X.dtype.element_ty)) -def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): +def qk_rmsnorm_forward1(x: torch.Tensor, weight: torch.Tensor, eps): """ This function is used to perform in-place RMSNorm on the input tensor, and to adapt the head_dim norm for Qwen3 MoE and the splited qk tensor layout. @@ -64,3 +64,28 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): num_warps=4, ) return x + + +@torch.no_grad() +def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float): + assert torch.is_tensor(x) and torch.is_tensor(weight) + # assert weight.ndim == 1, weight.shape + # assert x.is_contiguous(), "x.is_contiguous()" + + head_dim = weight.numel() + x2d = x.view(-1, x.shape[-1]) # (M2, N) + M2, N = x2d.shape + assert N % head_dim == 0, (N, head_dim) + H = N // head_dim + + x3 = x2d.view(M2, H, head_dim) # (M2, H, D) + + x_fp32 = x3.to(torch.float32) + w = weight.view(1, 1, head_dim) + + var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(var + eps) + y = (x_fp32 * rstd).to(torch.bfloat16) * w + + x3.copy_(y.to(dtype=x3.dtype)) + return x From 3aa5e18ef20b4fbae8f65b4c3c359cc90bb26b82 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 16 Jan 2026 06:01:00 +0000 Subject: [PATCH 030/180] fp8 online quant for moe --- .../fused_moe/fused_moe_weight_tp.py | 32 +++++++++++++++---- .../common/quantization/quantize_method.py | 2 ++ lightllm/common/quantization/w8a8_quant.py | 17 ++++++++-- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py index bf7b218b71..023c7ba634 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -10,6 +10,7 @@ get_row_slice_mixin, get_col_slice_mixin, ) +from threading import Lock def create_tp_moe_wegiht_obj( @@ -80,6 +81,7 @@ def __init__( self.quantized_weight = quant_cfg.quantized_weight if self.quant_method.method_name != "none": self.weight_scale_suffix = self.quant_method.weight_scale_suffix + self.quant_method.is_moe = True self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name @@ -103,6 +105,9 @@ def __init__( self.col_slicer = get_col_slice_mixin( self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() ) + self.lock = Lock() + # for online per-tensor quantization + self.gate_up_buffer = [[None, None] for _ in range(self.n_routed_experts)] self._create_weight() def _create_weight(self): @@ -206,16 +211,16 @@ def load_hf_weights(self, weights): # Load each expert with TP slicing for i_experts in range(self.n_routed_experts): self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) - if self.w13.weight_scale is not None: + if self.w13.weight_scale is not None and self.quantized_weight: self._load_expert(i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix) - if self.w13.weight_zero_point is not None: + if self.w13.weight_zero_point is not None and self.quantized_weight: self._load_expert( i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix ) def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): if self.quant_method.weight_need_quanted(weight): - self.quant_method.quantize(weight, weight_pack, start_idx) + self.quant_method.quantize_moe(weight, weight_pack, start_idx) else: self.quant_method.load_weight(weight, weight_pack, start_idx) @@ -225,10 +230,23 @@ def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{suffix}" intermediate_size = self.split_inter_size load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) - if w1_weight in weights: - load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) - if w3_weight in weights: - load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) + if suffix == "weight": + with self.lock: + if w1_weight in weights: + self.gate_up_buffer[expert_idx][0] = slice_func(weights[w1_weight]) + if w3_weight in weights: + self.gate_up_buffer[expert_idx][1] = slice_func(weights[w3_weight]) + if None not in self.gate_up_buffer[expert_idx]: + tmp_weight = torch.cat( + [self.gate_up_buffer[expert_idx][0], self.gate_up_buffer[expert_idx][1]], dim=0 + ) + load_func(tmp_weight, self.w13.get_expert(expert_idx), start_idx=0) + self.gate_up_buffer[expert_idx] = [None, None] + else: + if w1_weight in weights: + load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) + if w3_weight in weights: + load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) if w2_weight in weights: diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 77e59465ee..971ea20a12 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -15,6 +15,8 @@ def get_expert(self, expert_idx: int): assert self.weight.ndim == 3, f"weight must be a 3D tensor, but got {self.weight.ndim}" weight = self.weight[expert_idx] weight_scale = self.weight_scale[expert_idx] if self.weight_scale is not None else None + if weight_scale is not None and weight_scale.ndim == 0: + weight_scale = weight_scale.unsqueeze(0) weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index e4f7b552aa..b2a5ce6ede 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -116,10 +116,9 @@ def __init__(self): self.is_moe = False self.has_weight_scale = True self.has_weight_zero_point = False + self.is_moe = False def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - if self.is_moe: - return self.quantize_moe(weight, output, offset) qweight, weight_scale = scaled_fp8_quant( weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) @@ -127,6 +126,14 @@ def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) return + def quantize_moe(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + qweight, weight_scale = scaled_fp8_quant( + weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False + ) + output.weight[:, :].copy_(qweight) + output.weight_scale[:].copy_(weight_scale) + return + def apply( self, input_tensor: torch.Tensor, @@ -160,7 +167,11 @@ def create_weight( ) -> WeightPack: expert_prefix = (num_experts,) if num_experts > 1 else () weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + if self.is_moe: + # per-tensor for moe + weight_scale = torch.empty((num_experts,), dtype=torch.float32).cuda(device_id) + else: + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) return WeightPack(weight=weight, weight_scale=weight_scale) From 7cb890b4c6d97632068fc18550a746e55ec53fcc Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 16 Jan 2026 12:56:56 +0000 Subject: [PATCH 031/180] hotfix for fa3 of llama --- .../layer_infer/transformer_layer_infer.py | 144 +++++++++++++----- 1 file changed, 110 insertions(+), 34 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index bb38c45bb5..640d04d6df 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -41,6 +41,14 @@ from lightllm.utils.sgl_utils import flash_attn_with_kvcache +try: + import flash_attn + import flash_attn_3_cuda + +except ImportError: + flash_attn_3_cuda = None + logger.warning("flash_attn is not installed, you can't use the api of it. ") + class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ @@ -326,25 +334,59 @@ def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionSt :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization + # k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, + # o = flash_attn_with_kvcache( + # q=q, + # k_cache=cache_k, + # v_cache=cache_v, + # page_table=infer_state.page_table, + # cache_seqlens=infer_state.b_seq_len, + # cu_seqlens_q=infer_state.cu_seqlens_q, + # cu_seqlens_k_new=infer_state.cu_seqlens_k, + # max_seqlen_q=infer_state.q_max_seq_len, + # softmax_scale=sm_scale, + # causal=True, + # window_size=(-1, -1), + # softcap=0.0, + # k_descale=k_descale, + # v_descale=v_descale, + # return_softmax_lse=False, + # ) + o, softmax_lse, *rest = flash_attn_3_cuda.fwd( + q, + cache_k, + cache_v, + None, + None, + None, # qv + None, # out + infer_state.cu_seqlens_q, + None, + infer_state.cu_seqlens_k, + None, + infer_state.b_seq_len, + infer_state.max_q_seq_len, + None, + infer_state.page_table, + None, + None, + None, + None, + None, + None, + None, + sm_scale, + True, # causal + -1, # window_size + -1, # window_size_right + 0.0, + True, + None, + 0, + None, + 0, ) return o @@ -839,25 +881,59 @@ def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionS :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization + # k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, + # o = flash_attn_with_kvcache( + # q=q, + # k_cache=cache_k, + # v_cache=cache_v, + # page_table=infer_state.page_table, + # cache_seqlens=infer_state.b_seq_len, + # cu_seqlens_q=infer_state.cu_seqlens_q, + # cu_seqlens_k_new=infer_state.cu_seqlens_k, + # max_seqlen_q=1, + # softmax_scale=sm_scale, + # causal=True, + # window_size=(-1, -1), + # softcap=0.0, + # k_descale=k_descale, + # v_descale=v_descale, + # return_softmax_lse=False, + # ) + o, softmax_lse, *rest = flash_attn_3_cuda.fwd( + q, + cache_k, + cache_v, + None, + None, + None, # qv + None, # out + infer_state.cu_seqlens_q, + None, + infer_state.cu_seqlens_k, + None, + infer_state.b_seq_len, + infer_state.max_q_seq_len, + None, + infer_state.page_table, + None, + None, + None, + None, + None, + None, + None, + sm_scale, + True, # causal + -1, # window_size + -1, # window_size_right + 0.0, + True, + None, + 0, + None, + 0, ) return o From c242a75a7286267a31dc7746ad8b983d419a9fc5 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 19 Jan 2026 03:23:06 +0000 Subject: [PATCH 032/180] fp8w8a8 triton config --- ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ ..._fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ ..._fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ .../{topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...t16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 74 ++++++++++++ ...t16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 74 ++++++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 74 ++++++++++++ ...=torch.float16}_NVIDIA_H100_80GB_HBM3.json | 74 ++++++++++++ 9 files changed, 786 insertions(+) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ee316f610b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ddda23d257 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..560ca6c09d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e950ff0954 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..7f479b8382 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "8448": { + "BLOCK_SIZE": 256, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..b3051c6584 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "8448": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..fdb3212216 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..a94e669353 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..441421fd5d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file From a0195aa5e89c9d1f5302872aa24bad88690d84dd Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 19 Jan 2026 09:00:44 +0000 Subject: [PATCH 033/180] fp16 config --- lightllm/common/quantization/w8a8_quant.py | 2 +- ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ 3 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index b2a5ce6ede..7e819ccd24 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -127,7 +127,7 @@ def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> return def quantize_moe(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - qweight, weight_scale = scaled_fp8_quant( + qweight, weight_scale = vllm_ops.scaled_fp8_quant( weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False ) output.weight[:, :].copy_(qweight) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e027701092 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0713de7996 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file From 7f0c43756d647e26938a6b6e9ac919a70e5f3165 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 21 Jan 2026 14:05:27 +0800 Subject: [PATCH 034/180] release ipc tensor early. --- .../server/router/model_infer/mode_backend/base_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4be944584d..28faa6a830 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -462,7 +462,9 @@ def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): def _unwrap_tensor(tensor, tp_rank, device): if isinstance(tensor, LocalSerializedTensor): tensor = tensor.get(tp_rank) - return tensor.to(device) + clone = tensor.to(device).clone() + del tensor # free the ipc tensor + return clone named_tensors = { name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) From 5738d9ee53078d5ca839ac5842ed4ca2f5691498 Mon Sep 17 00:00:00 2001 From: sound Date: Wed, 21 Jan 2026 17:15:18 +0800 Subject: [PATCH 035/180] bugfix: fix flattened_bucket update weights --- .../server/router/model_infer/mode_backend/base_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 28faa6a830..c198e083d3 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -451,7 +451,8 @@ def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): monkey_patch_torch_reductions() if request.load_format == "flattened_bucket": # Handle flattened bucket format - return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=request.named_tensors) + serialized_named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp]) + return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors) # We need to get device after patch otherwise the device would be wrong self.device_module = torch.get_device_module("cuda") From e11bf58707cec1e1bd06076f81e22d4aa655e659 Mon Sep 17 00:00:00 2001 From: sound Date: Thu, 22 Jan 2026 11:07:01 +0800 Subject: [PATCH 036/180] bugfix: fix update_weights from tensor --- .../server/router/model_infer/mode_backend/base_backend.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index c198e083d3..099bd20da5 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -440,9 +440,14 @@ def _update_weights_from_flattened_bucket( # Create bucket and reconstruct tensors bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata) reconstructed_tensors = bucket.reconstruct_tensors() + + named_tensors = { + name: tensor + for name, tensor in reconstructed_tensors + } # Load the reconstructed tensors using the standard method - self.model.load_weights(reconstructed_tensors) + self.model.load_weights(named_tensors) return True, "Succeeded to update parameter online from flattened bucket tensor." From ce76f8a4c2a40243633c3fffa8f1744f18c02491 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 29 Jan 2026 07:20:42 +0000 Subject: [PATCH 037/180] fix start --- lightllm/common/basemodel/basemodel.py | 2 +- lightllm/server/api_start.py | 2 -- lightllm/server/core/objs/start_args_type.py | 4 +++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index a0d2f41ead..c7cbb1f27d 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -184,7 +184,7 @@ def _init_weights(self, start_layer_index=0): return def load_weights(self, weight_dict: dict): - assert isinstance(weight_dict, dict), "weight_dict must be a dict" + assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None" load_hf_weights( self.data_type, self.weight_dir_, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c918832a90..121f3b9a61 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -244,7 +244,6 @@ def _launch_subprocesses(args: StartArgs): ( nccl_port, router_port, - router_rpc_port, detokenization_port, http_server_port, visual_port, @@ -272,7 +271,6 @@ def _launch_subprocesses(args: StartArgs): if args.pd_decode_rpyc_port is None: args.pd_decode_rpyc_port = pd_decode_rpyc_port args.router_port = router_port - args.router_rpc_port = router_rpc_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port args.visual_port = visual_port diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 365db97369..2cb12ed89d 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -156,7 +156,6 @@ class StartArgs: enable_dp_prompt_cache_fetch: bool = field(default=False) # zmp ports router_port: int = field(default=None) - router_rpc_port: int = field(default=None) detokenization_port: int = field(default=None) http_server_port: int = field(default=None) visual_port: int = field(default=None) @@ -181,5 +180,8 @@ class StartArgs: disable_custom_allreduce: bool = field(default=False) enable_torch_memory_saver: bool = field(default=False) enable_weight_cpu_backup: bool = field(default=False) + hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]}) + enable_torch_fallback: bool = field(default=False) + enable_triton_fallback: bool = field(default=False) weight_version: str = "default" From 45259ec01584800d2c11ae64a8936e57ee2a1b5b Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 29 Jan 2026 09:12:09 +0000 Subject: [PATCH 038/180] add-merge-kv-mode --- .../layer_infer/transformer_layer_infer.py | 71 +++++++++++++++++++ .../layer_weights/transformer_layer_weight.py | 26 ++++--- .../models/qwen3/triton_kernel/qk_norm.py | 38 +++++----- 3 files changed, 107 insertions(+), 28 deletions(-) diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index ed48a9c6f1..4adf0e506b 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -17,6 +17,7 @@ class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): def __init__(self, data_type, network_config, mode): + self._is_merge_kv = network_config["is_merge_kv"] super().__init__(data_type, network_config, mode) return @@ -27,6 +28,14 @@ def _bind_attention(self): return def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + if self._is_merge_kv: + return self._get_qkv_mergekv(input, infer_state, layer_weight) + else: + return self._get_qkv_not_mergekv(input, infer_state, layer_weight) + + def _get_qkv_not_mergekv( + self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight + ): input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) # [T, Hq*D] @@ -97,6 +106,68 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC cache_kv = torch.cat([k, v], dim=1) return q, cache_kv + def _get_qkv_mergekv( + self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight + ): + input = input.view(-1, self.embed_dim_) + + q = layer_weight.q_proj.mm(input) # [T, Hq*D] + q_hw = layer_weight.q_hw_proj.mm(input) + k_hw = layer_weight.k_hw_proj.mm(input) + + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_hw, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_hw, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + + q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + qk_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ) + + k_hw = k_hw.view(q.shape[0], self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + def _context_attention_kernel( self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None ) -> torch.Tensor: diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py index bc38f1adcb..d8c842bb99 100644 --- a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -7,6 +7,7 @@ class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + self._is_merge_kv = network_config["merge_kv"] super().__init__(layer_num, data_type, network_config, mode, quant_cfg) return @@ -17,11 +18,15 @@ def _init_weight_names(self): self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" self._k_bias_hw_name = None - self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" - self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + if self._is_merge_kv: + self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight" + self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight" + else: + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" - self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" - self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" def _init_qkv(self): super()._init_qkv() @@ -44,8 +49,11 @@ def _init_qkv(self): def _init_norm(self): super()._init_norm() - - self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) - self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) - self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) - self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) + if self._is_merge_kv: + self.q_norm_hw_weight_ = NormWeight(weight_name=self._q_norm_hw_name, data_type=self.data_type_) + self.k_norm_hw_weight_ = NormWeight(weight_name=self._k_norm_hw_name, data_type=self.data_type_) + else: + self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) + self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) + self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) + self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) diff --git a/lightllm/models/qwen3/triton_kernel/qk_norm.py b/lightllm/models/qwen3/triton_kernel/qk_norm.py index e58cce0f2e..8e0de6a6e3 100644 --- a/lightllm/models/qwen3/triton_kernel/qk_norm.py +++ b/lightllm/models/qwen3/triton_kernel/qk_norm.py @@ -34,7 +34,7 @@ def _rms_norm_fwd_fused( tl.store(X + cols, y.to(X.dtype.element_ty)) -def qk_rmsnorm_forward1(x: torch.Tensor, weight: torch.Tensor, eps): +def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): """ This function is used to perform in-place RMSNorm on the input tensor, and to adapt the head_dim norm for Qwen3 MoE and the splited qk tensor layout. @@ -66,26 +66,26 @@ def qk_rmsnorm_forward1(x: torch.Tensor, weight: torch.Tensor, eps): return x -@torch.no_grad() -def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float): - assert torch.is_tensor(x) and torch.is_tensor(weight) - # assert weight.ndim == 1, weight.shape - # assert x.is_contiguous(), "x.is_contiguous()" +# @torch.no_grad() +# def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float): +# assert torch.is_tensor(x) and torch.is_tensor(weight) +# # assert weight.ndim == 1, weight.shape +# # assert x.is_contiguous(), "x.is_contiguous()" - head_dim = weight.numel() - x2d = x.view(-1, x.shape[-1]) # (M2, N) - M2, N = x2d.shape - assert N % head_dim == 0, (N, head_dim) - H = N // head_dim +# head_dim = weight.numel() +# x2d = x.view(-1, x.shape[-1]) # (M2, N) +# M2, N = x2d.shape +# assert N % head_dim == 0, (N, head_dim) +# H = N // head_dim - x3 = x2d.view(M2, H, head_dim) # (M2, H, D) +# x3 = x2d.view(M2, H, head_dim) # (M2, H, D) - x_fp32 = x3.to(torch.float32) - w = weight.view(1, 1, head_dim) +# x_fp32 = x3.to(torch.float32) +# w = weight.view(1, 1, head_dim) - var = x_fp32.pow(2).mean(dim=-1, keepdim=True) - rstd = torch.rsqrt(var + eps) - y = (x_fp32 * rstd).to(torch.bfloat16) * w +# var = x_fp32.pow(2).mean(dim=-1, keepdim=True) +# rstd = torch.rsqrt(var + eps) +# y = (x_fp32 * rstd).to(torch.bfloat16) * w - x3.copy_(y.to(dtype=x3.dtype)) - return x +# x3.copy_(y.to(dtype=x3.dtype)) +# return x From da3b53db4024786a0a98457f88ddbf0f7d716a0d Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 29 Jan 2026 09:36:15 +0000 Subject: [PATCH 039/180] add-neo-chat0129 --- .../neo_chat_moe/layer_infer/transformer_layer_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 4adf0e506b..8e5c8e5cd7 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -118,8 +118,8 @@ def _get_qkv_mergekv( cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(q_hw, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(k_hw, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_hw, weight=layer_weight.q_norm_hw_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_hw, weight=layer_weight.k_norm_hw_weight_.weight, eps=self.eps_) q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) q_h, q_w = q_hw.chunk(2, dim=-1) From 043e898e6589ab40c652f3e6e930ac951c83523a Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 05:19:37 +0000 Subject: [PATCH 040/180] moe fused weight --- .../meta_weights/fused_moe/fused_moe_weight.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03c..5d6519de4f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -295,6 +295,7 @@ def _create_weight(self): device_id=self.device_id_, num_experts=self.local_n_routed_experts, ) + self.w1, self.w3 = w13_param_list self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0]) self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) @@ -312,6 +313,8 @@ def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[st for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: self._load_expert(expert_idx, local_expert_idx, weights) + # for rl updated weight + self._load_merge_weight(weights) self._load_expert_scale( expert_idx, local_expert_idx, @@ -332,6 +335,7 @@ def _load_expert( w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" + row_slice_func = self.row_slicer._slice_weight col_slice_func = self.col_slicer._slice_weight if w1_weight in weights: @@ -341,6 +345,17 @@ def _load_expert( if w2_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): + w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" + w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}" + w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}" + if w1_merge_weight in weights: + self.quant_method.load_weight(weights[w1_merge_weight], self.w1) + if w2_merge_weight in weights: + self.quant_method.load_weight(weights[w2_merge_weight], self.w2) + if w3_merge_weight in weights: + self.quant_method.load_weight(weights[w3_merge_weight], self.w3) + def _load_expert_scale( self, expert_idx: int, From 80cfcc4f7984fbb66d5a6ac0168d33a36ef6be83 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 06:32:12 +0000 Subject: [PATCH 041/180] fix neo --- lightllm/models/llama/model.py | 2 - .../layer_infer/transformer_layer_infer.py | 36 ++++++------ .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 42 +++++++++----- lightllm/models/neo_chat_moe/infer_struct.py | 6 +- .../layer_infer/transformer_layer_infer.py | 51 +++++++++-------- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 56 +++++++++++++------ 8 files changed, 119 insertions(+), 82 deletions(-) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 7064edae87..f86bd5f83d 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -110,7 +110,6 @@ def _init_to_get_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta", float(default_base)) - print(f"base is {base}") if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: @@ -151,7 +150,6 @@ def _init_to_get_hw_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta_hw", float(default_base)) - print(f"hw_base is {base}") if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index 1cf13c4130..a3436b28ee 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -1,23 +1,20 @@ import torch from functools import partial from typing import Tuple -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight from lightllm.distributed import all_reduce import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def _bind_attention(self): @@ -40,25 +37,24 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + layer_weight.q_norm_weight_(q, eps=self.eps_) q_h_2d = q_h.reshape(q.shape[0], -1) q_w_2d = q_w.reshape(q.shape[0], -1) - qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) + layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] k_w_2d = k_w.reshape(q.shape[0], -1) - qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) + layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) @@ -119,7 +115,7 @@ def _context_attention_kernel( o3 = o3[:, :, : self.head_dim_].contiguous() return o3.view(o3.shape[0], -1) - def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size @@ -134,18 +130,22 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, att_m_tensor, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( + token_attention_softmax_and_reducev, + ) + + token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ ] - o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) token_softmax_reducev_fwd( att_m_tensor, @@ -153,7 +153,7 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, o_3d, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, ) return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py index c1f0638ac4..e6489f39af 100644 --- a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -13,8 +13,8 @@ def rename_weight_keys(weights): class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py index e5e769a769..e62afae9bc 100644 --- a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -1,13 +1,13 @@ from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - NormWeight, + QKRMSNORMWeight, ROWMMWeight, ) class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): @@ -26,26 +26,42 @@ def _init_weight_names(self): def _init_qkv(self): super()._init_qkv() self.q_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.q_head_num_ * self.head_dim], weight_names=self._q_weight_hw_name, data_type=self.data_type_, bias_names=self._q_bias_hw_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_hw_proj", + quant_method=self.get_quant_method("q_hw_proj"), ) self.k_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.k_head_num_ * self.head_dim], weight_names=self._k_weight_hw_name, data_type=self.data_type_, bias_names=self._k_bias_hw_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="k_hw_proj", + quant_method=self.get_quant_method("k_hw_proj"), ) def _init_norm(self): super()._init_norm() - self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) - self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) - self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) - self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) + self.q_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_h_name, + data_type=self.data_type_, + ) + self.q_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_w_name, + data_type=self.data_type_, + ) + self.k_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_h_name, + data_type=self.data_type_, + ) + self.k_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_w_name, + data_type=self.data_type_, + ) diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index 0c7d9372e2..13d1ba5fc2 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -17,8 +17,8 @@ def __init__(self): self.position_cos_w = None self.position_sin_w = None - def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): - LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) + def init_some_extra_state(self, model: LlamaTpPartModel): + LlamaInferStateInfo.init_some_extra_state(self, model) if self.is_prefill: self.position_ids = self.get_neo_position(self.multimodal_params) else: @@ -94,6 +94,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: position_ids=position_ids, b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, - b_start_loc=self.b_start_loc, + b_start_loc=self.b_q_start_loc, ) return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 8e5c8e5cd7..ad891e6ca9 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -1,24 +1,21 @@ import torch from functools import partial from typing import Tuple -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight from lightllm.distributed import all_reduce import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, data_type, network_config, mode): - self._is_merge_kv = network_config["is_merge_kv"] - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + self._is_merge_kv = network_config.get("merge_kv", True) + super().__init__(data_type, network_config) return def _bind_attention(self): @@ -49,25 +46,24 @@ def _get_qkv_not_mergekv( cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + layer_weight.q_norm_weight_(q, eps=self.eps_) q_h_2d = q_h.reshape(q.shape[0], -1) q_w_2d = q_w.reshape(q.shape[0], -1) - qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) + layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] k_w_2d = k_w.reshape(q.shape[0], -1) - qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) + layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) @@ -117,16 +113,15 @@ def _get_qkv_mergekv( cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(q_hw, weight=layer_weight.q_norm_hw_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(k_hw, weight=layer_weight.k_norm_hw_weight_.weight, eps=self.eps_) + layer_weight.q_norm_weight_(q, eps=self.eps_) + layer_weight.q_norm_hw_weight_(q_hw, eps=self.eps_) + layer_weight.k_norm_hw_weight_(k_hw, eps=self.eps_) q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) q_h, q_w = q_hw.chunk(2, dim=-1) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) @@ -180,17 +175,17 @@ def _context_attention_kernel( o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_q_start_loc, infer_state.b_seq_len, infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, + infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, ) o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) o3 = o3[:, :, : self.head_dim_].contiguous() return o3.view(o3.shape[0], -1) - def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size @@ -205,18 +200,22 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, att_m_tensor, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( + token_attention_softmax_and_reducev, + ) + + token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ ] - o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) token_softmax_reducev_fwd( att_m_tensor, @@ -224,7 +223,7 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, o_3d, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, ) return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py index 7766a5d29f..4b0eae91c3 100644 --- a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py @@ -13,8 +13,8 @@ def rename_weight_keys(weights): class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py index d8c842bb99..26e986cdd7 100644 --- a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -1,14 +1,14 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - NormWeight, + QKRMSNORMWeight, ROWMMWeight, ) class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - self._is_merge_kv = network_config["merge_kv"] - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + self._is_merge_kv = network_config.get("merge_kv", True) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): @@ -31,29 +31,53 @@ def _init_weight_names(self): def _init_qkv(self): super()._init_qkv() self.q_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.q_head_num_ * self.head_dim], weight_names=self._q_weight_hw_name, data_type=self.data_type_, bias_names=self._q_bias_hw_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_hw_proj", + quant_method=self.get_quant_method("q_hw_proj"), ) self.k_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.k_head_num_ * self.head_dim], weight_names=self._k_weight_hw_name, data_type=self.data_type_, bias_names=self._k_bias_hw_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="k_hw_proj", + quant_method=self.get_quant_method("k_hw_proj"), ) def _init_norm(self): super()._init_norm() if self._is_merge_kv: - self.q_norm_hw_weight_ = NormWeight(weight_name=self._q_norm_hw_name, data_type=self.data_type_) - self.k_norm_hw_weight_ = NormWeight(weight_name=self._k_norm_hw_name, data_type=self.data_type_) + self.q_norm_hw_weight_ = QKRMSNORMWeight( + dim=self.head_dim, + weight_name=self._q_norm_hw_name, + data_type=self.data_type_, + ) + self.k_norm_hw_weight_ = QKRMSNORMWeight( + dim=self.head_dim, + weight_name=self._k_norm_hw_name, + data_type=self.data_type_, + ) else: - self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) - self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) - self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) - self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) + self.q_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_h_name, + data_type=self.data_type_, + ) + self.q_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_w_name, + data_type=self.data_type_, + ) + self.k_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_h_name, + data_type=self.data_type_, + ) + self.k_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_w_name, + data_type=self.data_type_, + ) From 6bbdb4feaa3523b25cc8221290a77130a933f204 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 07:58:42 +0000 Subject: [PATCH 042/180] fix launch --- lightllm/models/neo_chat_moe/neo_visual.py | 6 +++--- lightllm/server/req_id_generator.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py index 59bd23e2bb..60fa82f2b9 100644 --- a/lightllm/models/neo_chat_moe/neo_visual.py +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -247,9 +247,9 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - a = img.extra_params["min_pixels"] - b = img.extra_params["max_pixels"] - print(f"self.min_pixels is {a} ,max_pixelx is {b}") + # a = img.extra_params["min_pixels"] + # b = img.extra_params["max_pixels"] + # print(f"self.min_pixels is {a} ,max_pixelx is {b}") pixel_values, image_grid_hw = load_image_native( image_data, patch_size=self.patch_size, diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index 9bf9040c30..da1fade0dd 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -30,7 +30,7 @@ def __init__(self): self.current_id.arr[0] = 0 self.current_id.arr[1] = 0 self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") - self._wait_all_workers_ready() + # self._wait_all_workers_ready() logger.info("ReqIDGenerator init finished") def _wait_all_workers_ready(self): From e436ba565ed2ce2daa10d45df79ecac535c5c72b Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 08:02:32 +0000 Subject: [PATCH 043/180] fix launch --- lightllm/server/req_id_generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index da1fade0dd..20da121dc0 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -30,7 +30,8 @@ def __init__(self): self.current_id.arr[0] = 0 self.current_id.arr[1] = 0 self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") - # self._wait_all_workers_ready() + if self.args.httpserver_workers > 1: + self._wait_all_workers_ready() logger.info("ReqIDGenerator init finished") def _wait_all_workers_ready(self): From aef65bcbef6d6ce637b4c4a1862871c479a3029d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 08:14:05 +0000 Subject: [PATCH 044/180] fix tp slice for merged moe weight --- .../meta_weights/mm_weight/mm_slicer.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index ddbf98a866..15f050c14a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -47,17 +47,17 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten # 默认weight 的shape是 outxin,这也是目前最通用的约定。 -# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 +# 这里约定row-wise沿着倒数第二维切分,col-wise沿着第一维切分。 class RowSliceMixin(SliceMixinTpl): def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[0]) - return weight[start:end, :] + weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-2]) + return weight[..., start:end, :] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: assert ( @@ -75,17 +75,17 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( - weight_scale.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[0]) - return weight_scale[start:end] + weight_scale.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-2]) + return weight_scale[..., start:end, :] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[0]) - return weight_zero_point[start:end] + weight_zero_point.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-2]) + return weight_zero_point[..., start:end, :] class ColSliceMixin(SliceMixinTpl): @@ -94,10 +94,10 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[1]) - return weight[:, start:end] + weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-1]) + return weight[..., start:end] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ * self.repeat_times_ @@ -110,16 +110,16 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[1]) - return weight_scale[:, start:end] + ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-1]) + return weight_scale[..., start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[1]) - return weight_zero_point[:, start:end] + weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-1]) + return weight_zero_point[..., start:end] # awq 的量化权重是inxout存储格式,需要定制实现。 From bc87692403b1c62380a46978630c5c71c9a0d657 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 08:18:50 +0000 Subject: [PATCH 045/180] fix fusemoe weight --- .../meta_weights/fused_moe/fused_moe_weight.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 5d6519de4f..77d6d40e9f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -349,12 +349,14 @@ def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}" w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}" + row_slice_func = self.row_slicer._slice_weight + col_slice_func = self.col_slicer._slice_weight if w1_merge_weight in weights: - self.quant_method.load_weight(weights[w1_merge_weight], self.w1) + self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1) if w2_merge_weight in weights: - self.quant_method.load_weight(weights[w2_merge_weight], self.w2) + self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2) if w3_merge_weight in weights: - self.quant_method.load_weight(weights[w3_merge_weight], self.w3) + self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3) def _load_expert_scale( self, From cf5bcbf4b99045cddca61d12f3236e760cc00baa Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 10:12:17 +0000 Subject: [PATCH 046/180] fa3 for neo --- .../layer_infer/transformer_layer_infer.py | 96 +++++++++++-------- 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index ad891e6ca9..3670dac687 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -185,45 +185,57 @@ def _context_attention_kernel( o3 = o3[:, :, : self.head_dim_].contiguous() return o3.view(o3.shape[0], -1) - def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - - q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) - - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) - - k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - token_att_fwd( - q_3d, - k_3d, - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_kv_start_loc, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, - ) - - from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( - token_attention_softmax_and_reducev, - ) - - token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd - - v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ - ] - - o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) - - token_softmax_reducev_fwd( - att_m_tensor, - v_3d, - o_3d, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_kv_start_loc, - infer_state.b_seq_len, - ) - return o_3d.view(batch_size, -1) + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: NeoChatInferStateInfo, + layer_weight: NeoChatMOETransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) + o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)[:, :, : self.head_dim_].contiguous() + return o_tensor + + # def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): + # total_token_num = infer_state.total_token_num + # batch_size = infer_state.batch_size + + # q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + # att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + # k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + # token_att_fwd( + # q_3d, + # k_3d, + # att_m_tensor, + # infer_state.req_manager.req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_kv_start_loc, + # infer_state.b_seq_len, + # infer_state.max_kv_seq_len, + # ) + + # from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( + # token_attention_softmax_and_reducev, + # ) + + # token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd + + # v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + # :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + # ] + + # o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) + + # token_softmax_reducev_fwd( + # att_m_tensor, + # v_3d, + # o_3d, + # infer_state.req_manager.req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_kv_start_loc, + # infer_state.b_seq_len, + # ) + # return o_3d.view(batch_size, -1) From a23288b489369b57c278e5af9b609a025e0c705d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 10:43:43 +0000 Subject: [PATCH 047/180] fix dead visual process --- lightllm/server/visualserver/model_infer/model_rpc.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 5e07b162d4..22dfa915ba 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -4,6 +4,7 @@ import torch import socket import inspect +import setproctitle from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig @@ -26,6 +27,8 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name class VisualModelRpcServer(rpyc.Service): @@ -175,6 +178,8 @@ async def encode(self, images: List[ImageItem]): def _init_env(port, device_id): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server::RANK{device_id}") + start_parent_check_thread() import lightllm.utils.rpyc_fix_utils as _ From f5585404ab489521dc1a96ce7ba017d70d244668 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 10:59:37 +0000 Subject: [PATCH 048/180] auto visual dp --- lightllm/server/api_start.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 121f3b9a61..58dac941b0 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -164,6 +164,10 @@ def _launch_subprocesses(args: StartArgs): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # automatically set visual_dp based on visual_tp and tp + if args.visual_tp < args.tp and args.tp % args.visual_tp == 0: + args.visual_dp = args.tp // args.visual_tp + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) From 12c6c6b274f598c6e14484e0a026b4bdbad47125 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 12:15:36 +0000 Subject: [PATCH 049/180] fix format --- lightllm/utils/patch_torch.py | 8 +++----- lightllm/utils/serializer.py | 17 +++++++---------- lightllm/utils/tensor_bucket.py | 16 ++++++---------- 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py index c504e4bbc9..9f51edeb64 100644 --- a/lightllm/utils/patch_torch.py +++ b/lightllm/utils/patch_torch.py @@ -9,7 +9,8 @@ def monkey_patch_torch_reductions(): """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" - # Currently, NPU does not support UUID. This has been temporarily commented out, with support expected in the fourth quarter. + # Currently, NPU does not support UUID. This has been temporarily commented out, + # with support expected in the fourth quarter. # if _is_npu: # return @@ -32,9 +33,7 @@ def monkey_patch_torch_reductions(): def _reduce_tensor_modified(*args, **kwargs): output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) - output_args = _modify_tuple( - output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid - ) + output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid) return output_fn, output_args @@ -62,4 +61,3 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: def _modify_tuple(t, index: int, modifier: Callable): return *t[:index], modifier(t[index]), *t[index + 1 :] - diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py index ae6f418df6..d8180aeb0c 100644 --- a/lightllm/utils/serializer.py +++ b/lightllm/utils/serializer.py @@ -1,4 +1,3 @@ - # copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py import base64 @@ -108,27 +107,25 @@ def find_class(self, module, name): # Block deterministic attacks if (module, name) in self.DENY_CLASSES: raise RuntimeError( - f"Blocked unsafe class loading ({module}.{name}), " - f"to prevent exploitation of CVE-2025-10164" + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" ) # Allowlist of safe-to-load modules. - if any( - (module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES - ): + if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES): return super().find_class(module, name) # Block everything else. (Potential attack surface) raise RuntimeError( - f"Blocked unsafe class loading ({module}.{name}), " - f"to prevent exploitation of CVE-2025-10164" + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" ) + @dataclass class LocalSerializedTensor: - """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data). + """torch.Tensor that gets serialized by MultiprocessingSerializer + (which only serializes a pointer and not the data). The i-th element in the list corresponds to i-th rank's GPU.""" values: List[bytes] def get(self, rank: int): - return MultiprocessingSerializer.deserialize(self.values[rank]) \ No newline at end of file + return MultiprocessingSerializer.deserialize(self.values[rank]) diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py index 762bd0dd06..a9d7a367dd 100644 --- a/lightllm/utils/tensor_bucket.py +++ b/lightllm/utils/tensor_bucket.py @@ -1,4 +1,6 @@ -# copy from https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/srt/weight_sync/tensor_bucket.py +# copy from +# https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/ +# srt/weight_sync/tensor_bucket.py from dataclasses import dataclass from typing import List, Tuple @@ -74,9 +76,7 @@ def __init__( else: # Initialize from pre-flattened data if flattened_tensor is None or metadata is None: - raise ValueError( - "Must provide either named_tensors or both flattened_tensor and metadata" - ) + raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata") self.flattened_tensor = flattened_tensor self.metadata = metadata @@ -97,12 +97,8 @@ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: reconstructed = [None] * len(self.metadata) for i, meta in enumerate(self.metadata): - tensor = ( - self.flattened_tensor[meta.start_idx : meta.end_idx] - .view(meta.dtype) - .reshape(meta.shape) - ) + tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape) reconstructed[i] = (meta.name, tensor) - return reconstructed \ No newline at end of file + return reconstructed From fd91cad792ce26318538ef8051a627b9f5b5474c Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 2 Feb 2026 07:46:54 +0000 Subject: [PATCH 050/180] fix decode scale --- lightllm/common/basemodel/attention/base_att.py | 1 + lightllm/common/basemodel/attention/fa3/fp.py | 7 +++++-- .../neo_chat_moe/layer_infer/transformer_layer_infer.py | 9 ++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca84..6429bce9a0 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -65,6 +65,7 @@ class AttControl: mla_prefill_dict: Dict = None mla_decode: bool = False mla_decode_dict: Dict = None + scale: float = None @dataclass diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..2f5fccd57b 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -220,8 +220,11 @@ def _normal_decode_att( sink_weight = None k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) + if att_control.scale is not None: + sm_scale = att_control.scale + else: + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 3670dac687..c5efe1eef8 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -10,6 +10,7 @@ from lightllm.distributed import all_reduce import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.common.basemodel.attention.base_att import AttControl class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): @@ -193,7 +194,13 @@ def _token_attention_kernel( ) -> torch.Tensor: _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) - o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) + att_control = AttControl() + if att_control.scale is None: + att_control.scale = 1.0 / (self.head_dim_ ** 0.5) + # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor + ) o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)[:, :, : self.head_dim_].contiguous() return o_tensor From 26812639a79be7094510032eda1bb0d035f6be9b Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 2 Feb 2026 08:58:27 +0000 Subject: [PATCH 051/180] add new mode support text_ids+image_ids --- lightllm/models/neo_chat_moe/model.py | 8 +++++--- lightllm/server/httpserver/manager.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index ce2093d450..cf4404090f 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -88,9 +88,11 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): add_special_tokens = kwargs.get("add_special_tokens", True) return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) image_count = len(multimodal_params.images) - prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) - - origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + if not kwargs.get("already_tokenized", False): + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + else: + origin_ids = prompt # --> id,id+1...id+num input_ids = [] image_id = 0 diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index f48b9d04c5..dde1b51891 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -493,7 +493,20 @@ async def _encode( # 这里的校验对多模态不是很充分, to do if all(isinstance(e, int) for e in prompt): - if not self.enable_multimodal and not self.pd_mode.is_D(): + if self.enable_multimodal: + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + if multimodal_params.audios: + assert self.args.enable_multimodal_audio, "audio multimodal not enabled" + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + prompt_ids = self.tokenizer.encode( + prompt, + multimodal_params, + add_special_tokens=sampling_params.add_special_tokens, + already_tokenized=True, + ) + elif not self.enable_multimodal and not self.pd_mode.is_D(): if all(e < self.vocab_size for e in prompt): return prompt else: From fd17aa083555ae9551c919c58cf7ed075d84f255 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 2 Feb 2026 11:04:47 +0000 Subject: [PATCH 052/180] add new mode support text_ids+image_ids --- lightllm/server/httpserver/manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index dde1b51891..b6a7b0d127 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -506,6 +506,7 @@ async def _encode( add_special_tokens=sampling_params.add_special_tokens, already_tokenized=True, ) + return prompt_ids elif not self.enable_multimodal and not self.pd_mode.is_D(): if all(e < self.vocab_size for e in prompt): return prompt From e516bd9eabbbaefdabf33612406cc8547d7e7b36 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 2 Feb 2026 13:10:11 +0000 Subject: [PATCH 053/180] add cuda empty cache --- lightllm/common/basemodel/basemodel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c7cbb1f27d..4f37ebf247 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1070,25 +1070,32 @@ def release_weight(self): def release_kv_cache(self): self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + torch.cuda.empty_cache() def release_graph(self): self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() def release_all(self): self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() def resume_weight(self): + torch.cuda.empty_cache() self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) def resume_kv_cache(self): + torch.cuda.empty_cache() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) def resume_graph(self): + torch.cuda.empty_cache() self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) def resume_all(self): + torch.cuda.empty_cache() self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) From 81a0c1282b89c5247139af2669d451bb4a2bd0b9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 2 Feb 2026 14:42:10 +0000 Subject: [PATCH 054/180] add invalid token ids to sampling_param for rl training --- .../triton_kernel/post_process/__init__.py | 0 .../post_process/apply_invalid_token.py | 36 +++++++++++++++++++ .../{ => post_process}/apply_penalty.py | 0 .../apply_penalty_gpu_cache.py | 0 .../server/core/objs/py_sampling_params.py | 4 +++ lightllm/server/core/objs/sampling_params.py | 28 +++++++++++++++ .../server/router/model_infer/infer_batch.py | 6 ++++ .../mode_backend/generic_post_process.py | 34 ++++++++++++++++-- 8 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/post_process/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py rename lightllm/common/basemodel/triton_kernel/{ => post_process}/apply_penalty.py (100%) rename lightllm/common/basemodel/triton_kernel/{ => post_process}/apply_penalty_gpu_cache.py (100%) diff --git a/lightllm/common/basemodel/triton_kernel/post_process/__init__.py b/lightllm/common/basemodel/triton_kernel/post_process/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py new file mode 100644 index 0000000000..353affd8ed --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py @@ -0,0 +1,36 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_apply_invalid_token( + Logits, + invalid_token_ids, + cu_invalid_token_num, + stride_logit_b, +): + cur_batch = tl.program_id(0) + start_index = tl.load(cu_invalid_token_num + cur_batch) + end_index = tl.load(cu_invalid_token_num + cur_batch + 1) + for i in range(start_index, end_index): + cur_invalid_token_id = tl.load(invalid_token_ids + i) + cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id + tl.store(cur_logit_ptr, float("-inf")) + return + + +def apply_invalid_token_ids( + Logits: torch.Tensor, + invalid_token_ids: torch.Tensor, + cu_invalid_token_num: torch.Tensor, +): + batch_size = Logits.shape[0] + grid = (batch_size,) + _fwd_kernel_apply_invalid_token[grid]( + Logits=Logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + stride_logit_b=Logits.stride(0), + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 887f360c84..9194a235da 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -54,6 +54,8 @@ def __init__( # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. allowed_token_ids: Optional[List[int]] = None, + # if provided, the invalid token ids will be ignored during generation + invalid_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, # move kv to deocde node, only used in pd mode @@ -88,6 +90,7 @@ def __init__( self.guided_grammar = guided_grammar self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids + self.invalid_token_ids = invalid_token_ids self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node self.suggested_dp_index = suggested_dp_index @@ -267,6 +270,7 @@ def to_dict(self): ret["guided_grammar"] = self.guided_grammar ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids + ret["invalid_token_ids"] = self.invalid_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node return ret diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 7d4d2531b4..650d15512f 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -17,6 +17,7 @@ REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) +INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) class StopSequence(ctypes.Structure): @@ -205,6 +206,25 @@ def to_list(self): return list(self.ids[: self.size]) +class InvalidTokenIds(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH), + ("size", ctypes.c_int), + ] + + def initialize(self, ids: List[int]): + self.size = len(ids) + assert ( + self.size <= INVALID_TOKEN_IDS_MAX_LENGTH + ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}." + self.ids[: self.size] = ids[:] + return + + def to_list(self): + return list(self.ids[: self.size]) + + class ExponentialDecayLengthPenalty(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -306,6 +326,8 @@ class SamplingParams(ctypes.Structure): # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. ("allowed_token_ids", AllowedTokenIds), + # if provided, the invalid token ids will be ignored during generation + ("invalid_token_ids", InvalidTokenIds), ("stop_sequences", StopSequenceGroups), ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), ("group_request_id", ctypes.c_int64), # p d mode used params @@ -395,6 +417,11 @@ def init(self, tokenizer, **kwargs): self.allowed_token_ids = AllowedTokenIds() self.allowed_token_ids.initialize(allowed_token_ids) + # Initialize invalid_token_ids + invalid_token_ids = kwargs.get("invalid_token_ids", []) + self.invalid_token_ids = InvalidTokenIds() + self.invalid_token_ids.initialize(invalid_token_ids) + if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -495,6 +522,7 @@ def to_dict(self): "guided_grammar": self.guided_grammar.to_str(), "guided_json": self.guided_json.to_str(), "allowed_token_ids": self.allowed_token_ids.to_list(), + "invalid_token_ids": self.invalid_token_ids.to_list(), "group_request_id": self.group_request_id, "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), "skip_special_tokens": self.skip_special_tokens, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 89230c92de..2b35fad053 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -266,6 +266,7 @@ def __init__( self.fsm_current_state: int = 0 self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list() + self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list() if len(self.allowed_token_ids) == 0: self.allowed_token_ids = None @@ -281,6 +282,11 @@ def __init__( logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids") self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size] + if len(self.invalid_token_ids) > 0: + if not all(e < vocab_size for e in self.invalid_token_ids): + logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids") + self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size] + # nixl decode node information if self.shm_param.nixl_params.data_len > 0: self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index e2ccf290e8..fc551b08ea 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,7 +1,8 @@ import torch from typing import List -from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty -from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.utils.envs_utils import get_env_start_args @@ -14,7 +15,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_top_ks, b_length_penalty_param, b_mask_eos_reqs, + invalid_token_ids, + cu_invalid_token_num, is_all_greedy, + has_invalid_token_ids, ) = _get_post_sample_tensors(reqs) eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True) @@ -59,6 +63,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): eos_ids=eos_ids, sampling_params_manager=sampling_params_manager, ) + + if has_invalid_token_ids: + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + logits.div_(b_temperatures.view((-1, 1))) probs = torch.softmax(logits, dim=-1) @@ -112,6 +124,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]): mask_eos_reqs: List[bool] = [] is_all_greedy = True + # invalid token ids + invalid_token_ids: List[int] = [] + has_invalid_token_ids = False + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param shm_param = sample_param.shm_param @@ -127,6 +145,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]): if top_k_val > 1: is_all_greedy = False req_idxes.append(req_obj.req_idx) + invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) + cu_invalid_token_num.append(invalid_token_num_start) + if len(req_obj.sampling_param.invalid_token_ids) > 0: + has_invalid_token_ids = True + invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True) temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True) @@ -135,6 +158,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True) mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True) + if has_invalid_token_ids: + invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True) + cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True) + return ( req_idxes_cpu.cuda(non_blocking=True), temperatures_cpu.cuda(non_blocking=True), @@ -142,5 +169,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks_cpu.cuda(non_blocking=True), length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), + invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, + cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, is_all_greedy, + has_invalid_token_ids, ) From 14132d57eb8221c89c254ffbdef2dcfe3a086e2d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 2 Feb 2026 14:42:33 +0000 Subject: [PATCH 055/180] add unitest for apply_invalid_tokens --- .../triton_kernel/test_apply_invalid_token.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py diff --git a/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py new file mode 100644 index 0000000000..3b2f159f62 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import ( + apply_invalid_token_ids, +) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_apply_invalid_token_ids(dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Triton kernels.") + + batch_size = 4 + vocab_size = 32 + logits = torch.randn((batch_size, vocab_size), device="cuda", dtype=dtype) + expected = logits.clone() + + invalid_token_ids_per_batch = [ + [1, 3, 5], + [], + [0, 2, 31], + [7], + ] + + flat_ids = [] + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for ids in invalid_token_ids_per_batch: + flat_ids.extend(ids) + invalid_token_num_start += len(ids) + cu_invalid_token_num.append(invalid_token_num_start) + + invalid_token_ids = torch.tensor(flat_ids, device="cuda", dtype=torch.int32) + cu_invalid_token_num = torch.tensor(cu_invalid_token_num, device="cuda", dtype=torch.int32) + + for batch_idx, ids in enumerate(invalid_token_ids_per_batch): + if ids: + expected[batch_idx, ids] = float("-inf") + + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + assert torch.equal(logits, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) From ed41960f38fdc5239945d3772302b9e3f9509c21 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 3 Feb 2026 07:27:08 +0000 Subject: [PATCH 056/180] add gc collect --- lightllm/common/basemodel/basemodel.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4f37ebf247..c69ae07bdb 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1067,35 +1067,44 @@ def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): def release_weight(self): self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + torch.cuda.empty_cache() + gc.collect() def release_kv_cache(self): self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) torch.cuda.empty_cache() + gc.collect() def release_graph(self): self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) torch.cuda.empty_cache() + gc.collect() def release_all(self): self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) torch.cuda.empty_cache() + gc.collect() def resume_weight(self): torch.cuda.empty_cache() + gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) def resume_kv_cache(self): torch.cuda.empty_cache() + gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) def resume_graph(self): torch.cuda.empty_cache() + gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) def resume_all(self): torch.cuda.empty_cache() + gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) From 706ae2e022ade69e0ed255bf0857d7e10cbc2a9e Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 3 Feb 2026 09:04:51 +0000 Subject: [PATCH 057/180] logit_bias --- lightllm/server/core/objs/sampling_params.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 650d15512f..1beece3421 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -418,9 +418,9 @@ def init(self, tokenizer, **kwargs): self.allowed_token_ids.initialize(allowed_token_ids) # Initialize invalid_token_ids - invalid_token_ids = kwargs.get("invalid_token_ids", []) + invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) self.invalid_token_ids = InvalidTokenIds() - self.invalid_token_ids.initialize(invalid_token_ids) + self.invalid_token_ids.initialize(list(invalid_token_ids)) if self.do_sample is False: self.temperature = 1.0 From f432f5ae30ddc1badf3dbfdfecbbec0c32ad78f9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 3 Feb 2026 10:16:20 +0000 Subject: [PATCH 058/180] logit_bias --- lightllm/server/core/objs/sampling_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 1beece3421..93447830bb 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -420,7 +420,7 @@ def init(self, tokenizer, **kwargs): # Initialize invalid_token_ids invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) self.invalid_token_ids = InvalidTokenIds() - self.invalid_token_ids.initialize(list(invalid_token_ids)) + self.invalid_token_ids.initialize(list[int](invalid_token_ids)) if self.do_sample is False: self.temperature = 1.0 From 8f8ed44ae035467fee240e774bf7c29eafdfc5d2 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 4 Feb 2026 05:50:51 +0000 Subject: [PATCH 059/180] merge main --- .../neo_chat_moe/layer_infer/transformer_layer_infer.py | 7 ++++--- lightllm/server/api_start.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index c5efe1eef8..3cf5d1ecb6 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -108,12 +108,13 @@ def _get_qkv_mergekv( ): input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) # [T, Hq*D] + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) q_hw = layer_weight.q_hw_proj.mm(input) k_hw = layer_weight.k_hw_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - layer_weight.q_norm_weight_(q, eps=self.eps_) layer_weight.q_norm_hw_weight_(q_hw, eps=self.eps_) layer_weight.k_norm_hw_weight_(k_hw, eps=self.eps_) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 58dac941b0..afe199d04f 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -91,7 +91,6 @@ def _set_envs_and_config(args: StartArgs): def _launch_subprocesses(args: StartArgs): _set_envs_and_config(args) - set_unique_server_name(args) if not args.disable_shm_warning: check_recommended_shm_size(args) @@ -291,6 +290,8 @@ def _launch_subprocesses(args: StartArgs): args.pd_p_allowed_port_min = 20000 args.pd_p_allowed_port_max = 30000 + set_unique_server_name(args) + # p d 分离模式下,decode节点的调度间隙是0 if args.run_mode == "decode": args.router_max_wait_tokens = 0 From cac2edf0a632c589992637e9eff8b767e9013cce Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 6 Feb 2026 03:12:36 +0000 Subject: [PATCH 060/180] neo moe inferece speedup --- lightllm/models/neo_chat_moe/infer_struct.py | 4 ++ .../layer_infer/transformer_layer_infer.py | 1 + .../context_attention_fwd_neo.py | 63 +++++++------------ .../triton_kernel/get_neo_position.py | 17 +++++ .../models/neo_chat_moe/vision_process.py | 2 +- 5 files changed, 44 insertions(+), 43 deletions(-) diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index 13d1ba5fc2..961ed2a61d 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -20,6 +20,9 @@ def __init__(self): def init_some_extra_state(self, model: LlamaTpPartModel): LlamaInferStateInfo.init_some_extra_state(self, model) if self.is_prefill: + self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( + non_blocking=True + ) self.position_ids = self.get_neo_position(self.multimodal_params) else: b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] @@ -95,5 +98,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, b_start_loc=self.b_q_start_loc, + b_image_token_tag=self.b_image_token_tag, ) return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 3cf5d1ecb6..1518d68748 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -182,6 +182,7 @@ def _context_attention_kernel( infer_state.b_ready_cache_len, infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_tag, ) o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) o3 = o3[:, :, : self.head_dim_].contiguous() diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py index f5dae493cb..42c3254e27 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -34,8 +34,10 @@ def _fwd_kernel( stride_req_to_tokens_s, kv_group_num, b_prompt_cache_len, + b_image_token_tag, H: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + QK_HEAD_DIM: tl.constexpr, + V_HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -53,16 +55,19 @@ def _fwd_kernel( cur_batch_req_idx = tl.load(B_req_idx + cur_batch) block_start_loc = BLOCK_M * start_m + if block_start_loc >= cur_batch_seq_len: + return offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d_qk = tl.arange(0, QK_HEAD_DIM) + offs_d_v = tl.arange(0, V_HEAD_DIM) offs_m = block_start_loc + tl.arange(0, BLOCK_M) # Q pointers off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh - + offs_d[None, :] * stride_qd + + offs_d_qk[None, :] * stride_qd ) q_valid = offs_m < cur_batch_seq_len @@ -71,24 +76,14 @@ def _fwd_kernel( # online softmax state m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) block_end_loc = total_len # absolute q positions in the request q_pos = prompt_cache_len + offs_m # [M] + q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False) - # q_gid from packed position_ids (aligned with Q rows) - q_gid = tl.load( - position_ids + cur_batch_in_all_start_index + offs_m, - mask=q_valid, - other=-2147483648, - ).to(tl.int32) - - BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid - - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + for start_n in range(0, block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) k_pos = start_n + offs_n # [N] @@ -102,32 +97,13 @@ def _fwd_kernel( ).to(tl.int64) # load K - off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk[:, None] * stride_kd k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) - - qk = tl.dot(q, k) - - # k_gid: - # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false - # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) - k_in_new = k_pos >= prompt_cache_len - k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new - k_gid_new = tl.load( - position_ids + cur_batch_in_all_start_index + k_new_idx, - mask=k_valid & k_in_new, - other=-2147483647, - ).to(tl.int32) - - k_gid = tl.where( - k_in_new, - k_gid_new, - (k_pos.to(tl.int32) + BIG), - ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # mask: causal OR same gid (only possible inside NEW part) - mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) - mask = mask & q_valid[:, None] & k_valid[None, :] - + mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None] qk = tl.where(mask, qk * sm_scale, -1.0e8) # online softmax @@ -141,7 +117,7 @@ def _fwd_kernel( acc = acc * alpha[:, None] # load V - off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d_v[None, :] * stride_vd v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) p = p.to(v.dtype) @@ -154,7 +130,7 @@ def _fwd_kernel( off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh - + offs_d[None, :] * stride_od + + offs_d_v[None, :] * stride_od ) tl.store(Out + off_o, acc, mask=q_valid[:, None]) @@ -172,6 +148,7 @@ def context_attention_fwd_neo( b_prompt_cache_len, max_input_len, req_to_token_indexs, + b_image_token_tag, ): # minimal safety: position_ids must cover packed q rows assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) @@ -220,8 +197,10 @@ def context_attention_fwd_neo( req_to_token_indexs.stride(1), kv_group_num=kv_group_num, b_prompt_cache_len=b_prompt_cache_len, + b_image_token_tag=b_image_token_tag, H=head, - BLOCK_DMODEL=Lk, + QK_HEAD_DIM=Lk, + V_HEAD_DIM=Lk // 2, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py index 955f48bd80..1a3d4af73b 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -16,6 +16,7 @@ def _get_neo_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, + b_image_token_tag: torch.Tensor, BLOCK_SIZE: tl.constexpr, ) -> torch.Tensor: cur_batch = tl.program_id(0) @@ -36,6 +37,13 @@ def _get_neo_position_triton( t_pos = local_image_start_idx + off * 0 h_pos = off // image_w w_pos = off % image_w + tl.store( + b_image_token_tag + off + image_start_idx, + True, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) tl.store( position_ids + off + image_start_idx, t_pos, @@ -87,6 +95,7 @@ def get_neo_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, + b_image_token_tag: torch.Tensor, ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] @@ -105,6 +114,7 @@ def get_neo_position_triton( b_ready_cache_len=b_ready_cache_len, b_q_seq_len=b_q_seq_len, b_start_loc=b_start_loc, + b_image_token_tag=b_image_token_tag, BLOCK_SIZE=BLOCK_SIZE, ) @@ -121,6 +131,7 @@ def test(): .expand(3, -1) .contiguous() ) + b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda") position_ids[1:].zero_() b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") @@ -135,8 +146,10 @@ def test(): b_ready_cache_len, b_q_seq_len, b_start_loc, + b_image_token_tag, ) + print(b_image_token_tag) print(position_ids) # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) @@ -172,3 +185,7 @@ def test(): [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], device='cuda:0', dtype=torch.int32) """ + + +if __name__ == "__main__": + test() diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py index aa008e18fb..fbd57a5e9c 100644 --- a/lightllm/models/neo_chat_moe/vision_process.py +++ b/lightllm/models/neo_chat_moe/vision_process.py @@ -136,6 +136,6 @@ def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=655 ) pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) - print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + # print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") return pixel_values, grid_hw From 02078ad1ded3babcd1b3c8d152e83ee1238ba7d9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Feb 2026 10:44:23 +0000 Subject: [PATCH 061/180] port random generate --- lightllm/utils/net_utils.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 486414e88e..51ec443d10 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -2,7 +2,6 @@ import subprocess import ipaddress import random -import portpicker from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -13,25 +12,37 @@ def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000) used_nccl_ports = [] port_list = [] + locked_sockets = [] + used_set = set(used_nccl_ports) + max_port = 65535 max_attempts = num * 50 # Allow more attempts to find ports in range for _ in range(max_attempts): if len(port_list) >= num: break - try: - port = portpicker.pick_unused_port() - - if port >= from_port_num and port not in used_nccl_ports: - port_list.append(port) - logger.debug(f"Allocated port: {port}") - else: - logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping") + # 在 [from_port_num, 65535] 范围内随机选端口,避免多进程同时启动时分配到相同端口 + port = random.randint(from_port_num, max_port) + if port in used_set: + continue - except Exception as e: - logger.warning(f"Failed to allocate port: {e}") + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind(("", port)) + port_list.append(port) + used_set.add(port) + locked_sockets.append(sock) + logger.debug(f"Allocated and locked port: {port}") + + except OSError as e: + sock.close() + logger.warning(f"Failed to bind port: {e}") continue + for sock in locked_sockets: + sock.close() + if len(port_list) < num: logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}") return None From 68954b02f49951c0f517a231a95bf5422caf55f2 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 9 Feb 2026 11:11:48 +0000 Subject: [PATCH 062/180] feat: add MoE expert routing capture for R3 rollout replay --- .gitignore | 1 + .../fused_moe/fused_moe_weight.py | 5 + .../fused_moe/gpt_oss_fused_moe_weight_tp.py | 8 + .../meta_weights/fused_moe/impl/base_impl.py | 2 + .../fused_moe/impl/triton_impl.py | 7 + lightllm/common/basemodel/routing_manager.py | 224 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 2 + .../layer_weights/transformer_layer_weight.py | 4 + lightllm/models/deepseek2/model.py | 4 + .../layer_infer/transformer_layer_infer.py | 1 + .../layer_weights/transformer_layer_weight.py | 1 + lightllm/models/gpt_oss/model.py | 7 + lightllm/models/llama/model.py | 16 +- .../models/mixtral/layer_infer/_custom_ops.py | 46 ---- .../layer_infer/transformer_layer_infer.py | 33 +-- .../layer_weights/transformer_layer_weight.py | 1 + lightllm/models/mixtral/model.py | 4 + .../layer_infer/transformer_layer_infer.py | 8 +- .../layer_weights/transformer_layer_weight.py | 6 + lightllm/models/qwen3_moe/model.py | 4 + lightllm/server/api_cli.py | 6 + lightllm/server/api_lightllm.py | 5 + lightllm/server/core/objs/req.py | 65 +++++ lightllm/server/core/objs/shm_array.py | 13 + lightllm/server/core/objs/start_args_type.py | 2 + lightllm/server/httpserver/manager.py | 18 ++ .../server/router/model_infer/infer_batch.py | 19 ++ .../model_infer/mode_backend/base_backend.py | 13 + .../mode_backend/chunked_prefill/impl.py | 4 + .../mode_backend/diverse_backend/impl.py | 2 +- .../mode_backend/dp_backend/impl.py | 13 +- test/test_api/test_r3.py | 99 ++++++++ unit_tests/__init__.py | 0 unit_tests/common/__init__.py | 0 unit_tests/common/basemodel/__init__.py | 0 .../basemodel/test_routing_capture_manager.py | 219 +++++++++++++++++ 36 files changed, 781 insertions(+), 81 deletions(-) create mode 100644 lightllm/common/basemodel/routing_manager.py delete mode 100644 lightllm/models/mixtral/layer_infer/_custom_ops.py create mode 100644 test/test_api/test_r3.py create mode 100644 unit_tests/__init__.py create mode 100644 unit_tests/common/__init__.py create mode 100644 unit_tests/common/basemodel/__init__.py create mode 100644 unit_tests/common/basemodel/test_routing_capture_manager.py diff --git a/.gitignore b/.gitignore index 63408699f4..3fb49db8b1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 77d6d40e9f..3dc888b6ac 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -33,6 +33,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: super().__init__(data_type=data_type) self.w1_weight_name = gate_proj_name @@ -50,6 +51,7 @@ def __init__( self.enable_ep_moe = get_env_start_args().enable_ep_moe self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts + self.moe_layer_index = moe_layer_index self._init_config(network_config) self._init_redundancy_expert_params() self._init_parallel_params() @@ -130,6 +132,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +148,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + moe_layer_index=self.moe_layer_index, + microbatch_index=microbatch_index, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b4..4ca1605be4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -8,6 +8,7 @@ from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -46,6 +47,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: network_config["norm_topk_prob"] = None super().__init__( @@ -62,6 +64,7 @@ def __init__( num_fused_shared_experts=num_fused_shared_experts, layer_num=layer_num, network_config=network_config, + moe_layer_index=moe_layer_index, ) self.hidden_size = network_config["hidden_size"] @@ -144,10 +147,15 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._router(router_logits, top_k) + # Rollout router replay + if _routing_mgr.g_routing_capture_manager is not None: + _routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac185..1c93cb13dc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 8bcdb4bf90..1e81b226ec 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -3,6 +3,7 @@ from lightllm.common.quantization.no_quant import WeightPack from lightllm.common.quantization.quantize_method import QuantizationMethod from .base_impl import FuseMoeBaseImpl +from lightllm.common.basemodel import routing_manager as _routing_mgr class FuseMoeTriton(FuseMoeBaseImpl): @@ -124,6 +125,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -136,6 +139,10 @@ def __call__( num_expert_group=num_expert_group, scoring_func=scoring_func, ) + + if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None: + _routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index) + output = self._fused_experts( input_tensor=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 0000000000..9b8c09d8cb --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,224 @@ +import atexit +import torch +import numpy as np +from multiprocessing import shared_memory +from typing import Optional +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_rank_in_dp +from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.shm_utils import create_or_link_shm + +logger = init_logger(__name__) + + +def routing_dtype_id_to_np(dtype_id: int): + if dtype_id == 1: + return np.int8 + elif dtype_id == 2: + return np.int16 + return np.int32 + + +def get_routing_config_shm() -> SharedArray: + service_name = get_unique_server_name() + return SharedArray(f"{service_name}_routing_config", shape=(4,), dtype=np.int32) + + +class RoutingCaptureManager: + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, + ): + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.kv_cache_size = kv_cache_size + + self.dtype = torch.int8 if num_experts <= 127 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.int8 else 2 + + # Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory. + # Written after forward() via flush_to_routing_buffer(), read on request finish. + routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.routing_buffer = torch.zeros( + (num_moe_layers, kv_cache_size, topk), + dtype=self.dtype, + device="cpu", + ) + + # Capture buffers: simple contiguous tensors written to during forward(). + capture_buf_size = max_capture_tokens * num_moe_layers * topk * dtype_bytes + self._capture_buffer = [ + torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2) + ] + + dtype_name = "int8" if self.dtype == torch.int8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, " + f"capture_buffer={capture_buf_size / 1024 / 1024:.2f}MB x2, dtype={dtype_name}" + ) + + @property + def np_dtype(self): + return np.int8 if self.dtype == torch.int8 else np.int16 + + @property + def dtype_id(self) -> int: + return 1 if self.dtype == torch.int8 else 2 + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + num_tokens = topk_ids.shape[0] + self._capture_buffer[microbatch_index][:num_tokens, moe_layer_index, :] = topk_ids.to(self.dtype) + + def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None: + buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk) + buf_t = buf.permute(1, 0, 2).cpu() + self.routing_buffer[:, mem_indexes[:num_tokens].cpu(), :] = buf_t + + def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray: + cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes + return self.routing_buffer[:, cpu_indexes, :].numpy() + + +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, +) -> None: + global g_routing_capture_manager + assert g_routing_capture_manager is None, "RoutingCaptureManager already exists" + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=kv_cache_size, + max_capture_tokens=max_capture_tokens, + ) + + +def preallocate_routing_shm_pool(max_req_num: int, num_moe_layers: int, max_tokens: int, topk: int, np_dtype) -> None: + """Pre-allocate POSIX SHM segments for all request slots. + + Each segment is sized for the maximum possible routing data so it can be + reused across requests without create/destroy overhead. + """ + dtype_bytes = np.dtype(np_dtype).itemsize + segment_size = num_moe_layers * max_tokens * topk * dtype_bytes + service_name = get_unique_server_name() + + for i in range(max_req_num): + name = f"{service_name}_shm_routing_{i}" + shm = create_or_link_shm(name, segment_size, auto_cleanup=True) + shm.close() # close handle; SHM persists in /dev/shm + + logger.info( + f"Pre-allocated {max_req_num} routing SHM segments, " + f"each {segment_size / 1024:.1f} KB (total {max_req_num * segment_size / 1024 / 1024:.1f} MB)" + ) + + +def cleanup_routing_shm_pool() -> None: + """Unlink all pre-allocated routing SHM segments. Called at server shutdown.""" + try: + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + except Exception: + return + + service_name = get_unique_server_name() + + for i in range(args.running_max_req_size): + name = f"{service_name}_shm_routing_{i}" + try: + shm = shared_memory.SharedMemory(name=name) + shm.close() + shm.unlink() + except Exception: + pass + + config_name = f"{service_name}_routing_config" + try: + shm = shared_memory.SharedMemory(name=config_name) + shm.close() + shm.unlink() + except Exception: + pass + + +def init_routing_capture(model, num_moe_layers: int) -> None: + dp_rank = get_current_rank_in_dp() + logger.info(f"init_routing_capture called: num_moe_layers={num_moe_layers}, dp_rank={dp_rank}") + if dp_rank != 0: + logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}") + return + + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled." + ) + return + + num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) + topk = model.config.get("num_experts_per_tok", 0) + assert num_experts > 0 and topk > 0 + + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + + # Capture buffer must fit the max tokens in any single forward call. + # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size. + batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192 + max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=model.mem_manager.size + 1, + max_capture_tokens=max_capture_tokens, + ) + + mgr = g_routing_capture_manager + np_dtype = mgr.np_dtype + dtype_id = mgr.dtype_id + + max_req_total_len = args.max_req_total_len + + # Write config to cross-process SHM + shm = get_routing_config_shm() + shm.arr[0] = num_moe_layers + shm.arr[1] = topk + shm.arr[2] = dtype_id + shm.arr[3] = max_req_total_len + logger.info( + f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, " + f"dtype_id={dtype_id}, max_tokens={max_req_total_len}" + ) + + preallocate_routing_shm_pool( + max_req_num=args.running_max_req_size, + num_moe_layers=num_moe_layers, + max_tokens=max_req_total_len, + topk=topk, + np_dtype=np_dtype, + ) + + atexit.register(cleanup_routing_shm_pool) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 98cc7c229e..97015f6b20 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -312,6 +312,7 @@ def _moe_ffn( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -339,6 +340,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..bd72035072 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -242,6 +242,9 @@ def _init_moe(self): # == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。 if self.num_fused_shared_experts == 0: self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + first_moe = self.network_config_["first_k_dense_replace"] + freq = self.network_config_.get("moe_layer_freq", 1) + moe_layer_index = (self.layer_num_ - first_moe) // freq self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -256,6 +259,7 @@ def _init_moe(self): num_fused_shared_experts=self.num_fused_shared_experts, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_ffn(self): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..79bd327068 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -49,6 +50,9 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d80eefd16e..e5672f8210 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7c8c30940e..7278c62fec 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -55,6 +55,7 @@ def _init_moe(self): num_fused_shared_experts=0, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) def _init_weight_names(self): diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 9e9561eb24..cff748933d 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -21,6 +22,12 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) + def _init_custom(self): + super()._init_custom() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) + def _init_att_backend(self): self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])( model=self diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index f86bd5f83d..cc1dc28178 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -74,16 +74,19 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() - if "rope_theta_hw" in self.config: - self._init_to_get_hw_rotary() - return - - if "rope_type" in rope_scaling: + elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() + super()._init_custom() + + def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -98,9 +101,6 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - if "rope_theta_hw" in self.config: - self._init_to_get_hw_rotary() - return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py deleted file mode 100644 index b0e27ac1de..0000000000 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -# Pytorch version -# Triton version in progress -def topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output, - topk=2, -): - scores = torch.softmax(gating_output, dim=-1) - topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False) - return topk_weights, topk_ids - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - alloc_tensor_func=torch.empty, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 44e66cff2d..a2968f5ab1 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -19,25 +16,15 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_)) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, - ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - return fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index 51c62fd4cb..d93cb5fb58 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -57,4 +57,5 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e87..35bf38de58 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -45,6 +46,9 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) return def _init_mem_manager(self): diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 9eccddffc1..af035e81b6 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -93,8 +93,10 @@ def _tpsp_get_qkv( input = gather_input[0 : len(infer_state.input_ids), :] input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input) + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) layer_weight.q_norm_weight_(q, eps=self.eps_) layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -130,6 +132,7 @@ def _moe_ffn( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) @@ -150,6 +153,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 13ba6cbe0f..5a857fd093 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -52,6 +52,11 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, ) + mlp_only = set(self.network_config_.get("mlp_only_layers", [])) + step = self.network_config_.get("decoder_sparse_step", 1) + moe_layer_index = sum( + 1 for i in range(self.layer_num_) if self.n_routed_experts > 0 and i not in mlp_only and (i + 1) % step == 0 + ) self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -65,6 +70,7 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_qkv(self): diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a5051276..2926a12b1f 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -4,6 +4,7 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -26,3 +27,6 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 73401f1633..409460feb0 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -644,4 +644,10 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f54..5abd90815b 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -53,6 +53,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +79,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +105,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 7df5ba74e8..4a33b659b0 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -1,6 +1,7 @@ import os import math import ctypes +import base64 import numpy as np import time from .sampling_params import SamplingParams @@ -13,6 +14,7 @@ from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union from lightllm.utils.log_utils import init_logger +from lightllm.utils.shm_utils import create_or_link_shm logger = init_logger(__name__) @@ -125,6 +127,8 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), + # Number of tokens in routing data SHM, written by model worker, read by HTTP server. + ("shm_routing_num_tokens", ctypes.c_int), ] def get_str(self): @@ -182,6 +186,7 @@ def init( self._mtp_step = get_env_start_args().mtp_step self.stop_str_matched = False self.stop_str_matched_token_index = -1 + self.shm_routing_num_tokens = 0 self.post_init() @@ -230,6 +235,66 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8): + """Link to a pre-allocated routing SHM and create a numpy view for the actual data shape.""" + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_moe_layers, num_tokens, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.link_shm_partial() + self.shm_routing_num_tokens = num_tokens + return + + def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8): + """Link to the pre-allocated routing SHM from the reader side (HTTP server).""" + if num_moe_layers == 0: + return + num_tokens = self.shm_routing_num_tokens + if num_tokens <= 0: + return + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_moe_layers, num_tokens, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.link_shm_partial() + return + + def get_routing_data(self): + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + return None + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + """Detach from pre-allocated SHM without unlinking it.""" + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.detach_shm() + self.shm_routing_data = None + self.shm_routing_num_tokens = 0 + return + + def get_routing_metadata(self, num_moe_layers: int, topk: int, dtype_id: int = 1): + if num_moe_layers == 0 or topk == 0: + return None + if self.shm_routing_num_tokens <= 0: + return None + try: + from lightllm.common.basemodel.routing_manager import routing_dtype_id_to_np + + np_dtype = routing_dtype_id_to_np(dtype_id) + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + self.link_routing_data_shm_array(num_moe_layers, topk, np_dtype=np_dtype) + routing_data = self.get_routing_data() + if routing_data is None: + return None + return { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + except Exception as e: + logger.warning(f"Failed to read routing data for req {self.request_id}: {e}") + return None + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py index c5ad512c6b..1bf20535ad 100644 --- a/lightllm/server/core/objs/shm_array.py +++ b/lightllm/server/core/objs/shm_array.py @@ -26,6 +26,19 @@ def link_shm(self): self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) return + def link_shm_partial(self): + """Link to an existing SHM that may be larger than the needed shape.""" + self.shm = create_or_link_shm(self.name, -1, force_mode="link") + assert self.shm.size >= self.dest_size, f"SHM {self.name} too small: need {self.dest_size}, got {self.shm.size}" + self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + + def detach_shm(self): + """Close handle without unlinking (SHM persists for reuse).""" + if self.shm is not None: + self.shm.close() + self.shm = None + self.arr = None + def close_shm(self): if self.shm is not None: self.shm.close() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 2cb12ed89d..4ac0a4dd2b 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -184,4 +184,6 @@ class StartArgs: enable_torch_fallback: bool = field(default=False) enable_triton_fallback: bool = field(default=False) + enable_return_routed_experts: bool = field(default=False) + weight_version: str = "default" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index b6a7b0d127..3ef778ca42 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -29,6 +29,7 @@ from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.basemodel.routing_manager import get_routing_config_shm from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.server.io_struct import ( @@ -139,6 +140,9 @@ def __init__( self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + # Cache routing config for MoE expert routing data extraction + self._routing_shm = get_routing_config_shm() if args.enable_return_routed_experts else None + self.is_pause = False self.is_pause_cond = asyncio.Condition() @@ -769,6 +773,11 @@ async def recycle_resource_loop(self): for req_status in release_req_status: self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) for req in req_status.group_req_objs.shm_req_objs: + if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) @@ -868,6 +877,15 @@ async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta): else: finish_status = FinishStatus(req.finish_status.status) + if self._routing_shm is not None: + _num_moe = int(self._routing_shm.arr[0]) + _topk = int(self._routing_shm.arr[1]) + _dtype_id = int(self._routing_shm.arr[2]) + if _num_moe > 0: + routing_meta = req.get_routing_metadata(_num_moe, _topk, dtype_id=_dtype_id) + if routing_meta is not None: + metadata["routed_experts"] = routing_meta + token_list.append((req_id, text, metadata, finish_status)) else: break diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 2b35fad053..66aeb6e95d 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -19,6 +19,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -113,6 +114,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq"): + if req.shm_req.shm_routing_num_tokens > 0: + return + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + mgr = _routing_mgr.g_routing_capture_manager + routing_data = mgr.extract_routing_data(mem_indexes) + req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype) + req.shm_req.shm_routing_data.arr[:] = routing_data + req.shm_req.shm_routing_data.detach_shm() + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -149,12 +160,18 @@ def _filter(self, finished_request_ids: List[int]): if len(finished_request_ids) == 0: return + need_routing_data = _routing_mgr.g_routing_capture_manager is not None + free_req_index = [] free_token_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() + + if need_routing_data: + self._extract_routing_data(req) + self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) @@ -588,6 +605,8 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + if _routing_mgr.g_routing_capture_manager is not None: + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 2d0e4b14b7..ff2ea8c21e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -44,6 +44,7 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.common.basemodel import routing_manager as _routing_mgr from lightllm.utils.torch_memory_saver_utils import MemoryTag from .multi_level_kv_cache import MultiLevelKvCacheModule from lightllm.server.io_struct import ( @@ -996,6 +997,18 @@ def _sample_and_scatter_token( ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + """Scatter captured routing data from capture buffer to KV-indexed GPU buffer. + + Must be called AFTER model.forward() completes. mem_indexes should be the + original (unpadded) tensor — either CPU or CUDA. + """ + if _routing_mgr.g_routing_capture_manager is not None and mem_indexes is not None: + if not mem_indexes.is_cuda: + mem_indexes = mem_indexes.cuda(non_blocking=True) + num_tokens = mem_indexes.shape[0] + _routing_mgr.g_routing_capture_manager.flush_to_kv_buffer(mem_indexes, num_tokens, microbatch_index) + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224ebc..9f4443e48e 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -109,6 +109,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -148,6 +149,7 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -186,6 +188,7 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -236,6 +239,7 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb620..ebc55b7ef4 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -40,8 +40,8 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq ) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) logits = model_output.logits batch_idx, run_reqs = self._diverse_copy( diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e76..f01e5fe935 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -145,6 +145,7 @@ def prefill_normal( run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -188,6 +189,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -236,6 +238,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -305,6 +309,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -359,6 +365,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] b_req_idx = model_input.b_req_idx[0:req_num] @@ -421,6 +428,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: logits = model_output.logits[0:req_num, :] @@ -629,6 +637,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits req_num0, req_num1 = len(run_reqs0), len(run_reqs1) @@ -726,8 +736,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits run_reqs = run_reqs0 + run_reqs1 diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py new file mode 100644 index 0000000000..00f34c4893 --- /dev/null +++ b/test/test_api/test_r3.py @@ -0,0 +1,99 @@ +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France? What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + "return_routed_experts": True, + "repetition_penalty": 1.0, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype_str = routing_info["dtype"] + dtype = np.dtype(dtype_str) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape}") + print(f"Dtype: {dtype}") + print(f"Num MoE layers: {shape[0]}") + print(f"Num tokens: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Verify dtype is int8 (for models with ≤127 experts) or int16 + if dtype_str not in ("int8", "int16"): + print(f"\nERROR: Expected dtype int8 or int16, got {dtype_str}") + print("This suggests dtype optimization is not working correctly.") + return False + print(f"\nDtype check PASSED: {dtype_str} (compact representation)") + + # Compute payload size savings + int32_size = np.prod(shape) * 4 + actual_size = len(data) + savings = (1 - actual_size / int32_size) * 100 + print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)") + + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = shape[1] + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") + + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 0000000000..dcc010b372 --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,219 @@ +import torch +import numpy as np + + +class TestRoutingCaptureManager: + def test_capture_and_extract_basic(self): + """Test the core pipeline: capture → flush_to_kv_buffer → extract_from_gpu.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=64, + ) + + # Simulate a batch of 10 tokens at KV-cache positions [100..109] + mem_indexes = torch.arange(100, 110, device="cuda") + + # Capture routing for each MoE layer (writes to capture buffer) + for layer_idx in range(4): + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=layer_idx, topk_ids=topk_ids, microbatch_index=0) + + # Flush from capture buffer to KV-indexed gpu_kv_buffer + manager.flush_to_kv_buffer(mem_indexes, num_tokens=10, microbatch_index=0) + + # Extract for those same KV-cache positions + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (4, 10, 8) + assert result.dtype == np.int8 + + def test_capture_writes_to_correct_kv_positions(self): + """Verify that captured data lands in the right KV-cache positions after flush.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Use non-contiguous mem_indexes to simulate real KV-cache + mem_indexes = torch.tensor([10, 50, 200], device="cuda") + + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + topk_ids_layer1 = topk_ids + 20 + manager.capture(moe_layer_index=1, topk_ids=topk_ids_layer1, microbatch_index=0) + + # Flush to KV positions + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + + # Extract and verify + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (2, 3, 4) + np.testing.assert_array_equal(result[0], topk_ids.cpu().numpy()) + np.testing.assert_array_equal(result[1], topk_ids_layer1.cpu().numpy()) + + def test_microbatch_isolation(self): + """Two microbatches writing to different KV positions don't interfere.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Microbatch 0: positions [10, 11] + mem0 = torch.tensor([10, 11], device="cuda") + ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + # Microbatch 1: positions [20, 21] + mem1 = torch.tensor([20, 21], device="cuda") + ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + # Flush each microbatch to different KV positions + manager.flush_to_kv_buffer(mem0, num_tokens=2, microbatch_index=0) + manager.flush_to_kv_buffer(mem1, num_tokens=2, microbatch_index=1) + + # Extract microbatch 0 + result0 = manager.extract_from_gpu(mem0) + assert result0.shape == (1, 2, 4) + assert result0[0, 0, 0] == 1 + + # Extract microbatch 1 + result1 = manager.extract_from_gpu(mem1) + assert result1[0, 0, 0] == 2 + + def test_dtype_selection_int8(self): + """Models with ≤127 experts use int8.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int8 + assert manager.np_dtype == np.int8 + assert manager.dtype_id == 1 + + def test_dtype_selection_int16(self): + """Models with >127 experts use int16.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int16 + assert manager.np_dtype == np.int16 + assert manager.dtype_id == 2 + + def test_extract_preserves_values(self): + """Extracted values exactly match what was captured, no dtype truncation.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=64, + kv_cache_size=64, + max_capture_tokens=16, + ) + + mem_indexes = torch.tensor([0, 1, 2], device="cuda") + + topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 127, 3]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush then extract + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + result = manager.extract_from_gpu(mem_indexes) + expected = topk_ids.cpu().numpy().astype(np.int8) + np.testing.assert_array_equal(result[0], expected) + + def test_gpu_kv_buffer_shape(self): + """Buffer shape is (num_moe_layers, kv_cache_size, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + # 127 experts fits in int8 (max value 127) + manager = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=127, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager.gpu_kv_buffer.shape == (48, 2048, 8) + assert manager.gpu_kv_buffer.dtype == torch.int8 + assert manager.gpu_kv_buffer.device.type == "cuda" + + # 128 experts requires int16 + manager2 = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=128, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager2.gpu_kv_buffer.dtype == torch.int16 + + def test_partial_token_capture(self): + """capture() only writes num_tokens rows to the buffer.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=32, + kv_cache_size=128, + max_capture_tokens=16, + ) + + # Capture only 3 tokens, flush to 5 KV positions (first 3 get data) + mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda") + + topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda") # only 3 tokens + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush only the 3 captured tokens + manager.flush_to_kv_buffer(mem_indexes[:3], num_tokens=3, microbatch_index=0) + + # Positions 10-12 should have data, 13-14 should be zeros (from init) + result_written = manager.extract_from_gpu(mem_indexes[:3]) + np.testing.assert_array_equal(result_written[0], topk_ids.cpu().numpy().astype(np.int8)) + + result_unwritten = manager.extract_from_gpu(mem_indexes[3:]) + np.testing.assert_array_equal(result_unwritten[0], np.zeros((2, 2), dtype=np.int8)) + + def test_capture_buffer_shape(self): + """Capture buffer has correct shape (max_tokens, num_moe_layers, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=256, + ) + assert manager._capture_buffer[0].shape == (256, 4, 8) + assert manager._capture_buffer[1].shape == (256, 4, 8) + assert manager._capture_buffer[0].dtype == torch.int8 From 3569d53a6acad7d886638679f56b355f2fc5cddf Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 9 Feb 2026 11:20:43 +0000 Subject: [PATCH 063/180] fix --- lightllm/server/router/model_infer/mode_backend/base_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ff2ea8c21e..70b0ec9ebf 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -1007,7 +1007,7 @@ def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_inde if not mem_indexes.is_cuda: mem_indexes = mem_indexes.cuda(non_blocking=True) num_tokens = mem_indexes.shape[0] - _routing_mgr.g_routing_capture_manager.flush_to_kv_buffer(mem_indexes, num_tokens, microbatch_index) + _routing_mgr.g_routing_capture_manager.flush_to_routing_buffer(mem_indexes, num_tokens, microbatch_index) def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] From fe54253e8de8aa6b0b4bacc8a9ce2b0240f5d255 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Feb 2026 11:29:05 +0000 Subject: [PATCH 064/180] add node-id for env_utils --- lightllm/utils/envs_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 3a0e28bcb6..59315108a9 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -11,15 +11,18 @@ def set_unique_server_name(args): if args.run_mode == "pd_master": - os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" + os.environ[f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}"] = str(args.port) + "_pd_master" else: - os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) + os.environ[f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}"] = ( + str(args.nccl_port) + "_" + str(args.node_rank) + ) return @lru_cache(maxsize=None) def get_unique_server_name(): - service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") + args = get_env_start_args() + service_uni_name = os.getenv(f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}") return service_uni_name @@ -33,7 +36,7 @@ def set_env_start_args(args): set_cuda_arch(args) if not isinstance(args, dict): args = vars(args) - os.environ["LIGHTLLM_START_ARGS"] = json.dumps(args) + os.environ[f"LIGHTLLM_START_ARGS_{args.pd_node_id}"] = json.dumps(args) return @@ -41,7 +44,8 @@ def set_env_start_args(args): def get_env_start_args(): from lightllm.server.core.objs.start_args_type import StartArgs - start_args: StartArgs = json.loads(os.environ["LIGHTLLM_START_ARGS"]) + args = get_env_start_args() + start_args: StartArgs = json.loads(os.environ[f"LIGHTLLM_START_ARGS_{args.pd_node_id}"]) start_args: StartArgs = EasyDict(start_args) return start_args From 8eead2b14c492d928c6488560b7e07765ea2b516 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 9 Feb 2026 11:58:52 +0000 Subject: [PATCH 065/180] Revert "add node-id for env_utils" This reverts commit fe54253e8de8aa6b0b4bacc8a9ce2b0240f5d255. --- lightllm/utils/envs_utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 59315108a9..3a0e28bcb6 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -11,18 +11,15 @@ def set_unique_server_name(args): if args.run_mode == "pd_master": - os.environ[f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}"] = str(args.port) + "_pd_master" + os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" else: - os.environ[f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}"] = ( - str(args.nccl_port) + "_" + str(args.node_rank) - ) + os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) return @lru_cache(maxsize=None) def get_unique_server_name(): - args = get_env_start_args() - service_uni_name = os.getenv(f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}") + service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") return service_uni_name @@ -36,7 +33,7 @@ def set_env_start_args(args): set_cuda_arch(args) if not isinstance(args, dict): args = vars(args) - os.environ[f"LIGHTLLM_START_ARGS_{args.pd_node_id}"] = json.dumps(args) + os.environ["LIGHTLLM_START_ARGS"] = json.dumps(args) return @@ -44,8 +41,7 @@ def set_env_start_args(args): def get_env_start_args(): from lightllm.server.core.objs.start_args_type import StartArgs - args = get_env_start_args() - start_args: StartArgs = json.loads(os.environ[f"LIGHTLLM_START_ARGS_{args.pd_node_id}"]) + start_args: StartArgs = json.loads(os.environ["LIGHTLLM_START_ARGS"]) start_args: StartArgs = EasyDict(start_args) return start_args From 27f9e87d5edf971c4ffe9f7f7b969b672fd07684 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 9 Feb 2026 11:58:58 +0000 Subject: [PATCH 066/180] Revert "port random generate" This reverts commit 02078ad1ded3babcd1b3c8d152e83ee1238ba7d9. --- lightllm/utils/net_utils.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 51ec443d10..486414e88e 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -2,6 +2,7 @@ import subprocess import ipaddress import random +import portpicker from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -12,36 +13,24 @@ def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000) used_nccl_ports = [] port_list = [] - locked_sockets = [] - used_set = set(used_nccl_ports) - max_port = 65535 max_attempts = num * 50 # Allow more attempts to find ports in range for _ in range(max_attempts): if len(port_list) >= num: break - # 在 [from_port_num, 65535] 范围内随机选端口,避免多进程同时启动时分配到相同端口 - port = random.randint(from_port_num, max_port) - if port in used_set: - continue - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: - sock.bind(("", port)) - port_list.append(port) - used_set.add(port) - locked_sockets.append(sock) - logger.debug(f"Allocated and locked port: {port}") - - except OSError as e: - sock.close() - logger.warning(f"Failed to bind port: {e}") - continue + port = portpicker.pick_unused_port() - for sock in locked_sockets: - sock.close() + if port >= from_port_num and port not in used_nccl_ports: + port_list.append(port) + logger.debug(f"Allocated port: {port}") + else: + logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping") + + except Exception as e: + logger.warning(f"Failed to allocate port: {e}") + continue if len(port_list) < num: logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}") From 6fa8f74e473632f598da4ec96b1d841a374dd0b2 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Feb 2026 13:03:14 +0000 Subject: [PATCH 067/180] add assert none --- lightllm/utils/envs_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 3a0e28bcb6..03816d3ab0 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -13,6 +13,7 @@ def set_unique_server_name(args): if args.run_mode == "pd_master": os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" else: + assert str(args.nccl_port) != "None", "nccl_port is not set" os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) return From bf83078ae3e29547ccc9d823c6864102f119ed96 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 10 Feb 2026 05:00:10 +0000 Subject: [PATCH 068/180] set_unique_server_name --- lightllm/server/api_start.py | 3 +-- lightllm/utils/envs_utils.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index afe199d04f..58dac941b0 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -91,6 +91,7 @@ def _set_envs_and_config(args: StartArgs): def _launch_subprocesses(args: StartArgs): _set_envs_and_config(args) + set_unique_server_name(args) if not args.disable_shm_warning: check_recommended_shm_size(args) @@ -290,8 +291,6 @@ def _launch_subprocesses(args: StartArgs): args.pd_p_allowed_port_min = 20000 args.pd_p_allowed_port_max = 30000 - set_unique_server_name(args) - # p d 分离模式下,decode节点的调度间隙是0 if args.run_mode == "decode": args.router_max_wait_tokens = 0 diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 03816d3ab0..a702a465b2 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -21,6 +21,8 @@ def set_unique_server_name(args): @lru_cache(maxsize=None) def get_unique_server_name(): service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") + assert "None" not in service_uni_name, "service_uni_name is not set" + return service_uni_name From 3eab5a746162ab6d50172df61791f614848d309d Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 10 Feb 2026 05:45:48 +0000 Subject: [PATCH 069/180] fix return_routed_experts --- lightllm/common/basemodel/routing_manager.py | 12 +++--- lightllm/server/api_lightllm.py | 6 ++- lightllm/server/core/objs/sampling_params.py | 39 +++++++++++--------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py index 9b8c09d8cb..01caa36662 100644 --- a/lightllm/common/basemodel/routing_manager.py +++ b/lightllm/common/basemodel/routing_manager.py @@ -14,7 +14,7 @@ def routing_dtype_id_to_np(dtype_id: int): if dtype_id == 1: - return np.int8 + return np.uint8 elif dtype_id == 2: return np.int16 return np.int32 @@ -39,8 +39,8 @@ def __init__( self.num_experts = num_experts self.kv_cache_size = kv_cache_size - self.dtype = torch.int8 if num_experts <= 127 else torch.int16 - dtype_bytes = 1 if self.dtype == torch.int8 else 2 + self.dtype = torch.uint8 if num_experts <= 255 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.uint8 else 2 # Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory. # Written after forward() via flush_to_routing_buffer(), read on request finish. @@ -57,7 +57,7 @@ def __init__( torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2) ] - dtype_name = "int8" if self.dtype == torch.int8 else "int16" + dtype_name = "uint8" if self.dtype == torch.uint8 else "int16" logger.info( f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, " @@ -66,11 +66,11 @@ def __init__( @property def np_dtype(self): - return np.int8 if self.dtype == torch.int8 else np.int16 + return np.uint8 if self.dtype == torch.uint8 else np.int16 @property def dtype_id(self) -> int: - return 1 if self.dtype == torch.int8 else 2 + return 1 if self.dtype == torch.uint8 else 2 def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: num_tokens = topk_ids.shape[0] diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 5abd90815b..d15bec6485 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] return_details = sample_params_dict.pop("return_details", False) + return_routed_experts = sample_params_dict.pop( + "return_routed_experts", httpserver_manager.args.enable_return_routed_experts + ) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() @@ -105,7 +108,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage - if routed_experts_data is not None: + if return_routed_experts and routed_experts_data is not None: ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) @@ -117,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] _ = sample_params_dict.pop("return_details", False) + _ = sample_params_dict.pop("return_routed_experts", None) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 93447830bb..1c06428622 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -357,23 +357,28 @@ class SamplingParams(ctypes.Structure): def init(self, tokenizer, **kwargs): super().__init__() - self.best_of = kwargs.get("best_of", 1) - self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) - self.ignore_eos = kwargs.get("ignore_eos", False) - self.min_pixels = kwargs.get("min_pixels", -1) - self.max_pixels = kwargs.get("max_pixels", -1) - self.max_new_tokens = kwargs.get("max_new_tokens", 16) - self.min_new_tokens = kwargs.get("min_new_tokens", 1) - self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) - self.group_request_id = kwargs.get("group_request_id", -1) - self.suggested_dp_index = kwargs.get("suggested_dp_index", -1) + + def _get(key, default): + v = kwargs.get(key) + return v if v is not None else default + + self.best_of = _get("best_of", 1) + self.n = _get("n", self.best_of) + self.do_sample = _get("do_sample", SamplingParams._do_sample) + self.presence_penalty = _get("presence_penalty", SamplingParams._presence_penalty) + self.frequency_penalty = _get("frequency_penalty", SamplingParams._frequency_penalty) + self.repetition_penalty = _get("repetition_penalty", SamplingParams._repetition_penalty) + self.temperature = _get("temperature", SamplingParams._temperature) + self.top_p = _get("top_p", SamplingParams._top_p) + self.top_k = _get("top_k", SamplingParams._top_k) + self.ignore_eos = _get("ignore_eos", False) + self.min_pixels = _get("min_pixels", -1) + self.max_pixels = _get("max_pixels", -1) + self.max_new_tokens = _get("max_new_tokens", 16) + self.min_new_tokens = _get("min_new_tokens", 1) + self.input_penalty = _get("input_penalty", DEFAULT_INPUT_PENALTY) + self.group_request_id = _get("group_request_id", -1) + self.suggested_dp_index = _get("suggested_dp_index", -1) self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) From 14cfc9511715e6f72b56d17279f8a41e5fb67ec6 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 10 Feb 2026 11:04:06 +0000 Subject: [PATCH 070/180] fix r3 --- lightllm/common/basemodel/routing_manager.py | 41 ++----------------- .../server/core/objs/py_sampling_params.py | 19 +++++---- lightllm/server/core/objs/req.py | 19 +++++---- lightllm/server/core/objs/sampling_params.py | 19 +++++---- test/test_api/test_r3.py | 19 +++------ 5 files changed, 45 insertions(+), 72 deletions(-) diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py index 01caa36662..77b611130f 100644 --- a/lightllm/common/basemodel/routing_manager.py +++ b/lightllm/common/basemodel/routing_manager.py @@ -7,7 +7,6 @@ from lightllm.utils.dist_utils import get_current_rank_in_dp from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.utils.shm_utils import create_or_link_shm logger = init_logger(__name__) @@ -42,11 +41,11 @@ def __init__( self.dtype = torch.uint8 if num_experts <= 255 else torch.int16 dtype_bytes = 1 if self.dtype == torch.uint8 else 2 - # Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory. + # Shape: (kv_cache_size, num_moe_layers, topk) — on CPU to save GPU memory. # Written after forward() via flush_to_routing_buffer(), read on request finish. routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes self.routing_buffer = torch.zeros( - (num_moe_layers, kv_cache_size, topk), + (kv_cache_size, num_moe_layers, topk), dtype=self.dtype, device="cpu", ) @@ -78,12 +77,11 @@ def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None: buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk) - buf_t = buf.permute(1, 0, 2).cpu() - self.routing_buffer[:, mem_indexes[:num_tokens].cpu(), :] = buf_t + self.routing_buffer[mem_indexes[:num_tokens].cpu(), :, :] = buf.cpu() def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray: cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes - return self.routing_buffer[:, cpu_indexes, :].numpy() + return self.routing_buffer[cpu_indexes, :, :].numpy() g_routing_capture_manager: Optional[RoutingCaptureManager] = None @@ -107,27 +105,6 @@ def create_routing_capture_manager( ) -def preallocate_routing_shm_pool(max_req_num: int, num_moe_layers: int, max_tokens: int, topk: int, np_dtype) -> None: - """Pre-allocate POSIX SHM segments for all request slots. - - Each segment is sized for the maximum possible routing data so it can be - reused across requests without create/destroy overhead. - """ - dtype_bytes = np.dtype(np_dtype).itemsize - segment_size = num_moe_layers * max_tokens * topk * dtype_bytes - service_name = get_unique_server_name() - - for i in range(max_req_num): - name = f"{service_name}_shm_routing_{i}" - shm = create_or_link_shm(name, segment_size, auto_cleanup=True) - shm.close() # close handle; SHM persists in /dev/shm - - logger.info( - f"Pre-allocated {max_req_num} routing SHM segments, " - f"each {segment_size / 1024:.1f} KB (total {max_req_num * segment_size / 1024 / 1024:.1f} MB)" - ) - - def cleanup_routing_shm_pool() -> None: """Unlink all pre-allocated routing SHM segments. Called at server shutdown.""" try: @@ -197,7 +174,6 @@ def init_routing_capture(model, num_moe_layers: int) -> None: ) mgr = g_routing_capture_manager - np_dtype = mgr.np_dtype dtype_id = mgr.dtype_id max_req_total_len = args.max_req_total_len @@ -212,13 +188,4 @@ def init_routing_capture(model, num_moe_layers: int) -> None: f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, " f"dtype_id={dtype_id}, max_tokens={max_req_total_len}" ) - - preallocate_routing_shm_pool( - max_req_num=args.running_max_req_size, - num_moe_layers=num_moe_layers, - max_tokens=max_req_total_len, - topk=topk, - np_dtype=np_dtype, - ) - atexit.register(cleanup_routing_shm_pool) diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 9194a235da..08921317e8 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -112,13 +112,18 @@ def __init__( def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) cls._stop_sequences = generation_cfg.get("stop", None) except: pass diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 4a33b659b0..5c7e56843b 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -236,17 +236,20 @@ def link_logprobs_shm_array(self): return def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8): - """Link to a pre-allocated routing SHM and create a numpy view for the actual data shape.""" + """Create routing SHM at actual size (on-demand, not pre-allocated). + + Uses smart mode: links if same-sized SHM exists, otherwise creates new. + """ service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" - shape = (num_moe_layers, num_tokens, topk) + shape = (num_tokens, num_moe_layers, topk) self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) - self.shm_routing_data.link_shm_partial() + self.shm_routing_data.create_shm() self.shm_routing_num_tokens = num_tokens return def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8): - """Link to the pre-allocated routing SHM from the reader side (HTTP server).""" + """Link to routing SHM from the reader side (HTTP server).""" if num_moe_layers == 0: return num_tokens = self.shm_routing_num_tokens @@ -254,9 +257,9 @@ def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=n return service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" - shape = (num_moe_layers, num_tokens, topk) + shape = (num_tokens, num_moe_layers, topk) self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) - self.shm_routing_data.link_shm_partial() + self.shm_routing_data.link_shm() return def get_routing_data(self): @@ -265,9 +268,9 @@ def get_routing_data(self): return self.shm_routing_data.arr def close_routing_data_shm_array(self): - """Detach from pre-allocated SHM without unlinking it.""" + """Close and unlink routing SHM (on-demand, no longer pooled).""" if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: - self.shm_routing_data.detach_shm() + self.shm_routing_data.close_shm() self.shm_routing_data = None self.shm_routing_num_tokens = 0 return diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 1c06428622..31e2fbefed 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -443,13 +443,18 @@ def _get(key, default): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) except: pass diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py index 00f34c4893..85c4e44ef9 100644 --- a/test/test_api/test_r3.py +++ b/test/test_api/test_r3.py @@ -16,8 +16,8 @@ def test_routing_export(url: str = "http://localhost:8000"): "inputs": "What is the capital of France? What is the capital of France?", "parameters": { "max_new_tokens": 50, - "return_routed_experts": True, - "repetition_penalty": 1.0, + # "return_routed_experts": True, + # "repetition_penalty": 1.0, }, }, timeout=60, @@ -60,17 +60,10 @@ def test_routing_export(url: str = "http://localhost:8000"): print(f"{'=' * 50}") print(f"Shape: {shape}") print(f"Dtype: {dtype}") - print(f"Num MoE layers: {shape[0]}") - print(f"Num tokens: {shape[1]}") + print(f"Num tokens: {shape[0]}") + print(f"Num MoE layers: {shape[1]}") print(f"Top-K: {shape[2]}") - # Verify dtype is int8 (for models with ≤127 experts) or int16 - if dtype_str not in ("int8", "int16"): - print(f"\nERROR: Expected dtype int8 or int16, got {dtype_str}") - print("This suggests dtype optimization is not working correctly.") - return False - print(f"\nDtype check PASSED: {dtype_str} (compact representation)") - # Compute payload size savings int32_size = np.prod(shape) * 4 actual_size = len(data) @@ -78,9 +71,9 @@ def test_routing_export(url: str = "http://localhost:8000"): print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)") print(f"\nSample routing (first layer, first 5 tokens):") - num_tokens_to_show = shape[1] + num_tokens_to_show = shape[0] for i in range(num_tokens_to_show): - print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") + print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}") if np.all(routing_array == 0): print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") From e8ed8b5d9459cfe89ed7328cdca54596e71e859c Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 12 Feb 2026 10:55:49 +0000 Subject: [PATCH 071/180] add-neo++ --- .../common/kv_cache_mem_manager/__init__.py | 5 +- .../common/kv_cache_mem_manager/mem_utils.py | 10 +- .../kv_cache_mem_manager/neo_mem_manager.py | 78 +++--- lightllm/models/llama/model.py | 6 +- .../token_attention_nopad_att1.py | 3 +- .../layer_infer/transformer_layer_infer.py | 247 +++++++----------- .../layer_weights/transformer_layer_weight.py | 51 ++-- .../context_attention_fwd_neo.py | 5 +- lightllm/utils/kv_cache_utils.py | 2 +- 9 files changed, 181 insertions(+), 226 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index bcc1292097..8cd77eca16 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -4,7 +4,8 @@ from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager -from .neo_mem_manager import NeoMemoryManager + +# from .neo_mem_manager import NeoMemoryManager __all__ = [ "MemoryManager", @@ -14,5 +15,5 @@ "PPLINT4KVMemoryManager", "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", - "NeoMemoryManager", + # "NeoMemoryManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 686993dde4..9c7466eb5a 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -5,7 +5,7 @@ PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, - NeoMemoryManager, + # NeoMemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -26,10 +26,10 @@ def select_mem_manager_class(): mem_class = Deepseek2MemoryManager logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class - # 判断是否是 neo 系列的模型 - elif issubclass(model_class, NeoTpMOEPartModel) or issubclass(model_class, NeoTpPartModel): - mem_class = NeoMemoryManager - return mem_class + # # 判断是否是 neo 系列的模型 + # elif issubclass(model_class, NeoTpMOEPartModel) or issubclass(model_class, NeoTpPartModel): + # mem_class = NeoMemoryManager + # return mem_class # case normal logger.info(f"mode setting params: {get_env_start_args().llm_kv_type}") diff --git a/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py index 0a79aa072b..1101386f6c 100755 --- a/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py @@ -1,46 +1,46 @@ -import torch -from lightllm.utils.dist_utils import get_current_rank_in_node -from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +# import torch +# from lightllm.utils.dist_utils import get_current_rank_in_node +# from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +# from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager -class NeoMemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - self.size = size - self.head_num = head_num - self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 - self.layer_num = layer_num - self.always_copy = always_copy - self.dtype = dtype - # profile the max total token num if the size is None - self.profile_size(mem_fraction) +# class NeoMemoryManager(MemoryManager): +# def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): +# self.size = size +# self.head_num = head_num +# self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 +# self.layer_num = layer_num +# self.always_copy = always_copy +# self.dtype = dtype +# # profile the max total token num if the size is None +# self.profile_size(mem_fraction) - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size +# self.mem_state = torch.arange( +# 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True +# ) +# self._mem_state_return = torch.arange( +# 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True +# ) +# self._return_start = 0 +# self.mark_start = 0 +# self.mark_end = self.size - self.can_use_mem_size = self.size +# self.can_use_mem_size = self.size - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name +# # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 +# from lightllm.utils.envs_utils import get_unique_server_name - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) +# rank_in_node = get_current_rank_in_node() +# self.shared_can_use_token_num = SharedInt( +# f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" +# ) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self._init_buffers( - self.size, - dtype, - head_num, - self.head_dim, - layer_num, - ) - self.HOLD_TOKEN_MEMINDEX = self.size +# self.shared_can_use_token_num.set_value(self.can_use_mem_size) +# self._init_buffers( +# self.size, +# dtype, +# head_num, +# self.head_dim, +# layer_num, +# ) +# self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index cc1dc28178..20d6cad743 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -129,9 +129,10 @@ def _init_to_get_rotary(self, default_base=10000): except: pass - inv_freq = 1.0 / ( + full_inv_freq = 1.0 / ( base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) ) + inv_freq = full_inv_freq[::2] # for neo t = ( torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) / rope_scaling_factor @@ -169,9 +170,10 @@ def _init_to_get_hw_rotary(self, default_base=10000): except: pass - inv_freq = 1.0 / ( + full_inv_freq = 1.0 / ( base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) ) + inv_freq = full_inv_freq[::2] t = ( torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) / rope_scaling_factor diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index 223a64ad51..45de83e989 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -75,8 +75,7 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk assert Lk in {16, 32, 64, 128, 256} - Lk_scale = Lk // 2 - sm_scale = 1.0 / (Lk_scale ** 0.5) + sm_scale = 1.0 / (Lk ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 1518d68748..6a5259cafd 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -21,8 +21,7 @@ def __init__(self, data_type, network_config): def _bind_attention(self): self._context_attention_kernel = self._context_attention_kernel - self._token_attention_kernel = self._token_decode_attention_normal - self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + self._token_attention_kernel = self._token_attention_kernel return def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): @@ -34,74 +33,75 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC def _get_qkv_not_mergekv( self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight ): - input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) # [T, Hq*D] - - q_hw = layer_weight.q_hw_proj.mm(input) - q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) - q_h, q_w = q_hw.chunk(2, dim=-1) - - k_hw = layer_weight.k_hw_proj.mm(input) - k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) - k_h, k_w = k_hw.chunk(2, dim=-1) - - cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - - layer_weight.q_norm_weight_(q, eps=self.eps_) - - q_h_2d = q_h.reshape(q.shape[0], -1) - q_w_2d = q_w.reshape(q.shape[0], -1) - layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) - layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) - q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - - layer_weight.k_norm_weight_( - cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - eps=self.eps_, - ) - - k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] - k_w_2d = k_w.reshape(q.shape[0], -1) - layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) - layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) - k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) - k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) - - cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - rotary_emb_fwd( - q_h, - k_h, - infer_state.position_cos_h, - infer_state.position_sin_h, - ) - rotary_emb_fwd( - q_w, - k_w, - infer_state.position_cos_w, - infer_state.position_sin_w, - ) - - q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) - q3 = torch.cat([q3, q_h, q_w], dim=-1) - q = q3.reshape(q3.shape[0], -1) - - k = cache_kv[:, : self.tp_k_head_num_, :] - k = torch.cat([k, k_h, k_w], dim=-1) - - v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] - v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) - v = torch.cat([v, v_pad], dim=-1) - - cache_kv = torch.cat([k, v], dim=1) - return q, cache_kv + pass + # input = input.view(-1, self.embed_dim_) + # q = layer_weight.q_proj.mm(input) # [T, Hq*D] + + # q_hw = layer_weight.q_hw_proj.mm(input) + # q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + # q_h, q_w = q_hw.chunk(2, dim=-1) + + # k_hw = layer_weight.k_hw_proj.mm(input) + # k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + # k_h, k_w = k_hw.chunk(2, dim=-1) + + # cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + # layer_weight.q_norm_weight_(q, eps=self.eps_) + + # q_h_2d = q_h.reshape(q.shape[0], -1) + # q_w_2d = q_w.reshape(q.shape[0], -1) + # layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) + # layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) + # q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + # q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + + # layer_weight.k_norm_weight_( + # cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + # eps=self.eps_, + # ) + + # k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + # k_w_2d = k_w.reshape(q.shape[0], -1) + # layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) + # layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) + # k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + # k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + + # cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + # rotary_emb_fwd( + # q.view(-1, self.tp_q_head_num_, self.head_dim_), + # cache_kv[:, : self.tp_k_head_num_, :], + # infer_state.position_cos, + # infer_state.position_sin, + # ) + # rotary_emb_fwd( + # q_h, + # k_h, + # infer_state.position_cos_h, + # infer_state.position_sin_h, + # ) + # rotary_emb_fwd( + # q_w, + # k_w, + # infer_state.position_cos_w, + # infer_state.position_sin_w, + # ) + + # q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + # q3 = torch.cat([q3, q_h, q_w], dim=-1) + # q = q3.reshape(q3.shape[0], -1) + + # k = cache_kv[:, : self.tp_k_head_num_, :] + # k = torch.cat([k, k_h, k_w], dim=-1) + + # v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + # v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + # v = torch.cat([v, v_pad], dim=-1) + + # cache_kv = torch.cat([k, v], dim=1) + # return q, cache_kv def _get_qkv_mergekv( self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight @@ -112,29 +112,34 @@ def _get_qkv_mergekv( q, cache_kv = qkv.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - q_hw = layer_weight.q_hw_proj.mm(input) - k_hw = layer_weight.k_hw_proj.mm(input) + q = q.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) + q_t, q_hw = q.chunk(2, dim=-1) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + k = cache_kv[:, : self.tp_k_head_num_, :] + v = cache_kv[:, self.tp_k_head_num_ :, :] + k_t, k_hw = k.chunk(2, dim=-1) - layer_weight.q_norm_weight_(q, eps=self.eps_) - layer_weight.q_norm_hw_weight_(q_hw, eps=self.eps_) - layer_weight.k_norm_hw_weight_(k_hw, eps=self.eps_) + q_t_2d = q_t.reshape(q.shape[0], -1) + q_hw_2d = q_hw.reshape(q.shape[0], -1) + layer_weight.q_norm_weight_(q_t_2d, eps=self.eps_) + layer_weight.q_norm_hw_weight_(q_hw_2d, eps=self.eps_) - q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) + q_t = q_t_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_hw = q_hw_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) q_h, q_w = q_hw.chunk(2, dim=-1) - layer_weight.k_norm_weight_( - cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - eps=self.eps_, - ) - - k_hw = k_hw.view(q.shape[0], self.tp_k_head_num_, self.head_dim_) + k_t_2d = k_t.reshape(k.shape[0], -1) + k_hw_2d = k_hw.reshape(k.shape[0], -1) + layer_weight.k_norm_weight_(k_t_2d, eps=self.eps_) + layer_weight.k_norm_hw_weight_(k_hw_2d, eps=self.eps_) + k_t = k_t_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_hw = k_hw_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) k_h, k_w = k_hw.chunk(2, dim=-1) - cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, : self.tp_k_head_num_, :], + q_t, + k_t, infer_state.position_cos, infer_state.position_sin, ) @@ -151,17 +156,10 @@ def _get_qkv_mergekv( infer_state.position_sin_w, ) - q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) - q3 = torch.cat([q3, q_h, q_w], dim=-1) - q = q3.reshape(q3.shape[0], -1) - - k = cache_kv[:, : self.tp_k_head_num_, :] - k = torch.cat([k, k_h, k_w], dim=-1) - - v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] - v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) - v = torch.cat([v, v_pad], dim=-1) + q = torch.cat([q_t, q_h, q_w], dim=-1) + q = q.reshape(q.shape[0], -1) + k = torch.cat([k_t, k_h, k_w], dim=-1) cache_kv = torch.cat([k, v], dim=1) return q, cache_kv @@ -171,10 +169,10 @@ def _context_attention_kernel( o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out kv = infer_state.mem_manager.kv_buffer[self.layer_num_] context_attention_fwd_neo( - q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + q.view(-1, self.tp_q_head_num_, self.head_dim_), kv[:, 0 : self.tp_k_head_num_, :], kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] infer_state.b_req_idx, infer_state.b_q_start_loc, @@ -184,9 +182,7 @@ def _context_attention_kernel( infer_state.req_manager.req_to_token_indexs, infer_state.b_image_token_tag, ) - o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) - o3 = o3[:, :, : self.head_dim_].contiguous() - return o3.view(o3.shape[0], -1) + return o_tensor def _token_attention_kernel( self, @@ -195,56 +191,11 @@ def _token_attention_kernel( layer_weight: NeoChatMOETransformerLayerWeight, ) -> torch.Tensor: _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) - _q = q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) att_control = AttControl() - if att_control.scale is None: - att_control.scale = 1.0 / (self.head_dim_ ** 0.5) # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5) o_tensor = infer_state.decode_att_state.decode_att( q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor ) - o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)[:, :, : self.head_dim_].contiguous() + o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, : self.head_dim_].contiguous() return o_tensor - - # def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): - # total_token_num = infer_state.total_token_num - # batch_size = infer_state.batch_size - - # q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) - - # att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) - - # k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - # token_att_fwd( - # q_3d, - # k_3d, - # att_m_tensor, - # infer_state.req_manager.req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_kv_start_loc, - # infer_state.b_seq_len, - # infer_state.max_kv_seq_len, - # ) - - # from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( - # token_attention_softmax_and_reducev, - # ) - - # token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd - - # v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ - # :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ - # ] - - # o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) - - # token_softmax_reducev_fwd( - # att_m_tensor, - # v_3d, - # o_3d, - # infer_state.req_manager.req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_kv_start_loc, - # infer_state.b_seq_len, - # ) - # return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py index 26e986cdd7..83ec33060c 100644 --- a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -2,6 +2,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( QKRMSNORMWeight, ROWMMWeight, + RMSNormWeight, ) @@ -13,10 +14,6 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _init_weight_names(self): super()._init_weight_names() - self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" - self._q_bias_hw_name = None - self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" - self._k_bias_hw_name = None if self._is_merge_kv: self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight" @@ -30,54 +27,60 @@ def _init_weight_names(self): def _init_qkv(self): super()._init_qkv() - self.q_hw_proj = ROWMMWeight( - in_dim=self.network_config_["hidden_size"], - out_dims=[self.q_head_num_ * self.head_dim], - weight_names=self._q_weight_hw_name, + + def _init_norm(self): + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, data_type=self.data_type_, - bias_names=self._q_bias_hw_name, - quant_method=self.get_quant_method("q_hw_proj"), ) - self.k_hw_proj = ROWMMWeight( - in_dim=self.network_config_["hidden_size"], - out_dims=[self.k_head_num_ * self.head_dim], - weight_names=self._k_weight_hw_name, + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + + self.q_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_name, + data_type=self.data_type_, + ) + self.k_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_name, data_type=self.data_type_, - bias_names=self._k_bias_hw_name, - quant_method=self.get_quant_method("k_hw_proj"), ) - def _init_norm(self): - super()._init_norm() if self._is_merge_kv: self.q_norm_hw_weight_ = QKRMSNORMWeight( - dim=self.head_dim, + dim=self.head_dim // 2, weight_name=self._q_norm_hw_name, data_type=self.data_type_, ) self.k_norm_hw_weight_ = QKRMSNORMWeight( - dim=self.head_dim, + dim=self.head_dim // 2, weight_name=self._k_norm_hw_name, data_type=self.data_type_, ) else: self.q_norm_h_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, + dim=self.head_dim // 4, weight_name=self._q_norm_h_name, data_type=self.data_type_, ) self.q_norm_w_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, + dim=self.head_dim // 4, weight_name=self._q_norm_w_name, data_type=self.data_type_, ) self.k_norm_h_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, + dim=self.head_dim // 4, weight_name=self._k_norm_h_name, data_type=self.data_type_, ) self.k_norm_w_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, + dim=self.head_dim // 4, weight_name=self._k_norm_w_name, data_type=self.data_type_, ) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py index 42c3254e27..74ff82cae4 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -158,8 +158,7 @@ def context_attention_fwd_neo( Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128, 256} - base_head_dim = Lq // 2 - sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 + sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] @@ -200,7 +199,7 @@ def context_attention_fwd_neo( b_image_token_tag=b_image_token_tag, H=head, QK_HEAD_DIM=Lk, - V_HEAD_DIM=Lk // 2, + V_HEAD_DIM=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index f44aad92ac..686b5569d7 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -24,7 +24,7 @@ PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, - NeoMemoryManager, + # NeoMemoryManager, ) from typing import List, Tuple, Optional From 77b73c214148094d4478d69052908b16cb8c353c Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 19 Feb 2026 14:37:20 +0000 Subject: [PATCH 072/180] feat: add Qwen3Next linear attention model support Implement comprehensive support for Qwen3Next model with linear attention mechanism: Model Features: - Implement linear attention with MTP (Multi-Token Prediction) capability - Add custom Triton kernels for gated delta networks (GDN) operations - Support chunked operations for efficient attention computation - Add specialized buffer pool and memory managers for linear attention Triton Kernels: - Add causal_conv1d for efficient convolution operations - Implement chunk-based operations (chunk_o, chunk_delta_h, chunk_scaled_dot_kkt) - Add gated delta network kernels (fused_gdn_gating, gdn_decode_mtp) - Implement fused normalization (gemma_rmsnorm, gated_rmsnorm) Infrastructure: - Add hybrid radix cache for efficient memory management - Implement mamba cache manager for state management - Add allocator utilities for buffer management - Add parameter weight abstraction for flexible weight handling - Update model registration and API endpoints Performance Optimizations: - Add H200 autotune configurations for all Triton kernels - Optimize memory allocation with custom kernels - Support chunked prefill and decode backends This implementation enables efficient inference for models with linear attention mechanisms, providing significant speedup for long sequence lengths. --- lightllm/common/allocator_utils.py | 98 ++ lightllm/common/basemodel/basemodel.py | 6 + .../transformer_layer_infer_template.py | 49 +- .../basemodel/layer_weights/hf_load_utils.py | 10 +- .../layer_weights/meta_weights/__init__.py | 1 + .../meta_weights/parameter_weight.py | 83 + .../triton_kernel/alloc_buffer_kernel.py | 80 + .../triton_kernel/mamba_buffer_copy.py | 961 ++++++++++++ .../kv_cache_mem_manager/mem_manager.py | 108 +- .../mamba_cache_mem_manager/cache_manager.py | 188 +++ lightllm/common/req_manager.py | 46 + .../{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json | 14 + .../{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json | 14 + .../{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json | 14 + .../{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json | 7 + .../{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json | 12 + ...ARLEN=true,REVERSE=false}_NVIDIA_H200.json | 38 + ...ARLEN=true,REVERSE=false}_NVIDIA_H200.json | 38 + ...=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json | 7 + ...H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json | 12 + ...6,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 50 + ...8,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 50 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 70 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...um=10,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...um=10,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...um=10,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ .../{topk_num=10}_NVIDIA_H200.json | 50 + ...rch.bfloat16,topk_num=10}_NVIDIA_H200.json | 74 + ...rch.bfloat16,topk_num=10}_NVIDIA_H200.json | 74 + ...M=4,dtype=torch.bfloat16}_NVIDIA_H200.json | 50 + ...M=4,dtype=torch.bfloat16}_NVIDIA_H200.json | 50 + ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 + ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 + lightllm/models/__init__.py | 6 + lightllm/models/qwen3next/__init__.py | 3 + lightllm/models/qwen3next/buffer_pool.py | 83 + lightllm/models/qwen3next/infer_struct.py | 62 + .../qwen3next/layer_infer/post_layer_infer.py | 12 + .../layer_infer/shared_expert_mixin.py | 101 ++ .../layer_infer/transformer_layer_infer.py | 1067 +++++++++++++ .../layer_weights/transformer_layer_weight.py | 313 ++++ lightllm/models/qwen3next/mem_manager.py | 72 + lightllm/models/qwen3next/model.py | 157 ++ .../qwen3next/triton_kernel/causal_conv1d.py | 122 ++ .../qwen3next/triton_kernel/fla/__init__.py | 11 + .../triton_kernel/fla/ops/__init__.py | 15 + .../qwen3next/triton_kernel/fla/ops/chunk.py | 224 +++ .../triton_kernel/fla/ops/chunk_delta_h.py | 324 ++++ .../triton_kernel/fla/ops/chunk_o.py | 205 +++ .../fla/ops/chunk_scaled_dot_kkt.py | 180 +++ .../qwen3next/triton_kernel/fla/ops/cumsum.py | 306 ++++ .../triton_kernel/fla/ops/fused_recurrent.py | 492 ++++++ .../qwen3next/triton_kernel/fla/ops/index.py | 30 + .../qwen3next/triton_kernel/fla/ops/l2norm.py | 173 +++ .../qwen3next/triton_kernel/fla/ops/op.py | 65 + .../triton_kernel/fla/ops/solve_tril.py | 462 ++++++ .../qwen3next/triton_kernel/fla/ops/utils.py | 179 +++ .../triton_kernel/fla/ops/wy_fast.py | 145 ++ .../triton_kernel/fused_add_gemma_rmsnorm.py | 186 +++ .../triton_kernel/fused_gdn_gating.py | 87 ++ .../triton_kernel/fused_qkv_gating.py | 163 ++ .../triton_kernel/fused_split_copy.py | 400 +++++ .../qwen3next/triton_kernel/gated_rmsnorm.py | 174 +++ .../qwen3next/triton_kernel/gdn_decode_mtp.py | 1333 +++++++++++++++++ .../qwen3next/triton_kernel/gemma_rmsnorm.py | 141 ++ lightllm/models/qwen3next_mtp/__init__.py | 3 + .../qwen3next_mtp/layer_infer/__init__.py | 0 .../layer_infer/post_layer_infer.py | 16 + .../layer_infer/pre_layer_infer.py | 68 + .../layer_infer/transformer_layer_infer.py | 30 + .../qwen3next_mtp/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 47 + .../layer_weights/transformer_layer_weight.py | 141 ++ lightllm/models/qwen3next_mtp/model.py | 101 ++ lightllm/server/api_cli.py | 20 +- lightllm/server/api_openai.py | 20 +- lightllm/server/api_start.py | 3 +- lightllm/server/core/objs/start_args_type.py | 24 +- .../dynamic_prompt/hybrid_radix_cache.py | 206 +++ lightllm/server/tokenizer.py | 8 + lightllm/utils/config_utils.py | 16 + lightllm/utils/envs_utils.py | 2 +- 91 files changed, 10981 insertions(+), 124 deletions(-) create mode 100644 lightllm/common/allocator_utils.py create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py create mode 100644 lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py create mode 100644 lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py create mode 100644 lightllm/common/mamba_cache_mem_manager/cache_manager.py create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/models/qwen3next/__init__.py create mode 100644 lightllm/models/qwen3next/buffer_pool.py create mode 100644 lightllm/models/qwen3next/infer_struct.py create mode 100644 lightllm/models/qwen3next/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py create mode 100644 lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3next/mem_manager.py create mode 100644 lightllm/models/qwen3next/model.py create mode 100644 lightllm/models/qwen3next/triton_kernel/causal_conv1d.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/__init__.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/index.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/op.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_split_copy.py create mode 100644 lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py create mode 100644 lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py create mode 100644 lightllm/models/qwen3next_mtp/__init__.py create mode 100644 lightllm/models/qwen3next_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/qwen3next_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3next_mtp/model.py create mode 100644 lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py diff --git a/lightllm/common/allocator_utils.py b/lightllm/common/allocator_utils.py new file mode 100644 index 0000000000..803ed0a715 --- /dev/null +++ b/lightllm/common/allocator_utils.py @@ -0,0 +1,98 @@ +from typing import List, Union + +import torch + +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class TokenAllocator: + def __init__(self, size, shared_can_use_token_num_name: str): + self.size = size + + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + + self.can_use_mem_size = self.size + + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + self.shared_can_use_token_num = SharedInt(shared_can_use_token_num_name) + + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.HOLD_TOKEN_MEMINDEX = self.size + + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" + + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) + self.mem_state[start:end] = free_index_tensor + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + return + + def free_all(self): + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) + + def resize_mem(self, new_size): + """ + just for test code + """ + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5c1d2b8712..caa90462cc 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -53,6 +53,12 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo + @classmethod + def get_radix_cache_class(cls): + from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache + + return RadixCache + def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 9153349c5d..646f998642 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -62,20 +62,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) - input1 = None + def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) - q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.context_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -87,39 +88,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) - input1 = None + def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.token_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight) - input1 = None + def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) - q = None o = self._tpsp_get_o(o, infer_state, layer_weight) + return o + + def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -129,14 +133,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight) - input1 = None + def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._tpsp_get_o(o, infer_state, layer_weight) + return o + + def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8cf66a5ad6..304b04ab44 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -18,6 +18,14 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay weights = {k: weights.get_tensor(k) for k in weights.keys()} else: weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") + new_weight = {} + for k, v in weights.items(): + if "language_model." in k: + new_weight[k[len("language_model.") :]] = v + else: + new_weight[k] = v + del weights + weights = new_weight if pre_post_layer is not None: pre_post_layer.load_hf_weights(weights) @@ -60,7 +68,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye transformer_layer_list=transformer_layer_list, weight_dir=weight_dir, ) # noqa - worker = int(os.environ.get("LOADWORKER", 1)) + worker = int(os.environ.get("LOADWORKER", 18)) with Pool(worker) as p: iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index edf7fe21b9..fe77ca669c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -11,3 +11,4 @@ from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight from .fused_moe.fused_moe_weight import FusedMoeWeight +from .parameter_weight import ParameterWeight, TpParameterWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py new file mode 100644 index 0000000000..0afb0ecab2 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -0,0 +1,83 @@ +import torch +from typing import Dict, Optional, Tuple +from .base_weight import BaseWeightTpl + + +class ParameterWeight(BaseWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + weight_shape: Optional[Tuple[int, ...]], + bias_name: Optional[str] = None, + bias_shape: Optional[Tuple[int, ...]] = None, + ): + super().__init__() + self.weight_name = weight_name + self.bias_name = bias_name + self.data_type_ = data_type + self.weight_shape = weight_shape + self.bias_shape = bias_shape + self.weight: Optional[torch.Tensor] = None + self.bias: Optional[torch.Tensor] = None + if weight_shape is not None: + self._create_weight() + + def _create_weight(self): + if self.weight_shape is not None: + self.weight = torch.empty(*self.weight_shape, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False + if self.bias_name is not None and self.bias_shape is not None: + self.bias = torch.empty(*self.bias_shape, dtype=self.data_type_, device=self.device_id_) + self.bias.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_name in weights: + t_weight = weights[self.weight_name] + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True + if self.bias_name is not None and self.bias_name in weights: + t_bias = weights[self.bias_name] + self.bias.copy_(t_bias.to(self.data_type_)) + self.bias.load_ok = True + + def verify_load(self) -> bool: + if self.weight is not None and not getattr(self.weight, "load_ok", False): + return False + if self.bias is not None and not getattr(self.bias, "load_ok", False): + return False + return True + + +class TpParameterWeight(ParameterWeight): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_shape: Optional[Tuple[int, ...]] = None, + bias_shape: Optional[Tuple[int, ...]] = None, + ): + self.split_n_embed = split_n_embed + # Calculate TP-split shapes if full shapes are provided + tp_weight_shape = None + tp_bias_shape = None + if weight_shape is not None: + tp_weight_shape = (split_n_embed,) + weight_shape[1:] + if bias_shape is not None: + tp_bias_shape = (split_n_embed,) + bias_shape[1:] + super().__init__(weight_name, data_type, tp_weight_shape, bias_name, tp_bias_shape) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + start = self.split_n_embed * self.tp_rank_ + end = self.split_n_embed * (self.tp_rank_ + 1) + + if self.weight_name in weights: + t_weight = weights[self.weight_name][start:end] + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True + if self.bias_name is not None and self.bias_name in weights: + t_bias = weights[self.bias_name][start:end] + self.bias.copy_(t_bias.to(self.data_type_)) + self.bias.load_ok = True diff --git a/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py b/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py new file mode 100644 index 0000000000..b6444449b1 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py @@ -0,0 +1,80 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def alloc_buffer_for_req_kernel( + req_index_ptr, # [num_reqs] - indices of requests to allocate buffers for + buffer_indexes_ptr, # [num_reqs * num_buffers_per_req] - buffer indices to assign (from CPU) + req_to_buffer_index_ptr, # [max_request_num + 1, num_buffers_per_req] - tensor mapping req_idx to buffer_idx + num_reqs, # number of requests to process + stride_buffer, # stride for req_to_buffer_index second dimension + NUM_BUFFERS_PER_REQ: tl.constexpr, # number of buffers per request (mtp_step + 1) + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Mask for valid indices + mask = offsets < num_reqs + + # Load request indices + req_indices = tl.load(req_index_ptr + offsets, mask=mask, other=0) + + # For each request, allocate NUM_BUFFERS_PER_REQ buffers + for buf_idx in tl.static_range(NUM_BUFFERS_PER_REQ): + # Load buffer index for this position + buffer_offset = offsets * NUM_BUFFERS_PER_REQ + buf_idx + buffer_indices = tl.load(buffer_indexes_ptr + buffer_offset, mask=mask, other=0) + + # Update req_to_buffer_index[req_indices, buf_idx] = buffer_indices + output_offset = req_indices * stride_buffer + buf_idx + tl.store(req_to_buffer_index_ptr + output_offset, buffer_indices, mask=mask) + + +def alloc_buffer_for_req_triton( + req_index: torch.Tensor, # [num_reqs] int32/int64 tensor on CUDA + buffer_indexes: torch.Tensor, # [num_reqs * (mtp_step + 1)] int32 tensor (can be CPU or CUDA) + req_to_buffer_index: torch.Tensor, # [max_request_num + 1, mtp_step + 1] int32 tensor on CUDA + mtp_step: int = 0, # number of additional buffers per request (default 0 for non-MTP mode) +): + num_reqs = req_index.shape[0] + num_buffers_per_req = mtp_step + 1 + + # Ensure inputs are on CUDA + if not req_index.is_cuda: + req_index = req_index.cuda() + if not buffer_indexes.is_cuda: + buffer_indexes = buffer_indexes.cuda() + + # Ensure correct dtypes + if req_index.dtype not in [torch.int32, torch.int64]: + req_index = req_index.to(torch.int32) + if buffer_indexes.dtype != torch.int32: + buffer_indexes = buffer_indexes.to(torch.int32) + + # Validate buffer_indexes size + expected_size = num_reqs * num_buffers_per_req + assert buffer_indexes.shape[0] == expected_size, ( + f"Expected {expected_size} buffer indices for {num_reqs} requests " + f"with mtp_step={mtp_step}, but got {buffer_indexes.shape[0]}" + ) + + # Get stride for the second dimension of req_to_buffer_index + stride_buffer = req_to_buffer_index.stride(0) + + # Launch kernel + BLOCK_SIZE = 256 + grid = (triton.cdiv(num_reqs, BLOCK_SIZE),) + + alloc_buffer_for_req_kernel[grid]( + req_index, + buffer_indexes, + req_to_buffer_index, + num_reqs, + stride_buffer, + NUM_BUFFERS_PER_REQ=num_buffers_per_req, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py new file mode 100644 index 0000000000..b4a91f7861 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -0,0 +1,961 @@ +""" +Optimized Mamba Buffer Copy Kernels with Autotune Support + +This module provides auto-tuned Triton kernels for efficient buffer copying operations +in Mamba-style models, including support for MTP (Multi-Token Prediction) buffer broadcasting. +""" + +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _copy_buffer_p2p_1d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d, + d_size, + BLOCK_D: tl.constexpr, +): + """ + Optimized kernel for 1D buffer copy. + + Grid: (num_pairs, layer_num, num_blocks_d) + Each program copies one block of dimension d for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_d_idx = tl.program_id(2) + + # Load source and destination indices for this pair + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d_start = block_d_idx * BLOCK_D + d_offsets = d_start + tl.arange(0, BLOCK_D) + + # Create mask for valid indices + mask = d_offsets < d_size + + # Calculate source and destination pointers for this layer and pair + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + src_ptr = base_src + d_offsets * stride_d + dst_ptr = base_dst + d_offsets * stride_d + + # Load and store + data = tl.load(src_ptr, mask=mask, other=0.0) + tl.store(dst_ptr, data, mask=mask) + + +@triton.jit +def _copy_buffer_p2p_2d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + d1_size, + d2_size, + num_blocks_d2, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, +): + """ + Kernel to copy 2D buffer from source indices to destination indices. + + Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2) + Each program copies one 2D block for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1 and d2 block indices + block_d1_idx = block_idx // num_blocks_d2 + block_d2_idx = block_idx % num_blocks_d2 + + # Load source and destination indices + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + + # Create mask for valid indices + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + mask = d1_mask[:, None] & d2_mask[None, :] + + # Calculate base pointers + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + # Calculate full offsets + offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 + + # Load and store + data = tl.load(base_src + offsets, mask=mask, other=0.0) + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_1d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d, + d_size, + num_dst_per_src, + BLOCK_D: tl.constexpr, +): + """ + Broadcast kernel for 1D buffer copy (one source to multiple destinations). + + Grid: (num_src, layer_num, num_blocks_d) + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_d_idx = tl.program_id(2) + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets for this block + d_start = block_d_idx * BLOCK_D + d_offsets = d_start + tl.arange(0, BLOCK_D) + mask = d_offsets < d_size + + # Calculate source pointer + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + src_ptr = base_src + d_offsets * stride_d + + # Load data once + data = tl.load(src_ptr, mask=mask, other=0.0) + + # Broadcast to all destinations for this source + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + dst_ptr = base_dst + d_offsets * stride_d + + tl.store(dst_ptr, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_2d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + d1_size, + d2_size, + num_blocks_d2, + num_dst_per_src, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, +): + """ + Broadcast kernel for 2D buffer copy (one source to multiple destinations). + + Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2) + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx + block_d1_idx = block_idx // num_blocks_d2 + block_d2_idx = block_idx % num_blocks_d2 + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + mask = d1_mask[:, None] & d2_mask[None, :] + + # Calculate source pointer and load data once + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 + data = tl.load(base_src + offsets, mask=mask, other=0.0) + + # Broadcast to all destinations + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_p2p_3d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + stride_d3, + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, + BLOCK_D3: tl.constexpr, +): + """ + Optimized kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). + + Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) + Each program copies one 3D block for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1, d2, d3 block indices + block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) + temp = block_idx % (num_blocks_d2 * num_blocks_d3) + block_d2_idx = temp // num_blocks_d3 + block_d3_idx = temp % num_blocks_d3 + + # Load source and destination indices for this pair + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + d3_start = block_d3_idx * BLOCK_D3 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + d3_offsets = d3_start + tl.arange(0, BLOCK_D3) + + # Create masks for valid indices + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + d3_mask = d3_offsets < d3_size + + # 3D mask: [BLOCK_D1, BLOCK_D2, BLOCK_D3] + mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] + + # Calculate base pointers + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + # Calculate full 3D offsets + offsets = ( + d1_offsets[:, None, None] * stride_d1 + + d2_offsets[None, :, None] * stride_d2 + + d3_offsets[None, None, :] * stride_d3 + ) + + # Load and store + data = tl.load(base_src + offsets, mask=mask, other=0.0) + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_3d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + stride_d3, + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + num_dst_per_src, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, + BLOCK_D3: tl.constexpr, +): + """ + Broadcast kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). + + Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) + Each program loads once from source and broadcasts to all destinations. + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1, d2, d3 block indices + block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) + temp = block_idx % (num_blocks_d2 * num_blocks_d3) + block_d2_idx = temp // num_blocks_d3 + block_d3_idx = temp % num_blocks_d3 + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + d3_start = block_d3_idx * BLOCK_D3 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + d3_offsets = d3_start + tl.arange(0, BLOCK_D3) + + # Create masks + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + d3_mask = d3_offsets < d3_size + + mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] + + # Calculate source pointer and load data once + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + + offsets = ( + d1_offsets[:, None, None] * stride_d1 + + d2_offsets[None, :, None] * stride_d2 + + d3_offsets[None, None, :] * stride_d3 + ) + + data = tl.load(base_src + offsets, mask=mask, other=0.0) + + # Broadcast to all destinations for this source + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + tl.store(base_dst + offsets, data, mask=mask) + + +# ==================== Config Generation Functions ==================== + + +def _get_buffer_copy_1d_configs(): + """Generate candidate configurations for 1D buffer copy.""" + configs = [] + for block_d in [32, 64, 128, 256, 512, 1024]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_D": block_d, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_buffer_copy_2d_configs(): + """Generate candidate configurations for 2D buffer copy.""" + configs = [] + for block_d1 in [16, 32, 64, 128]: + for block_d2 in [16, 32, 64, 128, 256]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_D1": block_d1, + "BLOCK_D2": block_d2, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_buffer_copy_3d_configs(): + """Generate candidate configurations for 3D buffer copy (5D tensor).""" + configs = [] + for block_d1 in [8, 16, 32]: + for block_d2 in [8, 16, 32, 64]: + for block_d3 in [8, 16, 32, 64, 128]: + for num_warps in [4, 8]: + for num_stages in [2, 3]: + # Skip configs that are too large for shared memory + if block_d1 * block_d2 * block_d3 > 32768: + continue + configs.append( + { + "BLOCK_D1": block_d1, + "BLOCK_D2": block_d2, + "BLOCK_D3": block_d3, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +# ==================== Static and Run Key Functions ==================== + + +def _get_buffer_copy_static_key(src_buffer: torch.Tensor): + """Static key based on buffer shape and dtype.""" + shape = src_buffer.shape + return { + "ndim": len(shape), + "layer_num": shape[0], + "d_sizes": str(shape[2:]), # Dimension sizes + "dtype": str(src_buffer.dtype), + } + + +def _get_buffer_copy_run_key(src_indexes: torch.Tensor): + """Run key based on number of copy pairs.""" + return src_indexes.shape[0] + + +# ==================== Auto-tuned Buffer Copy Functions ==================== + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_1d:v1", + configs_gen_func=_get_buffer_copy_1d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_1d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 1D buffer copy.""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d_size = src_buffer.shape[2] + + if run_config is None: + # Default config if autotune is disabled + BLOCK_D = triton.next_power_of_2(min(d_size, 256)) + num_warps = 4 if BLOCK_D > 256 else 2 + num_stages = 2 + else: + BLOCK_D = run_config["BLOCK_D"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d = triton.cdiv(d_size, BLOCK_D) + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_d) + + _copy_buffer_p2p_1d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + d_size, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_2d:v1", + configs_gen_func=_get_buffer_copy_2d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_2d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 2D buffer copy.""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + + if run_config is None: + # Default config if autotune is disabled + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_total = num_blocks_d1 * num_blocks_d2 + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_p2p_2d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + d1_size, + d2_size, + num_blocks_d2, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_1d:v1", + configs_gen_func=_get_buffer_copy_1d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_1d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 1D buffer broadcast (one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d_size = src_buffer.shape[2] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D = triton.next_power_of_2(min(d_size, 256)) + num_warps = 4 if BLOCK_D > 256 else 2 + num_stages = 2 + else: + BLOCK_D = run_config["BLOCK_D"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d = triton.cdiv(d_size, BLOCK_D) + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_d) + + _copy_buffer_broadcast_1d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + d_size, + num_dst_per_src, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_2d:v1", + configs_gen_func=_get_buffer_copy_2d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_2d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 2D buffer broadcast (one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_total = num_blocks_d1 * num_blocks_d2 + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_broadcast_2d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + d1_size, + d2_size, + num_blocks_d2, + num_dst_per_src, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_3d:v1", + configs_gen_func=_get_buffer_copy_3d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_3d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 3D data buffer copy (5D tensor).""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + d3_size = src_buffer.shape[4] + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) + BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) + num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + BLOCK_D3 = run_config["BLOCK_D3"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) + num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_p2p_3d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + src_buffer.stride(4), + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + BLOCK_D3=BLOCK_D3, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_3d:v1", + configs_gen_func=_get_buffer_copy_3d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_3d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 3D data buffer broadcast (5D tensor, one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + d3_size = src_buffer.shape[4] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) + BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) + num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + BLOCK_D3 = run_config["BLOCK_D3"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) + num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_broadcast_3d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + src_buffer.stride(4), + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + num_dst_per_src, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + BLOCK_D3=BLOCK_D3, + num_warps=num_warps, + num_stages=num_stages, + ) + + +# ==================== Unified Interface ==================== + + +def copy_buffer_p2p( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + """ + Copy buffers from source indices to destination indices with auto-tuning. + + Supports 3D (conv states), 4D (standard buffers), and 5D (SSM states) buffers. + + Args: + src_buffer: Source buffer tensor [layer_num, buffer_size, ...] + dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] + src_indexes: Source buffer indices [num_pairs] + dst_indexes: Destination buffer indices [num_pairs] + """ + assert src_buffer.shape == dst_buffer.shape + assert src_indexes.shape == dst_indexes.shape + assert len(src_indexes.shape) == 1 + + if len(src_buffer.shape) == 3: + # 1D case: (layer_num, buffer_size, d) + _copy_buffer_p2p_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + elif len(src_buffer.shape) == 4: + # 2D case: (layer_num, buffer_size, d1, d2) + _copy_buffer_p2p_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + elif len(src_buffer.shape) == 5: + # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory + _copy_buffer_p2p_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + else: + raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") + + +def copy_buffer_broadcast( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + """ + Broadcast buffers from source indices to multiple destination indices (MTP use case). + + Each source buffer is copied to multiple destination buffers. + + Args: + src_buffer: Source buffer tensor [layer_num, buffer_size, ...] + dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] + src_indexes: Source buffer indices [num_src] + dst_indexes: Destination buffer indices [num_src, num_dst_per_src] (2D tensor) + """ + assert src_buffer.shape == dst_buffer.shape + assert len(src_indexes.shape) == 1 + assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D, got shape {dst_indexes.shape}" + + num_src = src_indexes.shape[0] + + assert num_src == dst_indexes.shape[0], f"Mismatch: src_indexes {num_src} vs dst_indexes {dst_indexes.shape[0]}" + + # Flatten dst_indexes for kernel + dst_indexes_flat = dst_indexes.reshape(-1).contiguous() + + if len(src_buffer.shape) == 3: + # 1D case + _copy_buffer_broadcast_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + elif len(src_buffer.shape) == 4: + # 2D case + _copy_buffer_broadcast_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + elif len(src_buffer.shape) == 5: + # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory + _copy_buffer_broadcast_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + else: + raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec7..8d6fb48c28 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -18,14 +18,17 @@ from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm +from lightllm.common.allocator_utils import TokenAllocator from multiprocessing.reduction import ForkingPickler from filelock import FileLock logger = init_logger(__name__) +KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_kv_cache_token_can_use_num" -class MemoryManager: + +class MemoryManager(TokenAllocator): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num @@ -36,27 +39,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False # profile the max total token num if the size is None self.profile_size(mem_fraction) - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size - - self.can_use_mem_size = self.size + super().__init__(self.size, f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name - - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) - - self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, @@ -64,7 +48,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False head_dim, layer_num, ) - self.HOLD_TOKEN_MEMINDEX = self.size def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ @@ -341,59 +324,13 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """_summary_ - - Args: - free_index (torch.Tensor): _description_ - """ - - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - self.mem_state.numpy()[start:end] = free_index - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - return + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kv_buffer[:, index]} - def free_all(self): - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + # 重写resize_mem方法,添加_free_buffers和_init_buffers调用 def resize_mem(self, new_size): """ just for test code @@ -404,24 +341,13 @@ def resize_mem(self, new_size): head_dim = self.head_dim layer_num = self.layer_num - self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + # 调用父类的resize_mem + super().resize_mem(new_size) + self._free_buffers() self._init_buffers(size, dtype, head_num, head_dim, layer_num) return - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index]} - - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - def copy_kv_from_other_dp_ranks( self, mem_managers: List["MemoryManager"], @@ -513,12 +439,12 @@ def __init__(self) -> None: self.dp_world_size = self.global_world_size // args.dp # 兼容多机 dp size=1 纯 tp 模式的情况 self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 - self.shared_tp_infos = [ - SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}") + self.shared_tp_can_use_token_nums = [ + SharedInt(f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] def get_unrefed_token_num(self, dp_rank_in_node: int): if self.is_multinode_tp: - return self.shared_tp_infos[0].get_value() - return self.shared_tp_infos[dp_rank_in_node].get_value() + return self.shared_tp_can_use_token_nums[0].get_value() + return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py new file mode 100644 index 0000000000..348b14192c --- /dev/null +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -0,0 +1,188 @@ +from typing import List, Tuple, Union + +import torch +import numpy as np + +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.common.allocator_utils import TokenAllocator +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt + +logger = init_logger(__name__) + +MAMBA_CACHE_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_mamba_cache_can_use_num" + + +class LayerCache: + def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_num: int): + self.size = size + self.dtype = dtype + self.shape = shape + self.layer_num = layer_num + + self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda") + + def get_cell_size(self): + return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) + + +class MambaCacheManager(TokenAllocator): + def __init__( + self, + size: int, + layer_num: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + ): + super().__init__(size, f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + self.conv_state_cache = LayerCache(size, conv_state_dtype, conv_state_shape, layer_num) + self.ssm_state_cache = LayerCache(size, ssm_state_dtype, ssm_state_shape, layer_num) + self.HOLD_BUFFER_INDEX = size + + logger.warning( + f"Linear attention state cache size: {size}\n" + f"Conv state use : " + f"{self.conv_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" + f"Ssm state use : " + f"{self.ssm_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" + ) + + def get_mamba_cache(self, layer_idx: int): + conv_state = self.conv_state_cache.buffer[layer_idx] + ssm_state = self.ssm_state_cache.buffer[layer_idx] + return conv_state, ssm_state + + def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): + """ + Copy buffers from source indices to destination indices using optimized Triton kernel. + + Args: + src_buffer_indexes: Source buffer indices (1D tensor) + dst_buffer_indexes: Destination buffer indices (1D tensor) + """ + assert src_buffer_indexes.dim() == 1 + assert dst_buffer_indexes.dim() == 1 + assert src_buffer_indexes.shape[0] == dst_buffer_indexes.shape[0] + + # Validate indices are within valid range [0, size] (size+1 is the buffer dim) + max_valid_idx = self.size # HOLD_BUFFER_INDEX = size is valid + src_max = src_buffer_indexes.max().item() if src_buffer_indexes.numel() > 0 else -1 + src_min = src_buffer_indexes.min().item() if src_buffer_indexes.numel() > 0 else -1 + dst_max = dst_buffer_indexes.max().item() if dst_buffer_indexes.numel() > 0 else -1 + dst_min = dst_buffer_indexes.min().item() if dst_buffer_indexes.numel() > 0 else -1 + + if src_min < 0 or src_max > max_valid_idx or dst_min < 0 or dst_max > max_valid_idx: + logger.error( + f"Invalid buffer indices: src=[{src_min}, {src_max}], dst=[{dst_min}, {dst_max}], " + f"valid range=[0, {max_valid_idx}], conv shape={self.conv_state_cache.buffer.shape}, " + f"ssm shape={self.ssm_state_cache.buffer.shape}" + ) + raise ValueError("Invalid buffer indices for copy_buffer_p2p") + + # Use PyTorch advanced indexing for buffer copy (safer than Triton for complex shapes) + # The buffer shape is [layer_num, buffer_size, *shape] + # We need to copy all layers for the given buffer indices + src_idx = src_buffer_indexes.long() + dst_idx = dst_buffer_indexes.long() + + # Copy conv_state: [layer_num, buffer_size, d1, d2] + self.conv_state_cache.buffer[:, dst_idx, ...] = self.conv_state_cache.buffer[:, src_idx, ...] + + # Copy ssm_state: [layer_num, buffer_size, d1, d2, d3] + self.ssm_state_cache.buffer[:, dst_idx, ...] = self.ssm_state_cache.buffer[:, src_idx, ...] + return + + def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + assert src_buffer_index.dim() == 1 + assert dst_buffer_indexes.dim() == 2 + assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] + + # Use PyTorch advanced indexing for broadcast copy + # src_buffer_index: [num_src] + # dst_buffer_indexes: [num_src, num_dst_per_src] + src_idx = src_buffer_index.long() + dst_idx = dst_buffer_indexes.long() + + # Broadcast each source to all its destinations + # For each (src, dst_group), copy buffer[src] to buffer[dst1], buffer[dst2], ... + num_src, num_dst_per_src = dst_idx.shape + for i in range(num_src): + src = src_idx[i : i + 1] # Keep as 1D tensor with 1 element + dsts = dst_idx[i, :] # 1D tensor with num_dst_per_src elements + # Copy conv_state + self.conv_state_cache.buffer[:, dsts, ...] = self.conv_state_cache.buffer[:, src, ...] + # Copy ssm_state + self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] + return + + def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + """ + Broadcast ONLY SSM states (not conv states) from source indices to destination indices. + + This is used for MTP mode where each buffer maintains its own independent conv state, + but SSM states need to be synchronized. + """ + assert src_buffer_index.dim() == 1 + assert dst_buffer_indexes.dim() == 2 + assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] + + # Use PyTorch advanced indexing for SSM-only broadcast copy + src_idx = src_buffer_index.long() + dst_idx = dst_buffer_indexes.long() + + # Broadcast each source to all its destinations (SSM only) + num_src = dst_idx.shape[0] + for i in range(num_src): + src = src_idx[i : i + 1] + dsts = dst_idx[i, :] + # Only copy ssm_state, NOT conv_state + self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] + return + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """ + Free the allocated cache buffers and clear them. + + Args: + free_index: Buffer indices to free (tensor or list of ints) + """ + # Convert to tensor if needed for indexing + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=torch.long, device="cuda") + else: + free_index_tensor = free_index.to(device="cuda", dtype=torch.long) + + # Clear the buffers for the freed indices + # Shape: [layer_num, buffer_index, *shape] + self.conv_state_cache.buffer[:, free_index_tensor, ...] = 0 + self.ssm_state_cache.buffer[:, free_index_tensor, ...] = 0 + + # Call parent's free method to update allocator state + super().free(free_index) + return + + +class ReadOnlyStaticsMambaCacheManager: + """ + 读取一些统计信息 + """ + + def __init__(self) -> None: + args = get_env_start_args() + self.global_world_size = args.tp + self.node_world_size = args.tp // args.nnodes + self.dp_world_size = self.global_world_size // args.dp + # 兼容多机 dp size=1 纯 tp 模式的情况 + self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 + self.shared_tp_can_use_token_nums = [ + SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") + for rank_in_node in range(0, self.node_world_size, self.dp_world_size) + ] + + def get_unrefed_token_num(self, dp_rank_in_node: int): + if self.is_multinode_tp: + return self.shared_tp_can_use_token_nums[0].get_value() + return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 33bdca4475..573fe50842 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,8 +1,10 @@ import torch import collections +from lightllm.common.basemodel.triton_kernel.alloc_buffer_kernel import alloc_buffer_for_req_triton from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional + from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args @@ -93,6 +95,18 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def alloc_buffer_for_req(self, req_index: torch.Tensor): + """Allocate buffers for requests. No-op for standard models without linear attention.""" + pass + + def free_buffer(self, free_buffer_indexes): + """Free buffer memory. No-op for standard models without linear attention.""" + pass + + def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): + """Copy buffer state between requests. No-op for standard models without linear attention.""" + pass + class ReqSamplingParamsManager: """ @@ -232,3 +246,35 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): p_token_counts_tensor.cuda(non_blocking=True), p_cumsum_seq_len_tensor.cuda(non_blocking=True), ) + + +class ReqManagerForMamba(ReqManager): + def __init__(self, max_request_num, max_sequence_length, mem_manager): + from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager + + super().__init__(max_request_num, max_sequence_length, mem_manager) + self.mtp_step = get_env_start_args().mtp_step + self.buffer_mem_manager: MambaCacheManager = self.mem_manager.mamba_cache_mem_manager + self.req_to_buffer_index = torch.zeros( + (self.max_request_num + 1, self.mtp_step + 1), dtype=torch.int32, device="cuda" + ) + self.req_to_buffer_index[self.HOLD_REQUEST_ID, :] = self.buffer_mem_manager.HOLD_BUFFER_INDEX + + def free_buffer(self, free_buffer_indexes: List[int]): + self.buffer_mem_manager.free(free_buffer_indexes) + return + + def alloc_buffer_for_req(self, req_index: torch.Tensor): + num_reqs = req_index.shape[0] + num_buffers_per_req = self.mtp_step + 1 + buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) + alloc_buffer_for_req_triton(req_index, buffer_indexes, self.req_to_buffer_index, self.mtp_step) + + def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): + # 获取目标请求的所有 MTP buffer (从 buffer[0] 到 buffer[mtp_step]) + mtp_range = torch.arange(0, self.mtp_step + 1, dtype=torch.int32, device="cuda") + all_mtp_buffers = self.req_to_buffer_index[tgt_req_index[:, None], mtp_range[None, :]] + + # 将 shared buffer 广播到所有 MTP step + self.buffer_mem_manager.copy_buffer_broadcast(src_buffer_index, all_mtp_buffers) + return diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5d9216c2ea --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 128, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..338af08a1d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 2 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..4bc06d07d9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..f1159e4357 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..c8fa422e0c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..40cdc996b9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,12 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..a40eda35d4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 1 + }, + "1024": { + "num_warps": 8 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 1 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 1 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b08208be2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 1 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 1 + }, + "128": { + "num_warps": 1 + }, + "16": { + "num_warps": 1 + }, + "16384": { + "num_warps": 2 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 8 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 8 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..27e4804a61 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 64, + "num_stages": 2, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..7749b3601f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,12 @@ +{ + "2": { + "BK": 64, + "num_stages": 2, + "num_warps": 4 + }, + "4": { + "BK": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..49c4dc63d1 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..907575d960 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "32": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 64, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..f525d11257 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,70 @@ +{ + "1024": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "128": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "131072": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16": { + "BLOCK_N": 256, + "num_warps": 4 + }, + "1600": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "256": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "262144": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "4096": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 256, + "num_warps": 4 + }, + "64": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "65536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "8": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..198a196dfb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..537c7a90eb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..9a6dcb6fbf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4096": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..e5a383f23f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..56c79e3a43 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..4843ed8ccf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..3c0e605b00 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..d82ca44a21 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "320": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..96eabffc42 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..07e5e6875f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 512, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..ff4632955f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 4, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..89ab51ff8c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 4, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..f4d29554da --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 1, + "num_warps": 8 + }, + "100": { + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "num_stages": 5, + "num_warps": 2 + }, + "16384": { + "num_stages": 1, + "num_warps": 2 + }, + "2048": { + "num_stages": 2, + "num_warps": 2 + }, + "256": { + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "num_stages": 5, + "num_warps": 2 + }, + "4096": { + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "num_stages": 5, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..8605a91680 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 2, + "num_warps": 1 + }, + "100": { + "num_stages": 4, + "num_warps": 4 + }, + "1024": { + "num_stages": 3, + "num_warps": 2 + }, + "128": { + "num_stages": 2, + "num_warps": 1 + }, + "16": { + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "num_stages": 3, + "num_warps": 2 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 5, + "num_warps": 2 + }, + "32": { + "num_stages": 2, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 2 + }, + "64": { + "num_stages": 2, + "num_warps": 1 + }, + "8": { + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..12993b0231 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "10": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "163840": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "20480": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..e08a58baf5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "10": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "163840": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "20480": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 32ccbe8337..af13e34cd9 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -7,6 +7,8 @@ from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel from lightllm.models.internlm2.model import Internlm2TpPartModel @@ -38,4 +40,8 @@ ) from lightllm.models.gpt_oss.model import GptOssTpPartModel from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel +from lightllm.models.qwen3_5.model import ( + Qwen3_5TpPartModel, + Qwen3_5MOETpPartModel, +) from .registry import get_model, get_model_class diff --git a/lightllm/models/qwen3next/__init__.py b/lightllm/models/qwen3next/__init__.py new file mode 100644 index 0000000000..a9d22c6643 --- /dev/null +++ b/lightllm/models/qwen3next/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel + +__all__ = ["Qwen3NextTpPartModel"] diff --git a/lightllm/models/qwen3next/buffer_pool.py b/lightllm/models/qwen3next/buffer_pool.py new file mode 100644 index 0000000000..42c4bcafc7 --- /dev/null +++ b/lightllm/models/qwen3next/buffer_pool.py @@ -0,0 +1,83 @@ +# lightllm/models/qwen3next/buffer_pool.py +import torch +from typing import Dict, Tuple + + +class Qwen3NextBufferPool: + """ + Buffer pool for Qwen3Next inference to reduce allocations. + + NOT thread-safe. Each GPU worker process should have its own pool instance. + + Manages reusable buffers for: + - Attention norm outputs + - FFN norm outputs + - FFN intermediate activations + - GDN intermediate tensors + """ + + def __init__(self, enable_stats: bool = False, max_buffers: int = 64): + self._buffers: Dict[Tuple[tuple, torch.dtype, torch.device], torch.Tensor] = {} + self._in_use: set = set() + self._max_buffers = max_buffers + self._access_order: list = [] # Track LRU order + self._enable_stats = enable_stats + self._stats = {"hits": 0, "misses": 0, "peak_buffers": 0, "evictions": 0} if enable_stats else None + + def get_buffer( + self, + shape: Tuple[int, ...], + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """Get a buffer from the pool or allocate a new one.""" + key = (shape, dtype, device) + + # Check if we have a matching buffer not in use + if key in self._buffers and key not in self._in_use: + self._in_use.add(key) + # Update LRU order + if key in self._access_order: + self._access_order.remove(key) + self._access_order.append(key) + if self._enable_stats: + self._stats["hits"] += 1 + return self._buffers[key] + + # Evict oldest unused buffer if at capacity + if len(self._buffers) >= self._max_buffers: + self._evict_one() + + # Allocate new buffer + buffer = torch.empty(shape, dtype=dtype, device=device) + self._buffers[key] = buffer + self._in_use.add(key) + self._access_order.append(key) + if self._enable_stats: + self._stats["misses"] += 1 + self._stats["peak_buffers"] = max(self._stats["peak_buffers"], len(self._buffers)) + return buffer + + def _evict_one(self): + """Evict oldest unused buffer (LRU).""" + for key in self._access_order: + if key not in self._in_use and key in self._buffers: + del self._buffers[key] + self._access_order.remove(key) + if self._enable_stats: + self._stats["evictions"] += 1 + return + + def release_all(self): + """Release all buffers back to the pool (call after forward pass).""" + self._in_use.clear() + + def clear(self): + """Clear all buffers (call when changing batch size significantly).""" + self._buffers.clear() + self._in_use.clear() + self._access_order.clear() + + def get_stats(self): + """Return buffer pool statistics (if enabled).""" + return self._stats.copy() if self._stats else None diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py new file mode 100644 index 0000000000..2883534a93 --- /dev/null +++ b/lightllm/models/qwen3next/infer_struct.py @@ -0,0 +1,62 @@ +import torch +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3NextInferStateInfo(LlamaInferStateInfo): + """ + Inference state for Qwen3Next with: + - gate_value attribute for output gating in full attention layers + - MTP-aware batching for multi-token prediction + - Custom buffer management for hybrid attention (full + linear) + """ + + def __init__(self): + super().__init__() + # For output gating in full attention layers + self.gate_value = None + # MTP-aware attributes + self.b_att_seq_len = None + self.att_batch_size = None + self.real_req_idx = None + self.mtp_buffer_idx_list = None + self.b_buffer_idx = None + + def init_some_extra_state(self, model): + """Initialize Qwen3Next-specific state""" + super().init_some_extra_state(model) + + args_mtp_step = get_env_start_args().mtp_step + mtp_size = args_mtp_step + 1 + + if self.is_prefill: + # Prefill: Standard initialization + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() + else: + # Decode: MTP-aware handling + # In MTP mode, each request has (mtp_step + 1) tokens + # att_batch_size is the number of unique requests + self.att_batch_size = self.batch_size // mtp_size + + # Use only the sequence lengths for the last token of each MTP group + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step::mtp_size].contiguous() + self.real_req_idx = self.b_req_idx[args_mtp_step::mtp_size] + else: + self.b_att_seq_len = self.b_seq_len + self.real_req_idx = self.b_req_idx + + # Buffer indices for Mamba cache (conv and SSM states) + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.real_req_idx, :].flatten().contiguous() + + # Create per-step buffer indices for MTP + if args_mtp_step > 0: + buffer_idx_list = [] + for step_id in range(mtp_size): + buffer_idx_list.append(self.b_buffer_idx[step_id::mtp_size].tolist()) + self.mtp_buffer_idx_list = torch.tensor( + buffer_idx_list, dtype=torch.int32, device=self.b_buffer_idx.device + ) + + return diff --git a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..9dcab4e6fc --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py @@ -0,0 +1,12 @@ +import torch + +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextPostLayerInfer(LlamaPostLayerInfer): + def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) + return out diff --git a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py new file mode 100644 index 0000000000..2da106dbb2 --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py @@ -0,0 +1,101 @@ +# lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py +import torch.nn.functional as F +from functools import partial +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +import os + + +class SharedExpertFFNMixin: + """ + Mixin providing shared expert + MoE FFN implementations. + + Used by both full attention and GDN layers in Qwen3Next. + + Requirements: + - Class must have: embed_dim_, tp_world_size_, alloc_tensor() + - Class must have MoE config: is_moe, n_routed_experts, num_experts_per_tok, norm_topk_prob + """ + + def _bind_ffn(self): + """Bind FFN implementation based on MoE configuration.""" + if self.is_moe: + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(SharedExpertFFNMixin._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(SharedExpertFFNMixin._ffn_with_shared_expert_tp, self) + else: + self._ffn = partial(SharedExpertFFNMixin._standard_ffn, self) + return + + def _ffn_core(self, input, layer_weight): + """Core FFN computation: gate_up -> silu_and_mul -> down.""" + input = input.view(-1, self.embed_dim_) + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + + if hasattr(self, "buffer_pool") and self.buffer_pool: + ffn1_out = self.buffer_pool.get_buffer((input.size(0), up_gate_out.size(1) // 2), input.dtype, input.device) + else: + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) + + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + return ffn2_out, input + + def _standard_ffn(self, input, infer_state, layer_weight): + """Standard FFN using shared expert weights (non-MoE layers).""" + ffn2_out, _ = self._ffn_core(input, layer_weight) + return ffn2_out + + def _compute_shared_expert(self, input, layer_weight): + """Compute shared expert FFN output with gating.""" + ffn2_out, input_view = self._ffn_core(input, layer_weight) + return F.sigmoid(layer_weight.shared_expert_gate.mm(input_view)) * ffn2_out, input_view + + def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (tensor parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert(input, layer_weight) + moe_out = self._moe_ffn(input, infer_state, layer_weight) + return shared_expert_out + moe_out + + def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (expert parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert(input, layer_weight) + moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) + return shared_expert_out + moe_out + + def _moe_ffn(self, input, infer_state, layer_weight): + """MoE FFN with tensor parallelism.""" + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + ) + return hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_edp(self, input, infer_state, layer_weight): + """MoE FFN with expert parallelism.""" + hidden_states = input + token_num, hidden_dim = hidden_states.shape + + router_logits = layer_weight.moe_gate.mm(hidden_states) + ep_output = layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + + ep_output = ep_output.view(token_num, hidden_dim) + return ep_output diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..cd5fd67d53 --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -0,0 +1,1067 @@ +import os +import torch + +import torch.distributed as dist +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl +from lightllm.utils.log_utils import init_logger +from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from typing import Tuple +from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.gdn_decode_mtp import ( + copy_conv_states, + copy_ssm_states, + copy_states_fused, +) +from lightllm.distributed import all_reduce +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward +from lightllm.models.qwen3next.triton_kernel.fused_add_gemma_rmsnorm import fused_add_gemma_rmsnorm +from lightllm.models.qwen3next.triton_kernel.fused_split_copy import fused_split_copy_qkvzba, fused_split_copy_qkv +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type +from functools import partial + +logger = init_logger(__name__) + + +class GemmaRMSNormMixin: + """ + Mixin providing Gemma-style RMSNorm implementations. + + Requirements: + - Class must have: eps_, alloc_tensor() + """ + + def _gemma_norm_with_pool(self, input, norm_weight): + """Apply Gemma RMSNorm.""" + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, norm_weight, self.eps_, out=out) + return out + + +class Qwen3NextFullAttentionBaseLayerInfer(GemmaRMSNormMixin, LlamaTransformerLayerInfer): + """ + Base class for Qwen3Next full attention layers. + Contains shared logic for both standard full attention and MTP layers. + """ + + def __init__(self, layer_num, network_config): + # Store Qwen3Next specific configs before calling super().__init__ + self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0) + self.n_routed_experts = network_config.get("num_experts", 0) + self.is_moe = ( + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 + ) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) + self.norm_topk_prob = network_config.get("norm_topk_prob", False) + + super().__init__(layer_num, network_config) + # Override head_dim which may be different in Qwen3Next + self.head_dim_ = network_config.get( + "head_dim", network_config["hidden_size"] // network_config["num_attention_heads"] + ) + + # Pre-allocated decode buffers (mirrors GDN layer pattern) + start_args = get_env_start_args() + self._decode_buffers = {} + self._graph_max_batch_size = start_args.graph_max_batch_size + + # Pre-compute dims for decode buffer pre-allocation + self.shared_inter_size = network_config.get("shared_expert_intermediate_size", 0) + self.tp_gate_up_dim = 2 * self.shared_inter_size // self.tp_world_size_ if self.shared_inter_size > 0 else 0 + self.tp_q_gate_dim = (self.tp_q_head_num_ + self.tp_o_head_num_) * self.head_dim_ + self.tp_kv_dim = (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_ + + return + + def _get_decode_buffer(self, name, max_shape, dtype, device): + """Get or create a pre-allocated buffer for the decode path.""" + key = (name, dtype, device if isinstance(device, str) else str(device)) + if key not in self._decode_buffers: + self._decode_buffers[key] = torch.empty(max_shape, dtype=dtype, device=device) + return self._decode_buffers[key] + + def _bind_func(self): + super()._bind_func() + self._bind_ffn() + return + + def _bind_norm(self): + """Use Gemma-style RMSNorm""" + self._att_norm = partial(Qwen3NextFullAttentionBaseLayerInfer._att_norm_impl, self) + self._ffn_norm = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_norm_impl, self) + return + + def _bind_ffn(self): + """Bind FFN implementation based on MoE configuration.""" + if self.is_moe: + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_tp, self) + else: + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._standard_ffn, self) + return + + def _ffn_core(self, input, layer_weight, is_decode=False): + """Core FFN computation: gate_up -> silu_and_mul -> down.""" + input = input.view(-1, self.embed_dim_) + if is_decode and self.tp_gate_up_dim > 0: + up_gate_buf = self._get_decode_buffer( + "up_gate_out", + (self._graph_max_batch_size, self.tp_gate_up_dim), + input.dtype, + input.device, + )[: input.size(0)] + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input, out=up_gate_buf) + else: + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + inter_dim = up_gate_out.size(1) // 2 + if is_decode: + ffn1_out = self._get_decode_buffer( + "ffn1_out", (self._graph_max_batch_size, inter_dim), input.dtype, input.device + )[: input.size(0)] + else: + ffn1_out = self.alloc_tensor((input.size(0), inter_dim), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + return ffn2_out, input + + def _standard_ffn(self, input, infer_state, layer_weight): + """Standard FFN using shared expert weights (non-MoE layers).""" + ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) + return ffn2_out + + def _compute_shared_expert(self, input, layer_weight, is_decode=False): + """Compute shared expert FFN output with gating.""" + ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) + gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() + ffn2_out.mul_(gate) + return ffn2_out, input_view + + def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (tensor parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert( + input, layer_weight, is_decode=not infer_state.is_prefill + ) + moe_out = self._moe_ffn(input, infer_state, layer_weight) + moe_out.add_(shared_expert_out) + return moe_out + + def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (expert parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert( + input, layer_weight, is_decode=not infer_state.is_prefill + ) + moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) + moe_out.add_(shared_expert_out) + return moe_out + + def _moe_ffn(self, input, infer_state, layer_weight): + """MoE FFN with tensor parallelism.""" + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + if not infer_state.is_prefill: + router_buf = self._get_decode_buffer( + "router_logits", + (self._graph_max_batch_size, self.n_routed_experts), + hidden_states.dtype, + hidden_states.device, + )[:num_tokens] + router_logits = layer_weight.moe_gate.mm(hidden_states, out=router_buf) + else: + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + ) + return hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_edp(self, input, infer_state, layer_weight): + """MoE FFN with expert parallelism.""" + hidden_states = input + token_num, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + ep_output = layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + ep_output = ep_output.view(token_num, hidden_dim) + return ep_output + + def _att_norm_impl( + self, + input, + _infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + return self._gemma_norm_with_pool(input, layer_weight.att_norm_weight_.weight) + + def _ffn_norm_impl( + self, + input, + _infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + return self._gemma_norm_with_pool(input, layer_weight.ffn_norm_weight_.weight) + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + QKV projection with output gating, Q/K normalization, and partial rotary embedding. + """ + input = input.view(-1, self.embed_dim_) + # Single fused GEMM for both Q and output gate projections + if not infer_state.is_prefill: + q_gate_buf = self._get_decode_buffer( + "q_gate_out", + (self._graph_max_batch_size, self.tp_q_gate_dim), + input.dtype, + input.device, + )[: input.size(0)] + q_gate = layer_weight.q_gate_proj.mm(input, out=q_gate_buf) + kv_buf = self._get_decode_buffer( + "kv_out", + (self._graph_max_batch_size, self.tp_kv_dim), + input.dtype, + input.device, + )[: input.size(0)] + kv_out = layer_weight.kv_proj.mm(input, out=kv_buf) + else: + q_gate = layer_weight.q_gate_proj.mm(input) + kv_out = layer_weight.kv_proj.mm(input) + q_dim = self.tp_q_head_num_ * self.head_dim_ + q = q_gate[:, :q_dim].contiguous() + # In-place sigmoid saves one allocation (gate_value is consumed once in _get_o) + infer_state.gate_value = q_gate[:, q_dim:].sigmoid_() + cache_kv = kv_out.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + # Q normalization (in-place via out=input) + gemma_rmsnorm_forward( + q.view(-1, self.head_dim_), + layer_weight.q_norm_weight_.weight, + eps=self.eps_, + out=q.view(-1, self.head_dim_), + ) + + # K normalization + k_input = cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]) + if not infer_state.is_prefill: + k_normed = self._get_decode_buffer( + "k_norm_out", + (self._graph_max_batch_size * self.tp_k_head_num_, cache_kv.shape[-1]), + k_input.dtype, + k_input.device, + )[: k_input.shape[0]] + gemma_rmsnorm_forward(k_input, layer_weight.k_norm_weight_.weight, eps=self.eps_, out=k_normed) + else: + k_normed = gemma_rmsnorm_forward(k_input, layer_weight.k_norm_weight_.weight, eps=self.eps_) + cache_kv[:, : self.tp_k_head_num_, :] = k_normed.view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + + # Rotary embedding with partial rotation support + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, + ) + return q, cache_kv + + def _get_o( + self, + input, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + """Output projection with gating (in-place multiply to save one allocation).""" + input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) + input.mul_(infer_state.gate_value) + infer_state.gate_value = None + o_tensor = layer_weight.o_proj.mm(input) + return o_tensor + + def token_forward(self, input_embdings, infer_state, layer_weight): + """Override token_forward to use pre-allocated decode buffers and fused kernels.""" + max_tokens = self._graph_max_batch_size + input1 = self._get_decode_buffer( + "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device + )[: input_embdings.shape[0]] + gemma_rmsnorm_forward(input_embdings, layer_weight.att_norm_weight_.weight, self.eps_, out=input1) + + o = self.token_attention_forward(input1, infer_state, layer_weight) + + # Fused residual add + FFN norm: saves 1 kernel launch + 1 read of input_embdings + input1 = self._get_decode_buffer( + "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device + )[: input_embdings.shape[0]] + fused_add_gemma_rmsnorm( + input_embdings, + o.view(-1, self.embed_dim_), + layer_weight.ffn_norm_weight_.weight, + self.eps_, + out=input1, + ) + o = None + + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + +class Qwen3NextFullAttentionTransformerLayerInfer(Qwen3NextFullAttentionBaseLayerInfer): + """ + Full attention layer for Qwen3Next that uses the abstracted attention backend. + Inherits from Qwen3NextFullAttentionBaseLayerInfer to get shared Qwen3Next logic. + """ + + pass + + +class Qwen3NextGatedDeltaNetTransformerLayerInfer(GemmaRMSNormMixin, TransformerLayerInferTpl): + """ + Linear attention (Gated Delta Networks) layer for Qwen3Next. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.network_config_ = network_config + + # MoE configuration + self.n_routed_experts = network_config.get("num_experts", 0) + self.is_moe = ( + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 + ) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) + self.norm_topk_prob = network_config.get("norm_topk_prob", False) + self.shared_inter_size = network_config.get("shared_expert_intermediate_size", 0) + + # Standard layer dimensions + self.eps_ = network_config["rms_norm_eps"] + self.embed_dim_ = network_config["hidden_size"] + + # Linear attention specific dimensions + self.num_v_heads = network_config["linear_num_value_heads"] + self.num_k_heads = network_config["linear_num_key_heads"] + self.head_k_dim = network_config["linear_key_head_dim"] + self.head_v_dim = network_config["linear_value_head_dim"] + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_dim = network_config["linear_conv_kernel_dim"] + self.activation = network_config["hidden_act"] + + # Tensor parallelism dimensions + self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ + self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ + self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ + self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ + self.tp_key_dim = self.key_dim // self.tp_world_size_ + self.tp_value_dim = self.value_dim // self.tp_world_size_ + + # Template required dimensions (not used for GDN but required by interface) + self.tp_q_head_num_ = self.tp_num_k_heads + self.tp_k_head_num_ = self.tp_num_k_heads + self.tp_v_head_num_ = self.tp_num_v_heads + self.tp_o_head_num_ = self.tp_num_v_heads + self.head_dim_ = self.head_v_dim + + assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" + self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + # MTP configuration + self.mtp_step = get_env_start_args().mtp_step + self.mtp_size = self.mtp_step + 1 + + # SSM state dtype optimization + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + start_args = get_env_start_args() + self.ssm_state_dtype = ssm_dtype_dict.get(start_args.mamba_ssm_data_type, torch.bfloat16) + + # Pre-compute whether dtype conversion is needed + # GDN kernel output dtype is self.data_type + # Conversion needed only if SSM state uses different dtype + self.needs_ssm_dtype_conversion = get_llm_data_type() != self.ssm_state_dtype + + # Pre-allocated decode buffers to avoid repeated allocation during CUDA graph replay. + # Buffers are lazily allocated on first decode call, sized to graph_max_batch_size. + self._decode_buffers = {} + self._graph_max_batch_size = start_args.graph_max_batch_size + + # Pre-compute FFN dims for decode buffer pre-allocation + self.tp_gate_up_dim = 2 * self.shared_inter_size // self.tp_world_size_ if self.shared_inter_size > 0 else 0 + + self._bind_func() + return + + def _get_decode_buffer(self, name, max_shape, dtype, device): + """Get or create a pre-allocated buffer for the decode path. + + On first call, allocates a buffer at max_shape. On subsequent calls, + returns the same buffer (caller should slice to actual batch size). + """ + key = (name, dtype, device if isinstance(device, str) else str(device)) + if key not in self._decode_buffers: + self._decode_buffers[key] = torch.empty(max_shape, dtype=dtype, device=device) + return self._decode_buffers[key] + + def _bind_func(self): + """Bind layer-specific implementations""" + self._bind_norm() + self._bind_ffn() + return + + def _bind_norm(self): + """Use Gemma-style RMSNorm""" + self._att_norm = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._att_norm_impl, self) + self._ffn_norm = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_norm_impl, self) + return + + def _bind_ffn(self): + """Bind FFN implementation based on MoE configuration.""" + if self.is_moe: + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_with_shared_expert_tp, self) + else: + self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._standard_ffn, self) + return + + def _ffn_core(self, input, layer_weight, is_decode=False): + """Core FFN computation: gate_up -> silu_and_mul -> down.""" + input = input.view(-1, self.embed_dim_) + if is_decode and self.tp_gate_up_dim > 0: + up_gate_buf = self._get_decode_buffer( + "up_gate_out", + (self._graph_max_batch_size * self.mtp_size, self.tp_gate_up_dim), + input.dtype, + input.device, + )[: input.size(0)] + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input, out=up_gate_buf) + else: + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + inter_dim = up_gate_out.size(1) // 2 + if is_decode: + ffn1_out = self._get_decode_buffer( + "ffn1_out", (self._graph_max_batch_size, inter_dim), input.dtype, input.device + )[: input.size(0)] + else: + ffn1_out = self.alloc_tensor((input.size(0), inter_dim), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + return ffn2_out, input + + def _standard_ffn(self, input, infer_state, layer_weight): + """Standard FFN using shared expert weights (non-MoE layers).""" + ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) + return ffn2_out + + def _compute_shared_expert(self, input, layer_weight, is_decode=False): + """Compute shared expert FFN output with gating.""" + ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) + gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() + ffn2_out.mul_(gate) + return ffn2_out, input_view + + def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (tensor parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert( + input, layer_weight, is_decode=not infer_state.is_prefill + ) + moe_out = self._moe_ffn(input, infer_state, layer_weight) + moe_out.add_(shared_expert_out) + return moe_out + + def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (expert parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert( + input, layer_weight, is_decode=not infer_state.is_prefill + ) + moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) + moe_out.add_(shared_expert_out) + return moe_out + + def _moe_ffn(self, input, infer_state, layer_weight): + """MoE FFN with tensor parallelism.""" + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + if not infer_state.is_prefill: + router_buf = self._get_decode_buffer( + "router_logits", + (self._graph_max_batch_size * self.mtp_size, self.n_routed_experts), + hidden_states.dtype, + hidden_states.device, + )[:num_tokens] + router_logits = layer_weight.moe_gate.mm(hidden_states, out=router_buf) + else: + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + ) + return hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_edp(self, input, infer_state, layer_weight): + """MoE FFN with expert parallelism.""" + hidden_states = input + token_num, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + ep_output = layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + ep_output = ep_output.view(token_num, hidden_dim) + return ep_output + + def _att_norm_impl( + self, + input, + _infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + return self._gemma_norm_with_pool(input, layer_weight.att_norm_weight_.weight) + + def _ffn_norm_impl( + self, + input, + _infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + return self._gemma_norm_with_pool(input, layer_weight.ffn_norm_weight_.weight) + + def _get_qkv( + self, + _input: torch.Tensor, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Not used by GDN - QKV projection handled in gdn_forward. + + GDN uses a fused projection that includes z, b, a parameters + in addition to q, k, v, so the standard template flow doesn't apply. + This method exists to satisfy the template interface. + """ + pass # Implementation in gdn_forward + + def _tpsp_get_qkv( + self, + _input: torch.Tensor, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """TPSP mode not implemented for GDN layers.""" + pass # No TPSP support planned + + def _get_o( + self, + _input, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """ + Not used by GDN - output projection handled in gdn_forward. + + Output computation is fused with GDN recurrence in gdn_forward. + """ + pass # Implementation in gdn_forward + + def _tpsp_get_o( + self, + _input, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """TPSP mode not implemented for GDN layers.""" + pass # No TPSP support planned + + def _context_attention_kernel( + self, + _q: torch.Tensor, + _kv: torch.Tensor, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Not used by GDN - attention computed in gdn_forward.""" + pass # Implementation in gdn_forward + + def _token_attention_kernel( + self, + _q: torch.Tensor, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Not used by GDN - attention computed in gdn_forward.""" + pass # Implementation in gdn_forward + + def _gdn_layer_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + is_prefill: bool, + ): + """Unified forward for both prefill and decode in GDN layers.""" + # Attention + GDN processing + if is_prefill: + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + else: + # Decode: use pre-allocated buffer to avoid alloc_tensor overhead + max_tokens = self._graph_max_batch_size * self.mtp_size + input1 = self._get_decode_buffer( + "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device + )[: input_embdings.shape[0]] + gemma_rmsnorm_forward(input_embdings, layer_weight.att_norm_weight_.weight, self.eps_, out=input1) + + gdn_out = self.gdn_forward(input1, infer_state, layer_weight, is_prefill=is_prefill) + if self.tp_world_size_ > 1: + all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + + # FFN + if is_prefill: + input_embdings.add_(gdn_out.view(-1, self.embed_dim_)) + gdn_out = None + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + else: + # Decode: fused residual add + FFN norm saves 1 kernel + 1 read of input_embdings + input1 = self._get_decode_buffer( + "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device + )[: input_embdings.shape[0]] + fused_add_gemma_rmsnorm( + input_embdings, + gdn_out.view(-1, self.embed_dim_), + layer_weight.ffn_norm_weight_.weight, + self.eps_, + out=input1, + ) + gdn_out = None + + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def context_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Override context_forward to use GDN logic instead of standard attention flow.""" + return self._gdn_layer_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + + def token_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Override token_forward to use GDN logic instead of standard attention flow.""" + return self._gdn_layer_forward(input_embdings, infer_state, layer_weight, is_prefill=False) + + def overlap_tpsp_token_forward( + self, + input_embdings, + input_embdings1, + infer_state: Qwen3NextInferStateInfo, + infer_state1: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Microbatch overlap for decode: process two half-batches sequentially. + Enables --enable_decode_microbatch_overlap for GDN layers.""" + input_embdings = self.token_forward(input_embdings, infer_state, layer_weight) + input_embdings1 = self.token_forward(input_embdings1, infer_state1, layer_weight) + return input_embdings, input_embdings1 + + def overlap_tpsp_context_forward( + self, + input_embdings, + input_embdings1, + infer_state: Qwen3NextInferStateInfo, + infer_state1: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Microbatch overlap for context: process two half-batches sequentially.""" + input_embdings = self.context_forward(input_embdings, infer_state, layer_weight) + input_embdings1 = self.context_forward(input_embdings1, infer_state1, layer_weight) + return input_embdings, input_embdings1 + + # ==================== GDN Helper Methods ==================== + + def _fix_query_key_value_ba_ordering(self, mixed_qkvzba, is_decode=False): + """ + Extract q, k, v, z, b, a from the MM output. + + After weight rearrangement at load time, the MM output is already in grouped layout: + [all_q | all_k | all_v | all_z | all_b | all_a] + so this is just simple slicing — no split+reshape+cat needed. + + Note: + Decode fast-path fused split-copy kernels are intentionally avoided here. + The explicit contiguous slicing path is slower but is more robust and + matches the reference behavior used in vLLM. + """ + qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim + z_end = qkv_dim + self.tp_value_dim + b_end = z_end + self.tp_num_v_heads + + if is_decode: + mixed_qkv = mixed_qkvzba[:, :qkv_dim].contiguous() + z = mixed_qkvzba[:, qkv_dim:z_end].contiguous().view(-1, self.tp_num_v_heads, self.head_v_dim) + b = mixed_qkvzba[:, z_end:b_end].contiguous() + a = mixed_qkvzba[:, b_end:].contiguous() + else: + mixed_qkv = mixed_qkvzba[:, :qkv_dim] + # .reshape() handles non-contiguous slices by copying when needed (unlike .view()) + z = mixed_qkvzba[:, qkv_dim:z_end].reshape(-1, self.tp_num_v_heads, self.head_v_dim) + # b and a must be contiguous: fused_gdn_gating_kernel uses raw pointer arithmetic + # (off = i_b * NUM_HEADS + head_off) that assumes contiguous layout. + # Non-contiguous slices have stride[0]=total_dim, causing wrong reads for i_b > 0. + b = mixed_qkvzba[:, z_end:b_end].contiguous() + a = mixed_qkvzba[:, b_end:].contiguous() + + return mixed_qkv, z, b, a + + def _rearrange_mixed_qkv(self, mixed_qkv, decode=False): + if mixed_qkv is None: + return None, None, None + if decode: + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + batch_size = mixed_qkv.shape[0] + query = query.contiguous().view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + key = key.contiguous().view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + value = value.contiguous().view(batch_size, 1, self.tp_num_v_heads, self.head_v_dim) + return query, key, value + else: + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + seq_len = query.shape[0] + query = query.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + key = key.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim) + return query, key, value + + def context_attention_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + return gdn_out + + def token_attention_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=False) + return gdn_out + + def _gdn_prefill_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Prefill kernel for GDN forward pass.""" + # Conv1D processing + mixed_qkv = mixed_qkv.transpose(0, 1) + out_tensor = causal_conv1d_fn( + mixed_qkv, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + query_start_loc=infer_state.b1_cu_q_seq_len, + cache_indices=infer_state.b_buffer_idx, + has_initial_state=infer_state.b_ready_cache_len > 0, + conv_states=conv_states, + activation=self.activation, + ) + mixed_qkv = out_tensor.transpose(0, 1) + + # Recurrent processing + query, key, value = self._rearrange_mixed_qkv(mixed_qkv) + initial_state = ssm_states[infer_state.b_buffer_idx] + # g and beta have shape (total_tokens, num_heads), need to unsqueeze to get (1, total_tokens, num_heads) + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g.unsqueeze(0), + beta=beta.unsqueeze(0), + initial_state=initial_state, + output_final_state=True, + cu_seqlens=infer_state.b1_cu_q_seq_len, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Use pre-computed dtype conversion flag to avoid runtime check + if self.needs_ssm_dtype_conversion: + ssm_states[infer_state.b_buffer_idx] = last_recurrent_state.to(self.ssm_state_dtype, copy=False) + else: + ssm_states[infer_state.b_buffer_idx] = last_recurrent_state + return core_attn_out + + def _gdn_decode_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Decode kernel for GDN forward pass (single-token, non-MTP mode). + Uses fused gating: g/beta computed inline in the recurrent kernel.""" + # Conv1D processing — mixed_qkv is pre-copied to contiguous buffer + # by _fix_query_key_value_ba_ordering (causal_conv1d_update requires contiguous input) + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + activation=self.activation, + conv_state_indices=infer_state.b_buffer_idx, + ) + + # Recurrent processing with fused gating + # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) + core_attn_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + initial_state=ssm_states, + inplace_final_state=True, + ssm_state_indices=infer_state.b_buffer_idx, + use_qk_l2norm_in_kernel=True, + A_log=layer_weight.linear_A_log.weight, + dt_bias=layer_weight.linear_dt_bias.weight, + a_raw=a, + b_raw=b, + ) + return core_attn_out + + def _gdn_decode_mtp_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """ + Optimized decode kernel for GDN forward pass (MTP mode with multiple steps). + + Key optimizations: + 1. Uses pre-allocated work buffer to avoid per-step .contiguous() allocations + 2. Uses optimized flat Triton kernels for state copying + 3. Direct slice assignment for output instead of .copy_() + + Note: Sequential processing is required because each MTP step depends on + the previous step's final state (both conv and SSM states). + """ + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // self.mtp_size + + # Pre-allocate output tensor + core_attn_out = torch.empty( + (total_tokens, 1, self.tp_num_v_heads, self.head_v_dim), + dtype=mixed_qkv.dtype, + device=mixed_qkv.device, + ) + + # Pre-allocate work buffer for conv1d input (avoids per-step .contiguous()) + qkv_work_buffer = torch.empty( + (batch_size, mixed_qkv.shape[-1]), + dtype=mixed_qkv.dtype, + device=mixed_qkv.device, + ) + + # Process each MTP step sequentially (required due to state dependencies) + for step_idx in range(self.mtp_size): + cur_buffer_idx = infer_state.mtp_buffer_idx_list[step_idx] + + # ========== Conv1D processing ========== + # Copy strided data to contiguous work buffer + qkv_work_buffer.copy_(mixed_qkv[step_idx :: self.mtp_size]) + + # causal_conv1d_update operates in-place on contiguous input + causal_conv1d_update( + qkv_work_buffer, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + activation=self.activation, + conv_state_indices=cur_buffer_idx, + ) + + # ========== Recurrent processing ========== + query_i, key_i, value_i = self._rearrange_mixed_qkv(qkv_work_buffer, decode=True) + g_i = g[step_idx :: self.mtp_size].unsqueeze(1) + beta_i = beta[step_idx :: self.mtp_size].unsqueeze(1) + + core_attn_out_i, _ = fused_recurrent_gated_delta_rule( + q=query_i, + k=key_i, + v=value_i, + g=g_i, + beta=beta_i, + initial_state=ssm_states, + inplace_final_state=True, + ssm_state_indices=cur_buffer_idx, + use_qk_l2norm_in_kernel=True, + ) + + # Direct slice assignment (no .copy_() needed) + core_attn_out[step_idx :: self.mtp_size] = core_attn_out_i + + # ========== State propagation to next step ========== + if step_idx < self.mtp_step: + next_buffer_idx = infer_state.mtp_buffer_idx_list[step_idx + 1] + if conv_states.is_contiguous() and ssm_states.is_contiguous(): + copy_states_fused(conv_states, ssm_states, cur_buffer_idx, next_buffer_idx) + else: + copy_conv_states(conv_states, cur_buffer_idx, next_buffer_idx) + copy_ssm_states(ssm_states, cur_buffer_idx, next_buffer_idx) + + return core_attn_out + + def gdn_forward( + self, + input: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + is_prefill: bool, + ): + assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) + + # Common preprocessing + input = input.view(-1, self.embed_dim_) + conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) + + if not is_prefill: + # Decode: pre-allocate GEMM output to avoid cache tensor manager overhead + in_proj_out_dim = self.tp_qkvz_dim + self.tp_ba_dim + in_proj_out = self._get_decode_buffer( + "in_proj_out", + (self._graph_max_batch_size * self.mtp_size, in_proj_out_dim), + input.dtype, + input.device, + )[: input.shape[0]] + mixed_qkvzba = layer_weight.linear_in_proj.mm(input, out=in_proj_out) + else: + mixed_qkvzba = layer_weight.linear_in_proj.mm(input) + # mixed_qkv is now returned pre-concatenated (no torch.cat needed) + mixed_qkv, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba, is_decode=not is_prefill) + + # Dispatch to appropriate kernel + if is_prefill: + # Prefill: compute g/beta upfront (chunk kernel doesn't support fused gating) + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) + core_attn_out = self._gdn_prefill_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + elif self.mtp_step == 0: + # Decode (non-MTP): fuse gating into recurrent kernel to save 2 kernel launches + core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) + else: + # Decode (MTP): compute g/beta upfront (multiple recurrent calls per step) + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) + core_attn_out = self._gdn_decode_mtp_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + + # Common postprocessing + num_tokens = z.shape[0] # batch (decode) or total_tokens (prefill/MTP) + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + if not is_prefill: + # Decode: use pre-allocated buffer for norm output to avoid alloc_tensor + max_decode_tokens = self._graph_max_batch_size * self.mtp_size + flat_size = max_decode_tokens * self.tp_num_v_heads + norm_out = self._get_decode_buffer( + "gdn_norm_out", + (flat_size, self.head_v_dim), + core_attn_out.dtype, + core_attn_out.device, + )[: core_attn_out.shape[0]] + else: + norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) + gated_rmsnorm_forward( + core_attn_out, + layer_weight.linear_norm.weight, + None, # RMSNormWeight has no bias + self.eps_, + z, + out=norm_out, + ) + # Merge head and value dims in a single view: (num_tokens * HV, V) → (num_tokens, HV * V) + core_attn_out = norm_out.view(num_tokens, -1) + + output = layer_weight.linear_out_proj.mm(core_attn_out) + # Note: all_reduce is handled by context_forward/token_forward callers + return output diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..d4e16555d9 --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -0,0 +1,313 @@ +import torch +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + RMSNormWeight, + TpParameterWeight, + KVROWNMMWeight, +) + + +class Qwen3NextFullAttentionTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_qkv(self): + # Override parent's QKVROWNMMWeight which requires kv_head_num % tp == 0. + # Qwen3-Next has very few KV heads (e.g., 2) so we use separate q + kv weights. + # KVROWNMMWeight handles the kv_head_num < tp_world_size case via repeating. + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + # Define o_gate weight name here (used by _split_q_with_gate during load) + self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + # Fused Q + gate projection: single GEMM outputs [q, gate] concatenated + self.q_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim, q_out_dim], + weight_names=[self._q_weight_name, self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("q_proj"), + ) + self.kv_proj = KVROWNMMWeight( + in_dim=in_dim, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("kv_proj"), + ) + + def _init_weight(self): + super()._init_weight() + # Additional architecture (o_gate is now fused into q_gate_proj in _init_qkv) + self._init_gate_shared_expert_weight() + return + + def load_hf_weights(self, weights): + self._split_q_with_gate(weights) + super().load_hf_weights(weights) + + def _split_q_with_gate(self, weights): + if self._q_weight_name in weights: + weight = weights[self._q_weight_name] + num_heads = self.q_head_num_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self._q_weight_name] = _q_proj + weights[self._o_gate_weight_name] = _gate_proj + + def _init_gate_shared_expert_weight(self): + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + hidden_size = self.network_config_["hidden_size"] + shared_inter = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[shared_inter, shared_inter], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=shared_inter, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + + +class Qwen3NextGatedDeltaNetTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + self.is_moe = ( + network_config["num_experts"] > 0 + and layer_num not in network_config["mlp_only_layers"] + and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + ) + super().__init__(layer_num, data_type, network_config, quant_cfg) + + def _parse_config(self): + super()._parse_config() + self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] + self.linear_num_k_heads = self.network_config_["linear_num_key_heads"] + self.linear_k_head_dim = self.network_config_["linear_key_head_dim"] + self.linear_v_head_dim = self.network_config_["linear_value_head_dim"] + + def _init_weight(self): + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self._init_gdn_weight() + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + if self.is_moe: + self._init_moe() + else: + self._init_ffn() + self._init_gate_shared_expert_weight() + + def _init_gdn_weight(self): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + hidden_size = self.network_config_["hidden_size"] + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + conv1d_channels = qk_dim + qk_dim + v_dim # q + k + v concatenated + kernel_size = self.network_config_.get("linear_conv_kernel_dim", 4) + + # Conv1d weight: after _preprocess_weight, shape is [channels, kernel_size]. + # ROWMMWeight row-slices out_dims (rows), matching TP split of channels dim. + # causal_conv1d_fn expects weight shape (dim, width) = (channels_per_tp, kernel_size). + self.linear_conv1d = ROWMMWeight( + in_dim=kernel_size, + out_dims=[conv1d_channels], + weight_names=f"{prefix}.conv1d.weight", + data_type=self.data_type_, + quant_method=None, + ) + + # in_proj_qkvz: q(qk_dim) + k(qk_dim) + v(v_dim) + z(v_dim) + # in_proj_ba: beta(num_v_heads) + alpha(num_v_heads) — per-head scalars + qkvz_dim = qk_dim + qk_dim + v_dim + v_dim + ba_dim = self.linear_num_v_heads + self.linear_num_v_heads + self.linear_in_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[qkvz_dim, ba_dim], + weight_names=[f"{prefix}.in_proj_qkvz.weight", f"{prefix}.in_proj_ba.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("in_proj_weight"), + ) + + self.linear_out_proj = COLMMWeight( + in_dim=v_dim, + out_dims=[hidden_size], + weight_names=f"{prefix}.out_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("out_proj_weight"), + ) + + split_n_embed = self.linear_num_v_heads // self.tp_world_size_ + self.linear_dt_bias = TpParameterWeight( + weight_name=f"{prefix}.dt_bias", + data_type=torch.float32, + split_n_embed=split_n_embed, + bias_name=None, + weight_shape=(self.linear_num_v_heads,), # Full shape before TP split + bias_shape=None, + ) + + self.linear_A_log = TpParameterWeight( + weight_name=f"{prefix}.A_log", + data_type=torch.float32, + split_n_embed=split_n_embed, + bias_name=None, + weight_shape=(self.linear_num_v_heads,), # Full shape before TP split + bias_shape=None, + ) + + # Norm is applied per-head across head_dim, not across all heads + linear_norm_dim = self.linear_v_head_dim + self.linear_norm = RMSNormWeight( + dim=linear_norm_dim, + weight_name=f"{prefix}.norm.weight", + data_type=self.data_type_, + ) + + def load_hf_weights(self, weights): + self._preprocess_weight(weights) + return super().load_hf_weights(weights) + + def _preprocess_weight(self, weights): + linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" + linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" + if linear_conv1d_weight_name in weights: + # squeeze [channels, 1, kernel] -> [channels, kernel], then rearrange for TP + # Result shape: [channels, kernel_size] — matches causal_conv1d_fn's (dim, width) + weights[linear_conv1d_weight_name] = self._parse_linear_conv1d( + weights[linear_conv1d_weight_name].squeeze(1) + ) + if linear_conv1d_bias_name in weights: + weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name]) + self._rearrange_gdn_in_proj_weights(weights) + + def _rearrange_gdn_in_proj_weights(self, weights): + """Rearrange in_proj_qkvz and in_proj_ba weight rows from interleaved per-k-head layout + to TP-aware grouped layout so that after ROWMMWeight's row-slicing, each rank's + MM output is already [q_chunk, k_chunk, v_chunk, z_chunk, b_chunk, a_chunk]. + + This eliminates the expensive split+reshape+cat in _fix_query_key_value_ba_ordering + at inference time, replacing it with simple slicing. + + The key challenge is that ROWMMWeight slices each weight as a contiguous row chunk + (rows [start:end]). So we arrange the rows such that each TP chunk contains + the grouped layout for that rank: + 1. Deinterleave from per-k-head groups into per-component tensors + 2. Chunk each component by TP + 3. Reassemble as [q_tp0, k_tp0, v_tp0, z_tp0, q_tp1, k_tp1, ...] so row-slicing + gives each rank [q_chunk, k_chunk, v_chunk, z_chunk]. + Same pattern as _parse_linear_conv1d uses for conv1d weights. + """ + num_k = self.linear_num_k_heads + k_dim = self.linear_k_head_dim + v_dim = self.linear_v_head_dim + num_v_per_k = self.linear_num_v_heads // num_k + tp = self.tp_world_size_ + + # Rearrange in_proj_qkvz + qkvz_name = f"model.layers.{self.layer_num_}.linear_attn.in_proj_qkvz.weight" + if qkvz_name in weights: + w = weights[qkvz_name] + hidden = w.shape[-1] + # Each k-head group: q(k_dim) + k(k_dim) + v(num_v_per_k * v_dim) + z(num_v_per_k * v_dim) rows + group_size = k_dim + k_dim + num_v_per_k * v_dim + num_v_per_k * v_dim + w = w.view(num_k, group_size, hidden) + v_block = num_v_per_k * v_dim + all_q = w[:, :k_dim, :].reshape(-1, hidden) # [total_q_dim, H] + all_k = w[:, k_dim : 2 * k_dim, :].reshape(-1, hidden) # [total_k_dim, H] + all_v = w[:, 2 * k_dim : 2 * k_dim + v_block, :].reshape(-1, hidden) # [total_v_dim, H] + all_z = w[:, 2 * k_dim + v_block :, :].reshape(-1, hidden) # [total_v_dim, H] + # Chunk each component by TP, interleave so row-slicing gives grouped layout per rank + q_chunks = all_q.chunk(tp, dim=0) + k_chunks = all_k.chunk(tp, dim=0) + v_chunks = all_v.chunk(tp, dim=0) + z_chunks = all_z.chunk(tp, dim=0) + weights[qkvz_name] = torch.cat( + [torch.cat([q_chunks[i], k_chunks[i], v_chunks[i], z_chunks[i]], dim=0) for i in range(tp)], + dim=0, + ) + + # Rearrange in_proj_ba + ba_name = f"model.layers.{self.layer_num_}.linear_attn.in_proj_ba.weight" + if ba_name in weights: + w = weights[ba_name] + hidden = w.shape[-1] + group_size = 2 * num_v_per_k + w = w.view(num_k, group_size, hidden) + all_b = w[:, :num_v_per_k, :].reshape(-1, hidden) # [total_num_v, H] + all_a = w[:, num_v_per_k:, :].reshape(-1, hidden) # [total_num_v, H] + b_chunks = all_b.chunk(tp, dim=0) + a_chunks = all_a.chunk(tp, dim=0) + weights[ba_name] = torch.cat( + [torch.cat([b_chunks[i], a_chunks[i]], dim=0) for i in range(tp)], + dim=0, + ) + + def _parse_linear_conv1d(self, weight): + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + q_bias, k_bias, v_bias = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0) + q_splits = q_bias.chunk(self.tp_world_size_, dim=0) + k_splits = k_bias.chunk(self.tp_world_size_, dim=0) + v_splits = v_bias.chunk(self.tp_world_size_, dim=0) + new_weight = torch.cat( + [torch.cat([q_splits[i], k_splits[i], v_splits[i]], dim=0) for i in range(self.tp_world_size_)], dim=0 + ) + return new_weight + + def _init_gate_shared_expert_weight(self): + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + hidden_size = self.network_config_["hidden_size"] + shared_inter = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[shared_inter, shared_inter], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=shared_inter, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py new file mode 100644 index 0000000000..7ac7149a06 --- /dev/null +++ b/lightllm/models/qwen3next/mem_manager.py @@ -0,0 +1,72 @@ +import torch +from typing import Tuple +from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager + +logger = init_logger(__name__) + + +class Qwen3NextHybridMemManager(MemoryManager): + def __init__( + self, + full_attn_cache_size, + linear_attn_cache_size, + dtype, + num_kv_heads, + head_dim, + layer_num, + mtp_layer_num, + full_attention_interval: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + max_req_num: int, + always_copy=False, + mem_fraction=0.9, + ): + + self.full_attention_interval = full_attention_interval + assert layer_num % full_attention_interval == 0 + self.layer_num = layer_num + self.mtp_layer_num = mtp_layer_num + self.full_attn_layer_num = layer_num // full_attention_interval + self.linear_attn_layer_num = layer_num - self.full_attn_layer_num + + self.mamba_cache_mem_manager = MambaCacheManager( + linear_attn_cache_size, + self.linear_attn_layer_num, + conv_state_dtype, + conv_state_shape, + ssm_state_dtype, + ssm_state_shape, + ) + + super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., + # None, kv_cache, mtp_kv_cache, mtp_kv_cache] + # Only full attention layers and MTP layers have KV cache. + self.kv_buffer = [None for _ in range(self.layer_num)] + for layer_id in range(self.full_attn_layer_num): + self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( + (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" + ) + for _ in range(self.mtp_layer_num): + self.kv_buffer.append(torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")) + + def free_all(self): + super().free_all() + self.mamba_cache_mem_manager.free_all() + return + + def get_cell_size(self): + # Only full attention layers and MTP layers have KV cache + kv_cache_layer_num = self.full_attn_layer_num + self.mtp_layer_num + return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype) + + def get_mamba_cache(self, layer_idx: int): + layer_idx_in_linear = layer_idx - (layer_idx // self.full_attention_interval) + return self.mamba_cache_mem_manager.get_mamba_cache(layer_idx_in_linear) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py new file mode 100644 index 0000000000..1234a659ed --- /dev/null +++ b/lightllm/models/qwen3next/model.py @@ -0,0 +1,157 @@ +import torch +from typing import Optional +import triton +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( + Qwen3NextFullAttentionTransformerLayerInfer, + Qwen3NextGatedDeltaNetTransformerLayerInfer, +) +from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.common.req_manager import ReqManagerForMamba +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights + +logger = init_logger(__name__) + + +@ModelRegistry("qwen3_next") +class Qwen3NextTpPartModel(Qwen3MOEModel): + + post_layer_infer_class = Qwen3NextPostLayerInfer + infer_state_class = Qwen3NextInferStateInfo + + is_hybrid_attention = True # Indicates model uses hybrid (full + linear) attention + use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states + + @classmethod + def get_radix_cache_class(cls): + from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache + + return HybridRadixCache + + def __init__(self, kvargs) -> None: + self.mem_manager: Qwen3NextHybridMemManager = None + + def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch.Tensor: + return torch.empty(size, device="cuda", dtype=torch.int8) + + # Set Triton allocator for TMA descriptors + # This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py + triton.set_allocator(_triton_allocator) + logger.info("Triton allocator set for Qwen3Next model") + super().__init__(kvargs) + + def autotune_layers(self): + return self.config["full_attention_interval"] + + def _init_config(self): + super()._init_config() + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + + def _init_mem_manager(self): + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + + start_args: StartArgs = get_env_start_args() + mamba_cache_size = start_args.mamba_cache_size + if mamba_cache_size is not None: + assert ( + mamba_cache_size >= start_args.running_max_req_size + ), "mamba_cache_size must be greater than running_max_req_size" + + self.num_linear_k_heads = self.config["linear_num_key_heads"] + self.num_linear_v_heads = self.config["linear_num_value_heads"] + self.head_linear_k_dim = self.config["linear_key_head_dim"] + self.head_linear_v_dim = self.config["linear_value_head_dim"] + + conv_kernel_size = self.config["linear_conv_kernel_dim"] + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) + + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + if start_args.mamba_ssm_data_type not in ssm_dtype_dict: + raise ValueError( + f"Invalid mamba_ssm_data_type: {start_args.mamba_ssm_data_type}." + f" Must be one of {list(ssm_dtype_dict.keys())}" + ) + + self.mem_manager = Qwen3NextHybridMemManager( + full_attn_cache_size=self.max_total_token_num, + linear_attn_cache_size=mamba_cache_size, + dtype=self.data_type, + num_kv_heads=self.num_kv_heads, + head_dim=self.config["head_dim"], + layer_num=self.config["n_layer"], + mtp_layer_num=start_args.mtp_step, + full_attention_interval=self.config["full_attention_interval"], + conv_state_dtype=self.data_type, + conv_state_shape=(conv_dim // self.tp_world_size_, conv_kernel_size - 1), + ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type], + ssm_state_shape=( + self.num_linear_v_heads // self.tp_world_size_, + self.head_linear_k_dim, + self.head_linear_v_dim, + ), + max_req_num=self.max_req_num, + mem_fraction=self.mem_fraction, + ) + + def _init_req_manager(self): + create_max_seq_len = 0 + + if self.batch_max_tokens is not None: + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) + if self.max_seq_length is not None: + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) + + self.req_manager = ReqManagerForMamba(self.max_req_num, create_max_seq_len, self.mem_manager) + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + self.trans_layers_weight = [ + ( + Qwen3NextFullAttentionTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + if (i + 1) % num_full_attention_layers == 0 + else Qwen3NextGatedDeltaNetTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + ) + for i in range(self.config["n_layer"]) + ] + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + + self.layers_infer = [ + ( + Qwen3NextFullAttentionTransformerLayerInfer(i, network_config=self.config) + if (i + 1) % num_full_attention_layers == 0 + else Qwen3NextGatedDeltaNetTransformerLayerInfer(i, network_config=self.config) + ) + for i in range(self.config["n_layer"]) + ] diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py new file mode 100644 index 0000000000..c6d099a2d8 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py @@ -0,0 +1,122 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d.py + +from typing import Optional + +import torch + +from sgl_kernel import causal_conv1d_fwd +from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = -1, + **kwargs, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + causal_conv1d_fwd( + x, + weight, + bias, + conv_states, + query_start_loc, + cache_indices, + has_initial_state, + activation in ["silu", "swish"], + pad_slot_id, + ) + return x + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError(f"activation must be None, silu, or swish, actual: {activation}") + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + causal_conv1d_update_kernel( + x, + conv_state, + weight, + bias, + activation_val, + cache_seqlens, + conv_state_indices, + pad_slot_id, + ) + if unsqueeze: + x = x.squeeze(-1) + return x diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py new file mode 100644 index 0000000000..2bde70bb99 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Adapted from +# https://github.com/vllm-project/vllm diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py new file mode 100644 index 0000000000..cd3b0962a3 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py new file mode 100644 index 0000000000..7b3067bbfb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=64, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=64, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2, + ) + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py new file mode 100644 index 0000000000..97933b2ac2 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp, safe_exp +from lightllm.common.triton_utils.autotuner import autotune + +NUM_WARPS = [2, 4, 8, 16] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + if SAVE_NEW_VALUE: + v_new += ((bos * H + i_h) * V).to(tl.int64) + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v = b_v * safe_exp(b_g_last - b_g)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k1, + mask=(o_k1 < K), + other=0.0, + ) + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k2, + mask=(o_k2 < K), + other=0.0, + ) + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k3, + mask=(o_k3 < K), + other=0.0, + ) + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k4, + mask=(o_k4 < K), + other=0.0, + ) + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v) + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_delta_h_configs(): + return [ + {"BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ] + + +def _get_chunk_delta_h_static_key(k, u, chunk_size): + B, T, Hg, K = k.shape + V = u.shape[-1] + H = u.shape[-2] + return {"H": H, "K": K, "V": V, "BT": chunk_size} + + +def _get_chunk_delta_h_run_key(k, u): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] + + +@autotune( + kernel_name="chunk_gated_delta_rule_fwd_h", + configs_gen_func=_get_chunk_delta_h_configs, + static_key_func=_get_chunk_delta_h_static_key, + run_key_func=_get_chunk_delta_h_run_key, +) +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + run_config=None, +) -> tuple[torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + # Extract config parameters + if run_config is None: + run_config = {"BV": 64, "num_warps": 2, "num_stages": 2} + + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, + ) + return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py new file mode 100644 index 0000000000..fc49763ecd --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp, safe_exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper +from lightllm.common.triton_utils.autotuner import autotune + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_o_configs(): + return [ + {"BK": BK, "BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_o_static_key(q, v, chunk_size): + B, T, Hg, K = q.shape + V = v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + return {"H": H, "K": K, "V": V, "BT": BT} + + +def _get_chunk_o_run_key(q, v): + # Return batch * heads as run key + return q.shape[0] * q.shape[2] + + +@autotune( + kernel_name="chunk_fwd_o", + configs_gen_func=_get_chunk_o_configs, + static_key_func=_get_chunk_o_static_key, + run_key_func=_get_chunk_o_run_key, +) +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor | None = None, # cumsum of log decay + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + run_config=None, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "BV": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, + ) + return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000..60a594c078 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import safe_exp +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * safe_exp(b_g_diff) + + b_A *= b_beta[:, None] + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_scaled_dot_kkt_configs(): + return [ + {"BK": BK, "num_warps": num_warps, "num_stages": num_stages} + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_scaled_dot_kkt_static_key(k, beta, chunk_size=64, cu_seqlens=None): + B, T, Hg, K = k.shape + H = beta.shape[-1] + IS_VARLEN = cu_seqlens is not None + return {"H": H, "K": K, "BT": chunk_size, "IS_VARLEN": IS_VARLEN} + + +def _get_chunk_scaled_dot_kkt_run_key(k, beta): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] + + +@autotune( + kernel_name="chunk_scaled_dot_kkt_fwd", + configs_gen_func=_get_chunk_scaled_dot_kkt_configs, + static_key_func=_get_chunk_scaled_dot_kkt_static_key, + run_key_func=_get_chunk_scaled_dot_kkt_run_key, +) +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + g: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, + run_config=None, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K = k.shape + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + BK=BK, + num_warps=num_warps, + num_stages=num_stages, + ) + return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py new file mode 100644 index 0000000000..6331e1602d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard +from lightllm.common.triton_utils.autotuner import autotune + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + else: + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_cumsum_scalar_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8]] + + +def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first): + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} + + +def _get_cumsum_scalar_run_key(g): + # Return total number of elements as run key + return g.shape[0] * g.shape[1] + + +@autotune( + kernel_name="chunk_local_cumsum_scalar", + configs_gen_func=_get_cumsum_scalar_configs, + static_key_func=_get_cumsum_scalar_static_key, + run_key_func=_get_cumsum_scalar_run_key, +) +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + run_config=None, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"num_warps": 2} + + num_warps = run_config.get("num_warps", 2) + + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=num_warps, + ) + return g + + +def _get_cumsum_vector_configs(): + return [{"BS": BS, "num_warps": num_warps} for BS in BS_LIST for num_warps in [2, 4, 8]] + + +def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first): + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "S": S, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} + + +def _get_cumsum_vector_run_key(g): + # Return batch * heads as run key + return g.shape[0] * g.shape[2] if len(g.shape) == 4 else g.shape[0] + + +@autotune( + kernel_name="chunk_local_cumsum_vector", + configs_gen_func=_get_cumsum_vector_configs, + static_key_func=_get_cumsum_vector_static_key, + run_key_func=_get_cumsum_vector_run_key, +) +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + run_config=None, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"BS": 32, "num_warps": 2} + + BS = run_config.get("BS", 32) + num_warps = run_config.get("num_warps", 2) + + grid = (triton.cdiv(S, BS), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=num_warps, + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + else: + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py new file mode 100644 index 0000000000..22a93a2c99 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -0,0 +1,492 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .op import exp + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + "HAS_SEPARATE_WRITE_INDICES": lambda args: args["ssm_state_write_indices"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + ssm_state_write_indices, # NEW: separate write indices for state propagation optimization + num_accepted_tokens, + # Fused gating parameters (only used when FUSE_GATING=True) + A_log, # [HV] per-head log decay + dt_bias, # [HV] per-head dt bias + a_raw, # [B*T, HV] raw alpha values (before softplus) + b_raw, # [B*T, HV] raw beta values (before sigmoid) + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + stride_write_indices_seq: tl.constexpr, # NEW: stride for write indices + stride_write_indices_tok: tl.constexpr, # NEW: stride for write indices + SOFTPLUS_BETA: tl.constexpr, # softplus beta parameter (default 1.0) + SOFTPLUS_THRESHOLD: tl.constexpr, # softplus threshold (default 20.0) + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, + HAS_SEPARATE_WRITE_INDICES: tl.constexpr, # NEW: whether to use separate write indices + FUSE_GATING: tl.constexpr, # whether to compute g/beta inline from raw values +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if FUSE_GATING: + # Fused gating: load per-head constants once, compute g/beta inline per token + b_A_log = tl.load(A_log + i_hv).to(tl.float32) + b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32) + p_a_raw = a_raw + bos * HV + i_hv + p_b_raw = b_raw + bos * HV + i_hv + else: + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token + ) + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + if FUSE_GATING: + # Compute g = -exp(A_log) * softplus(a_raw + dt_bias) inline + b_a = tl.load(p_a_raw).to(tl.float32) + x = b_a + b_dt_bias + softplus_x = tl.where( + SOFTPLUS_BETA * x <= SOFTPLUS_THRESHOLD, + (1.0 / SOFTPLUS_BETA) * tl.log(1.0 + tl.exp(SOFTPLUS_BETA * x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x + b_h *= exp(b_g) + # Compute beta = sigmoid(b_raw) inline + b_b = tl.load(p_b_raw).to(tl.float32) + b_beta = tl.sigmoid(b_b) + else: + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + # Use separate write indices if provided (for state propagation optimization) + # Otherwise fall back to read indices + if HAS_SEPARATE_WRITE_INDICES: + write_idx = tl.load(ssm_state_write_indices + i_n * stride_write_indices_seq + i_t).to(tl.int64) + else: + write_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) + p_ht = ht + write_idx * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + if FUSE_GATING: + p_a_raw += HV + p_b_raw += HV + else: + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + # Fused gating parameters + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK = triton.next_power_of_2(K) + if T == 1: + # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16) + # and more warps for better SM utilization at T=1 where there's no pipelining benefit + BV = min(triton.next_power_of_2(V), 32) + num_warps = 4 + num_stages = 1 + else: + # Prefill path: small BV for better pipelining across sequence length + BV = min(triton.next_power_of_2(V), 8) + num_warps = 1 + num_stages = 3 + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + + fuse_gating = A_log is not None + + if out is not None: + o = out.unsqueeze(0) if out.ndim == v.ndim else out + else: + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + # Strides for read indices + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + # Strides for write indices (if provided) + if ssm_state_write_indices is None: + stride_write_indices_seq, stride_write_indices_tok = 1, 1 + elif ssm_state_write_indices.ndim == 1: + stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 + else: + stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + ssm_state_write_indices=ssm_state_write_indices, + num_accepted_tokens=num_accepted_tokens, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw, + b_raw=b_raw, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + stride_write_indices_seq=stride_write_indices_seq, + stride_write_indices_tok=stride_write_indices_tok, + SOFTPLUS_BETA=1.0, + SOFTPLUS_THRESHOLD=20.0, + IS_BETA_HEADWISE=False if fuse_gating else (beta.ndim == v.ndim), + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, + FUSE_GATING=fuse_gating, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous() if g is not None else None, + beta=beta.contiguous() if beta is not None else None, + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + ssm_state_write_indices=ssm_state_write_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw.contiguous() if a_raw is not None else None, + b_raw=b_raw.contiguous() if b_raw is not None else None, + out=out, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor = None, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices for state propagation + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + # Fused gating: pass raw values to compute g/beta inline in the kernel + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + fuse_gating = A_log is not None + if not fuse_gating and beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + ssm_state_write_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + A_log, + dt_bias, + a_raw, + b_raw, + out, + ) + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py new file mode 100644 index 0000000000..8b1d59fc63 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py new file mode 100644 index 0000000000..29f892ef26 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +import torch + +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def _get_l2norm_kernel1_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16, 32]] + + +def _get_l2norm_kernel1_static_key(x): + D = x.shape[-1] + return {"D": D} + + +def _get_l2norm_kernel1_run_key(x): + return x.shape[0] # T + + +@autotune( + kernel_name="l2norm_fwd_kernel1", + configs_gen_func=_get_l2norm_kernel1_configs, + static_key_func=_get_l2norm_kernel1_static_key, + run_key_func=_get_l2norm_kernel1_run_key, +) +def _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD, run_config=None): + if run_config is None: + run_config = {"num_warps": 4} + + num_warps = run_config.get("num_warps", 4) + T = x.shape[0] + + l2norm_fwd_kernel1[(T,)](x, y, eps=eps, D=D, BD=BD, num_warps=num_warps) + + +def _get_l2norm_kernel_configs(): + return [{"BT": BT, "num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST] + + +def _get_l2norm_kernel_static_key(x): + D = x.shape[-1] + return {"D": D} + + +def _get_l2norm_kernel_run_key(x): + return x.shape[0] # T + + +@autotune( + kernel_name="l2norm_fwd_kernel", + configs_gen_func=_get_l2norm_kernel_configs, + static_key_func=_get_l2norm_kernel_static_key, + run_key_func=_get_l2norm_kernel_run_key, +) +def _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB, run_config=None): + if run_config is None: + run_config = {"BT": 32, "num_warps": 4} + + BT = run_config.get("BT", 32) + num_warps = run_config.get("num_warps", 4) + + grid = (triton.cdiv(T, BT),) + l2norm_fwd_kernel[grid](x, y, eps, NB=NB, T=T, D=D, BT=BT, BD=BD, num_warps=num_warps) + + +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB) + else: + _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD) + + return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py new file mode 100644 index 0000000000..2f69aa981d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import triton +import triton.language as tl + +from .utils import is_gather_supported + +exp = tl.exp +log = tl.log +log2 = tl.log2 + + +@triton.jit +def safe_exp(x): + """ + Numerically stable exponential function. + Only applies exp to non-positive values, returns 0 for positive values. + This prevents numerical overflow and improves stability. + """ + return exp(tl.where(x <= 0, x, float("-inf"))) + + +if not is_gather_supported: + + @triton.jit + def gather(src, index, axis, _builder=None): + """ + Gather operation that works when tl.gather is not supported. + This is a fallback implementation that returns None. + Just to make triton compiler happy. + """ + return None + +else: + gather = tl.gather + +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, "make_tensor_descriptor"): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py new file mode 100644 index 0000000000..b5b6cfc369 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -0,0 +1,462 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import os +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import make_tensor_descriptor +from .utils import input_guard, is_amd, is_tma_supported + + +def _ensure_triton_allocator(): + """Ensure Triton has an allocator set for kernels requiring scratch memory.""" + + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + +FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") +ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"] +assert ( + FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS +), f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}" + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + A = A + (bos * H + i_h) * BT + Ai = Ai + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + if not USE_TMA: + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + # [16, 16] + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16]) + b_A = desc.load([i_t * 16, offset]).to(tl.float32) + b_A = -tl.where(m_A, b_A, 0) + + for i in range(2, min(16, T - i_t * 16)): + # [16] + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + b_A = tl.where((o_i == i)[:, None], b_a, b_A) + b_A += m_I + if not USE_TMA: + p_Ai = tl.make_block_ptr(Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + + b_Ai_11 += m_I + b_Ai_22 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_A_33 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_A_44 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) + b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) + b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) + b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + for i in range(32 + 2, min(48, T - i_t * BT)): + b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32) + b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) + b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) + for i in range(48 + 2, min(64, T - i_t * BT)): + b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48) + b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) + b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) + b_Ai_11 += m_I + b_Ai_22 += m_I + b_Ai_33 += m_I + b_Ai_44 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) + b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) + b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) + b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) + b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + b_Ai_32 = -tl.dot( + tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), + b_Ai_22, + input_precision=DOT_PRECISION, + ) + b_Ai_43 = -tl.dot( + tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), + b_Ai_33, + input_precision=DOT_PRECISION, + ) + + b_Ai_31 = -tl.dot( + b_Ai_33, + tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_42 = -tl.dot( + b_Ai_44, + tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_41 = -tl.dot( + b_Ai_44, + tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@input_guard +def solve_tril( + A: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + """ + Compute the inverse of the matrix I + A + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, BT], where BT should only be 16, 32, or 64. + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. Default: `None`. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + output_dtype = A.dtype if output_dtype is None else output_dtype + + B, T, H, BT = A.shape + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + Ai = torch.zeros_like(A, dtype=output_dtype) + if BT == 16: + merge_fn = solve_tril_16x16_kernel + elif BT == 32: + merge_fn = merge_16x16_to_32x32_inverse_kernel + elif BT == 64: + merge_fn = merge_16x16_to_64x64_inverse_kernel + + # Ensure Triton allocator is set for TMA kernels that require scratch memory + if is_tma_supported: + _ensure_triton_allocator() + + merge_fn[NT, B * H]( + A=A, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + USE_TMA=is_tma_supported, + DOT_PRECISION=FLA_TRIL_PRECISION, + ) + return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py new file mode 100644 index 0000000000..cd7c2e3aeb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from collections.abc import Callable +from enum import Enum +from typing import Any, Literal + +import torch + +import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[tuple | None, dict | None, Any] = [] + cache_size = 8 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cpu" + + +@functools.cache +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = "cuda" +device_torch_lib = getattr(torch, device, None) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = True +is_gather_supported = hasattr(triton.language, "gather") +is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") +) + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py new file mode 100644 index 0000000000..08bb00e644 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: torch.LongTensor | None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py new file mode 100644 index 0000000000..6413158a66 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py @@ -0,0 +1,186 @@ +import torch + +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _fused_add_gemma_rmsnorm_kernel( + x_ptr, + r_ptr, + w_ptr, + y_ptr, + x_stride0, + x_stride1, + r_stride0, + r_stride1, + y_stride0, + y_stride1, + N: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused in-place residual add + Gemma RMSNorm. + + For each row: + 1. sum = x + residual (written back to x in-place) + 2. rstd = 1 / sqrt(mean(sum²) + eps) + 3. y = sum * rstd * (w + 1.0) (Gemma-style) + """ + row = tl.program_id(0) + x_ptr = x_ptr + row * x_stride0 + r_ptr = r_ptr + row * r_stride0 + y_ptr = y_ptr + row * y_stride0 + + # Pass 1: compute sum = x + residual, write back to x, accumulate sum² for variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + r = tl.load(r_ptr + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) + s = x + r + # Write sum back to x (in-place residual add) + tl.store(x_ptr + cols * x_stride1, s.to(x_ptr.dtype.element_ty), mask=mask) + _var += s * s + + var = tl.sum(_var, axis=0) / N + rstd = 1.0 / tl.sqrt(var + EPS) + + # Pass 2: normalize and apply Gemma-style linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + # Re-read x (now contains sum); hot in L2 from the write in pass 1 + s = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) + y = s * rstd * (w + 1.0) + tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) + + +def _get_fused_add_gemma_rmsnorm_configs(): + """Generate configurations for autotuning fused add + Gemma RMSNorm kernel.""" + configs = [] + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: + for num_warps in [1, 2, 4, 8]: + configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1}) + return configs + + +def _get_fused_add_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + N = x.shape[-1] + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(w.dtype), + "N": N, + } + + +@autotune( + kernel_name="fused_add_gemma_rmsnorm:v1", + configs_gen_func=_get_fused_add_gemma_rmsnorm_configs, + static_key_func=_get_fused_add_gemma_rmsnorm_static_key, + run_key_func=lambda x: x.shape[-1], + mutates_args=["x"], +) +def fused_add_gemma_rmsnorm(x, residual, w, eps, out=None, run_config: dict = None): + """Fused in-place residual add + Gemma RMSNorm. + + x: [M, N] - modified in-place (x += residual) + residual: [M, N] - residual to add (will be viewed as [-1, N]) + w: [N] - norm weight (Gemma-style: applies w + 1.0) + eps: float + out: [M, N] - output buffer (allocated if None) + Returns: out + """ + N = x.shape[-1] + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, N) + r_arg = residual.view(-1, N) + y_arg = y.view(-1, N) + + M = x_arg.shape[0] + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This fused_add_gemma_rmsnorm doesn't support feature dim >= 64KB.") + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} + + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + _fused_add_gemma_rmsnorm_kernel[(M,)]( + x_arg, + r_arg, + w, + y_arg, + x_stride0=x_arg.stride(0), + x_stride1=x_arg.stride(1), + r_stride0=r_arg.stride(0), + r_stride1=r_arg.stride(1), + y_stride0=y_arg.stride(0), + y_stride1=y_arg.stride(1), + N=N, + EPS=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y + + +def _fused_add_gemma_rmsnorm_torch(x, residual, weight, eps): + """Reference implementation for correctness testing.""" + original_dtype = x.dtype + x = x.to(torch.float32) + residual = residual.to(torch.float32) + s = x + residual + normed = s * torch.rsqrt(s.pow(2).mean(-1, keepdim=True) + eps) + out = normed * (1.0 + weight.float()) + return s.to(original_dtype), out.to(original_dtype) + + +def test_fused_add_gemma_rmsnorm(M=128, N=2048, dtype=torch.bfloat16, eps=1e-5, device="cuda"): + """Verify fused kernel matches separate add + gemma_rmsnorm.""" + x_shape = (M, N) + w_shape = (N,) + weight = torch.rand(w_shape, dtype=dtype, device=device) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + residual = 0.1 * torch.randn(x_shape, dtype=dtype, device=device) + + # Clone x for reference (since fused modifies x in-place) + x_ref = x.clone() + x_fused = x.clone() + + # Reference: separate add + norm + x_ref_sum, y_ref = _fused_add_gemma_rmsnorm_torch(x_ref, residual, weight, eps) + + # Fused kernel + y_fused = fused_add_gemma_rmsnorm(x_fused, residual, weight, eps) + + # Check x was modified in-place (x += residual) + print(f"Test: M={M}, N={N}, dtype={dtype}") + print(f" x in-place max delta: {torch.max(torch.abs(x_fused - x_ref_sum)):.6e}") + print(f" output max delta: {torch.max(torch.abs(y_fused - y_ref)):.6e}") + + atol = 1e-2 if dtype == torch.float32 else 5e-2 + assert torch.allclose(x_fused, x_ref_sum, atol=atol, rtol=0), "x in-place update mismatch!" + assert torch.allclose(y_fused, y_ref, atol=atol, rtol=0), "output mismatch!" + print(" PASSED") + + +if __name__ == "__main__": + test_fused_add_gemma_rmsnorm(M=1, N=2048) + test_fused_add_gemma_rmsnorm(M=128, N=2048) + test_fused_add_gemma_rmsnorm(M=1, N=2048, dtype=torch.float16) + test_fused_add_gemma_rmsnorm(M=64, N=4096, dtype=torch.float32) + print("All tests passed!") diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py new file mode 100644 index 0000000000..c816a20013 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -0,0 +1,87 @@ +# Adapted from https://github.com/sgl-project/sglang/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +# beta_output = b.sigmoid() +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_d = tl.program_id(0), tl.program_id(1) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask) + + +def _get_fused_gdn_gating_configs(): + return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [4, 8, 16, 32, 64] for nw in [1, 2, 4]] + + +def _get_fused_gdn_gating_static_key(a: torch.Tensor): + # group by head size and input dtype + return {"NUM_HEADS": a.shape[1], "a_dtype": str(a.dtype)} + + +@autotune( + kernel_name="fused_gdn_gating:v1", + configs_gen_func=_get_fused_gdn_gating_configs, + static_key_func=_get_fused_gdn_gating_static_key, + run_key_func=lambda a: a.shape[0], +) +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + if run_config is None: + run_config = {"BLK_HEADS": 8, "num_warps": 1} + + batch, num_heads = a.shape + grid = (batch, triton.cdiv(num_heads, run_config["BLK_HEADS"])) + g = torch.empty(batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(batch, num_heads, dtype=torch.float32, device=a.device) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + num_heads, + beta, + threshold, + run_config["BLK_HEADS"], + num_warps=run_config["num_warps"], + ) + return g, beta_output diff --git a/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py new file mode 100644 index 0000000000..f37d4911af --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py @@ -0,0 +1,163 @@ +""" +Fused QKV projection and GDN gating computation. + +This kernel fuses: +1. Linear projection (matmul with weight) +2. Output reorganization (split and reshape) +3. Gating computation (g and beta from a, b) + +This reduces kernel launches from 3 to 1 for the QKV+gating path. +""" + +import torch +import triton +import triton.language as tl +from typing import Tuple, Optional +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _fused_gdn_gating_only_kernel( + # Output pointers + g_ptr, + beta_ptr, + # Input pointers + a_ptr, + b_ptr, + A_log_ptr, + dt_bias_ptr, + # Dimensions + batch_size, + num_heads, + # Constants + beta_const: tl.constexpr, + threshold: tl.constexpr, + BLOCK_BATCH: tl.constexpr, + BLOCK_HEADS: tl.constexpr, +): + """ + Fused kernel for GDN gating computation with better memory access patterns. + + Computes: + - g = -exp(A_log) * softplus(a + dt_bias) + - beta = sigmoid(b) + """ + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + + batch_offs = pid_batch * BLOCK_BATCH + tl.arange(0, BLOCK_BATCH) + head_offs = pid_head * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS) + + batch_mask = batch_offs < batch_size + head_mask = head_offs < num_heads + mask = batch_mask[:, None] & head_mask[None, :] + + # Load A_log and dt_bias (broadcast across batch) + A_log = tl.load(A_log_ptr + head_offs, mask=head_mask, other=0.0) + dt_bias = tl.load(dt_bias_ptr + head_offs, mask=head_mask, other=0.0) + + # Load a and b + offs = batch_offs[:, None] * num_heads + head_offs[None, :] + a = tl.load(a_ptr + offs, mask=mask, other=0.0) + b = tl.load(b_ptr + offs, mask=mask, other=0.0) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + x = a.to(tl.float32) + dt_bias.to(tl.float32) + softplus_x = tl.where(beta_const * x <= threshold, (1.0 / beta_const) * tl.log(1.0 + tl.exp(beta_const * x)), x) + g = -tl.exp(A_log.to(tl.float32)) * softplus_x + + # Compute beta = sigmoid(b) + beta_out = tl.sigmoid(b.to(tl.float32)) + + # Store outputs with layout [1, batch, num_heads] + out_offs = batch_offs[:, None] * num_heads + head_offs[None, :] + tl.store(g_ptr + out_offs, g.to(g_ptr.dtype.element_ty), mask=mask) + tl.store(beta_ptr + out_offs, beta_out.to(beta_ptr.dtype.element_ty), mask=mask) + + +def _get_fused_gating_configs(): + """Generate autotuning configurations.""" + configs = [] + for block_batch in [1, 4, 8, 16]: + for block_heads in [8, 16, 32]: + for num_warps in [2, 4, 8]: + configs.append( + { + "BLOCK_BATCH": block_batch, + "BLOCK_HEADS": block_heads, + "num_warps": num_warps, + } + ) + return configs + + +def _get_fused_gating_static_key(a: torch.Tensor): + return {"dtype": str(a.dtype), "num_heads": a.shape[1]} + + +def _get_fused_gating_run_key(a: torch.Tensor): + return a.shape[0] + + +@autotune( + kernel_name="fused_gdn_gating_v2:v1", + configs_gen_func=_get_fused_gating_configs, + static_key_func=_get_fused_gating_static_key, + run_key_func=_get_fused_gating_run_key, + mutates_args=["g", "beta"], +) +def fused_gdn_gating_v2( + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + beta_const: float = 1.0, + threshold: float = 20.0, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Optimized GDN gating with pre-allocated output tensors. + + Args: + a: Input tensor [batch, num_heads] + b: Input tensor [batch, num_heads] + A_log: Log of A parameter [num_heads] + dt_bias: Bias for dt [num_heads] + g: Output tensor [1, batch, num_heads] (pre-allocated) + beta: Output tensor [1, batch, num_heads] (pre-allocated) + beta_const: Beta constant for softplus (default: 1.0) + threshold: Threshold for softplus approximation (default: 20.0) + run_config: Optional autotuning configuration + + Returns: + Tuple of (g, beta) - same tensors passed in, now filled + """ + batch_size, num_heads = a.shape + + if run_config is None: + run_config = {"BLOCK_BATCH": 8, "BLOCK_HEADS": 16, "num_warps": 4} + + grid = ( + triton.cdiv(batch_size, run_config["BLOCK_BATCH"]), + triton.cdiv(num_heads, run_config["BLOCK_HEADS"]), + ) + + _fused_gdn_gating_only_kernel[grid]( + g, + beta, + a, + b, + A_log, + dt_bias, + batch_size, + num_heads, + beta_const, + threshold, + run_config["BLOCK_BATCH"], + run_config["BLOCK_HEADS"], + num_warps=run_config["num_warps"], + ) + + return g, beta diff --git a/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py b/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py new file mode 100644 index 0000000000..5f4433fb34 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py @@ -0,0 +1,400 @@ +""" +Fused Split-Copy Triton Kernels for GDN Decode Path + +Replaces multiple separate .copy_() calls with single kernel launches to reduce +kernel launch overhead in the decode hot path (36 GDN layers per step). + +Kernel 1 (fused_split_copy_qkvzba): 4 copies → 1 kernel + Splits GEMM output [batch, total_dim] into qkv, z, b, a destination buffers. + +Kernel 2 (fused_split_copy_qkv): 3 copies → 1 kernel + Splits conv1d output [batch, qkv_dim] into q, k, v destination buffers. + Handles non-contiguous source (stride(0) != total_dim from column slicing). +""" + +import torch +import triton +import triton.language as tl + + +# ============================================================================= +# Kernel 1: Fused split-copy for qkv, z, b, a from GEMM output +# ============================================================================= + + +@triton.jit +def _fused_split_copy_qkvzba_kernel( + # Source pointer (contiguous GEMM output) + src_ptr, + # Destination pointers (pre-allocated contiguous buffers) + dst_qkv_ptr, + dst_z_ptr, + dst_b_ptr, + dst_a_ptr, + # Row strides + src_stride0, + dst_qkv_stride0, + dst_z_stride0, + dst_b_stride0, + dst_a_stride0, + # Segment boundaries (cumulative): [0, qkv_dim) [qkv_dim, z_end) [z_end, b_end) [b_end, total_dim) + qkv_dim, + z_end, + b_end, + total_dim, + # Block size + BLOCK_N: tl.constexpr, +): + """ + One program per (row, column_block). Loads a BLOCK_N chunk from the source row, + then conditionally stores to the correct destination based on column position. + + Grid: (batch, cdiv(total_dim, BLOCK_N)) + """ + row = tl.program_id(0) + col_block = tl.program_id(1) + + col_start = col_block * BLOCK_N + cols = col_start + tl.arange(0, BLOCK_N) + mask = cols < total_dim + + # Load source chunk + data = tl.load(src_ptr + row * src_stride0 + cols, mask=mask) + + # Store to qkv destination: columns [0, qkv_dim) + qkv_mask = mask & (cols < qkv_dim) + tl.store(dst_qkv_ptr + row * dst_qkv_stride0 + cols, data, mask=qkv_mask) + + # Store to z destination: columns [qkv_dim, z_end) + z_mask = mask & (cols >= qkv_dim) & (cols < z_end) + tl.store(dst_z_ptr + row * dst_z_stride0 + (cols - qkv_dim), data, mask=z_mask) + + # Store to b destination: columns [z_end, b_end) + b_mask = mask & (cols >= z_end) & (cols < b_end) + tl.store(dst_b_ptr + row * dst_b_stride0 + (cols - z_end), data, mask=b_mask) + + # Store to a destination: columns [b_end, total_dim) + a_mask = mask & (cols >= b_end) + tl.store(dst_a_ptr + row * dst_a_stride0 + (cols - b_end), data, mask=a_mask) + + +def fused_split_copy_qkvzba( + src: torch.Tensor, + dst_qkv: torch.Tensor, + dst_z: torch.Tensor, + dst_b: torch.Tensor, + dst_a: torch.Tensor, + qkv_dim: int, + z_dim: int, + b_dim: int, + a_dim: int, +): + """ + Fused split-copy from GEMM output into 4 contiguous destination buffers. + + Replaces: + conv_buf.copy_(mixed_qkvzba[:, :qkv_dim]) + z_buf.view(batch, -1).copy_(mixed_qkvzba[:, qkv_dim:z_end]) + b_buf.copy_(mixed_qkvzba[:, z_end:b_end]) + a_buf.copy_(mixed_qkvzba[:, b_end:]) + + Args: + src: [batch, total_dim] contiguous source (GEMM output) + dst_qkv: [batch, qkv_dim] contiguous destination for conv1d input + dst_z: [batch, z_dim] contiguous destination (z_buf viewed flat) + dst_b: [batch, b_dim] contiguous destination + dst_a: [batch, a_dim] contiguous destination + qkv_dim: width of qkv segment (tp_key_dim * 2 + tp_value_dim) + z_dim: width of z segment (tp_value_dim) + b_dim: width of b segment (tp_num_v_heads) + a_dim: width of a segment (tp_num_v_heads) + """ + total_dim = qkv_dim + z_dim + b_dim + a_dim + z_end = qkv_dim + z_dim + b_end = z_end + b_dim + + batch = src.shape[0] + BLOCK_N = 128 + num_col_blocks = triton.cdiv(total_dim, BLOCK_N) + + grid = (batch, num_col_blocks) + + _fused_split_copy_qkvzba_kernel[grid]( + src, + dst_qkv, + dst_z, + dst_b, + dst_a, + src.stride(0), + dst_qkv.stride(0), + dst_z.stride(0), + dst_b.stride(0), + dst_a.stride(0), + qkv_dim, + z_end, + b_end, + total_dim, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + + +# ============================================================================= +# Kernel 2: Fused split-copy for q, k, v from conv1d output +# ============================================================================= + + +@triton.jit +def _fused_split_copy_qkv_kernel( + # Source pointer (may be non-contiguous column slice) + src_ptr, + # Destination pointers (contiguous buffers) + dst_q_ptr, + dst_k_ptr, + dst_v_ptr, + # Row strides + src_stride0, + dst_q_stride0, + dst_k_stride0, + dst_v_stride0, + # Segment boundaries: [0, q_dim) [q_dim, qk_end) [qk_end, total_dim) + q_dim, + qk_end, + total_dim, + # Block size + BLOCK_N: tl.constexpr, +): + """ + One program per (row, column_block). Loads a BLOCK_N chunk from the source row, + then conditionally stores to q, k, or v destination. + + Supports non-contiguous source via src_stride0 (stride may be > total_dim + when source is a column slice of a larger tensor). + + Grid: (batch, cdiv(total_dim, BLOCK_N)) + """ + row = tl.program_id(0) + col_block = tl.program_id(1) + + col_start = col_block * BLOCK_N + cols = col_start + tl.arange(0, BLOCK_N) + mask = cols < total_dim + + # Load source chunk (use src_stride0 for row advancement) + data = tl.load(src_ptr + row * src_stride0 + cols, mask=mask) + + # Store to q destination: columns [0, q_dim) + q_mask = mask & (cols < q_dim) + tl.store(dst_q_ptr + row * dst_q_stride0 + cols, data, mask=q_mask) + + # Store to k destination: columns [q_dim, qk_end) + k_mask = mask & (cols >= q_dim) & (cols < qk_end) + tl.store(dst_k_ptr + row * dst_k_stride0 + (cols - q_dim), data, mask=k_mask) + + # Store to v destination: columns [qk_end, total_dim) + v_mask = mask & (cols >= qk_end) + tl.store(dst_v_ptr + row * dst_v_stride0 + (cols - qk_end), data, mask=v_mask) + + +def fused_split_copy_qkv( + src: torch.Tensor, + dst_q: torch.Tensor, + dst_k: torch.Tensor, + dst_v: torch.Tensor, + q_dim: int, + k_dim: int, + v_dim: int, + src_stride0: int, +): + """ + Fused split-copy from conv1d output into 3 contiguous q/k/v buffers. + + Replaces: + q_split, k_split, v_split = torch.split(mixed_qkv, [...], dim=-1) + q_buf.view(batch, -1).copy_(q_split) + k_buf.view(batch, -1).copy_(k_split) + v_buf.view(batch, -1).copy_(v_split) + + Args: + src: [batch, total_dim] source tensor (may be non-contiguous if column slice) + dst_q: [batch, q_dim] contiguous destination + dst_k: [batch, k_dim] contiguous destination + dst_v: [batch, v_dim] contiguous destination + q_dim: width of q segment (tp_key_dim) + k_dim: width of k segment (tp_key_dim) + v_dim: width of v segment (tp_value_dim) + src_stride0: row stride of source (may be > q_dim+k_dim+v_dim) + """ + total_dim = q_dim + k_dim + v_dim + qk_end = q_dim + k_dim + + batch = src.shape[0] + BLOCK_N = 128 + num_col_blocks = triton.cdiv(total_dim, BLOCK_N) + + grid = (batch, num_col_blocks) + + _fused_split_copy_qkv_kernel[grid]( + src, + dst_q, + dst_k, + dst_v, + src_stride0, + dst_q.stride(0), + dst_k.stride(0), + dst_v.stride(0), + q_dim, + qk_end, + total_dim, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + + +# ============================================================================= +# Test / Verification +# ============================================================================= + + +def test_fused_split_copy(): + """Verify fused kernels produce identical results to separate .copy_() calls.""" + torch.manual_seed(42) + device = "cuda" + dtype = torch.bfloat16 + + print("=" * 60) + print("Testing fused_split_copy_qkvzba") + print("=" * 60) + + # Typical dimensions for Qwen3-Coder-Next with TP=4 + # tp_key_dim=128, tp_value_dim=256, tp_num_v_heads=2 + qkv_dim = 128 + 128 + 256 # q + k + v = 512 + z_dim = 256 + b_dim = 2 + a_dim = 2 + total_dim = qkv_dim + z_dim + b_dim + a_dim # 772 + + for batch in [1, 4, 8, 32]: + src = torch.randn(batch, total_dim, dtype=dtype, device=device) + + # Reference: separate copies + ref_qkv = src[:, :qkv_dim].clone() + ref_z = src[:, qkv_dim : qkv_dim + z_dim].clone() + ref_b = src[:, qkv_dim + z_dim : qkv_dim + z_dim + b_dim].clone() + ref_a = src[:, qkv_dim + z_dim + b_dim :].clone() + + # Fused kernel + dst_qkv = torch.empty(batch, qkv_dim, dtype=dtype, device=device) + dst_z = torch.empty(batch, z_dim, dtype=dtype, device=device) + dst_b = torch.empty(batch, b_dim, dtype=dtype, device=device) + dst_a = torch.empty(batch, a_dim, dtype=dtype, device=device) + fused_split_copy_qkvzba(src, dst_qkv, dst_z, dst_b, dst_a, qkv_dim, z_dim, b_dim, a_dim) + + assert torch.equal(dst_qkv, ref_qkv), f"qkv mismatch at batch={batch}" + assert torch.equal(dst_z, ref_z), f"z mismatch at batch={batch}" + assert torch.equal(dst_b, ref_b), f"b mismatch at batch={batch}" + assert torch.equal(dst_a, ref_a), f"a mismatch at batch={batch}" + print(f" batch={batch:3d}: PASS") + + print() + print("=" * 60) + print("Testing fused_split_copy_qkv") + print("=" * 60) + + q_dim = 128 + k_dim = 128 + v_dim = 256 + qkv_dim = q_dim + k_dim + v_dim # 512 + + for batch in [1, 4, 8, 32]: + # Test with contiguous source + src = torch.randn(batch, qkv_dim, dtype=dtype, device=device) + + ref_q = src[:, :q_dim].clone() + ref_k = src[:, q_dim : q_dim + k_dim].clone() + ref_v = src[:, q_dim + k_dim :].clone() + + dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) + dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) + dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) + fused_split_copy_qkv(src, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src.stride(0)) + + assert torch.equal(dst_q, ref_q), f"q mismatch at batch={batch} (contiguous)" + assert torch.equal(dst_k, ref_k), f"k mismatch at batch={batch} (contiguous)" + assert torch.equal(dst_v, ref_v), f"v mismatch at batch={batch} (contiguous)" + print(f" batch={batch:3d} (contiguous src): PASS") + + # Test with non-contiguous source (column slice of wider tensor) + wider = torch.randn(batch, qkv_dim + 64, dtype=dtype, device=device) + src_nc = wider[:, :qkv_dim] # Non-contiguous: stride(0) = qkv_dim + 64 + assert src_nc.stride(0) == qkv_dim + 64, "expected non-contiguous slice" + + ref_q = src_nc[:, :q_dim].clone() + ref_k = src_nc[:, q_dim : q_dim + k_dim].clone() + ref_v = src_nc[:, q_dim + k_dim :].clone() + + dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) + dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) + dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) + fused_split_copy_qkv(src_nc, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src_nc.stride(0)) + + assert torch.equal(dst_q, ref_q), f"q mismatch at batch={batch} (non-contiguous)" + assert torch.equal(dst_k, ref_k), f"k mismatch at batch={batch} (non-contiguous)" + assert torch.equal(dst_v, ref_v), f"v mismatch at batch={batch} (non-contiguous)" + print(f" batch={batch:3d} (non-contiguous src): PASS") + + print() + print("=" * 60) + print("Testing edge cases") + print("=" * 60) + + # Edge case: different dimension ratios (small q/k, large v) + q_dim, k_dim, v_dim = 32, 32, 512 + qkv_dim = q_dim + k_dim + v_dim + batch = 2 + src = torch.randn(batch, qkv_dim, dtype=dtype, device=device) + + dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) + dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) + dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) + fused_split_copy_qkv(src, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src.stride(0)) + + assert torch.equal(dst_q, src[:, :q_dim]) + assert torch.equal(dst_k, src[:, q_dim : q_dim + k_dim]) + assert torch.equal(dst_v, src[:, q_dim + k_dim :]) + print(" asymmetric dims (32, 32, 512): PASS") + + # Edge case: float32 dtype + src_f32 = torch.randn(4, 772, dtype=torch.float32, device=device) + dst_qkv = torch.empty(4, 512, dtype=torch.float32, device=device) + dst_z = torch.empty(4, 256, dtype=torch.float32, device=device) + dst_b = torch.empty(4, 2, dtype=torch.float32, device=device) + dst_a = torch.empty(4, 2, dtype=torch.float32, device=device) + fused_split_copy_qkvzba(src_f32, dst_qkv, dst_z, dst_b, dst_a, 512, 256, 2, 2) + + assert torch.equal(dst_qkv, src_f32[:, :512]) + assert torch.equal(dst_z, src_f32[:, 512:768]) + assert torch.equal(dst_b, src_f32[:, 768:770]) + assert torch.equal(dst_a, src_f32[:, 770:]) + print(" float32 dtype: PASS") + + # Edge case: float16 dtype + src_f16 = torch.randn(4, 772, dtype=torch.float16, device=device) + dst_qkv = torch.empty(4, 512, dtype=torch.float16, device=device) + dst_z = torch.empty(4, 256, dtype=torch.float16, device=device) + dst_b = torch.empty(4, 2, dtype=torch.float16, device=device) + dst_a = torch.empty(4, 2, dtype=torch.float16, device=device) + fused_split_copy_qkvzba(src_f16, dst_qkv, dst_z, dst_b, dst_a, 512, 256, 2, 2) + + assert torch.equal(dst_qkv, src_f16[:, :512]) + assert torch.equal(dst_z, src_f16[:, 512:768]) + assert torch.equal(dst_b, src_f16[:, 768:770]) + assert torch.equal(dst_a, src_f16[:, 770:]) + print(" float16 dtype: PASS") + + print() + print("All tests passed!") + + +if __name__ == "__main__": + test_fused_split_copy() diff --git a/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py new file mode 100644 index 0000000000..89db5e00cb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py @@ -0,0 +1,174 @@ +import triton +import triton.language as tl +import torch +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + } +) +@triton.jit +def gated_rmsnorm_forward_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch (required, not optional) + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + Z += row * stride_z_row + group * N + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute variance (RMS norm doesn't use mean) + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + # RMS norm: compute variance directly without mean subtraction + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + # RMS norm: normalize without mean subtraction + x_hat = x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _get_gated_rmsnorm_configs(): + """Generate configurations for autotuning gated RMSNorm kernel.""" + configs = [] + # Different BLOCK_N sizes (powers of 2) + for block_n in [64, 128, 256, 512, 1024, 2048, 4096]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + # Skip configurations that are likely to be inefficient + if block_n >= 2048 and num_warps > 4: + continue + if block_n <= 128 and num_warps > 2: + continue + configs.append({"BLOCK_N": block_n, "num_warps": num_warps}) + return configs + + +def _get_gated_rmsnorm_static_key(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + M, N = x.shape + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(weight.dtype), + "N": N, + "has_bias": bias is not None, + } + + +@autotune( + kernel_name="gated_rmsnorm_forward:v1", + configs_gen_func=_get_gated_rmsnorm_configs, + static_key_func=_get_gated_rmsnorm_static_key, + run_key_func=lambda x: x.shape[0], +) +def gated_rmsnorm_forward( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + run_config: dict = None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + # z is required for gated_rmsnorm + assert z is not None, "z cannot be None for gated_rmsnorm_forward" + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + # For RMS norm, we still need rstd for the kernel + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + run_config = {"BLOCK_N": BLOCK_N, "num_warps": num_warps} + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + + # Validate BLOCK_N against group_size + if group_size > BLOCK_N: + # Fall back to largest valid BLOCK_N + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + grid = (M, ngroups) + gated_rmsnorm_forward_kernel[grid]( + x, + out, + weight, + bias, + z, + rstd, + x.stride(0), + out.stride(0), + z.stride(0), + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + num_warps=num_warps, + ) + return out diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py new file mode 100644 index 0000000000..5a39debaa9 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py @@ -0,0 +1,1333 @@ +""" +Optimized GDN Decode MTP (Multi-Token Prediction) Kernel + +This module provides an optimized Triton kernel for GDN decode with MTP support, +eliminating the need for sequential Python loops and reducing memory operations. + +Key optimizations: +1. Fused data reorganization from interleaved to batched layout +2. Parallel processing of all batch items with proper state indexing +3. Auto-tuned configurations for different batch sizes and model dimensions +""" + +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _reorganize_mtp_data_kernel( + # Input pointers (interleaved layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...]) + src_ptr, + # Output pointers (batched layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...]) + dst_ptr, + # Dimensions + batch_size, + mtp_size, + dim_size, + # Strides + src_stride_token, + src_stride_dim, + dst_stride_token, + dst_stride_dim, + # Block sizes + BLOCK_DIM: tl.constexpr, +): + """ + Reorganize data from interleaved MTP layout to batched layout. + + Input layout: [step0_batch0, step0_batch1, ..., step0_batchN, step1_batch0, ...] + Output layout: [batch0_step0, batch0_step1, ..., batch0_stepM, batch1_step0, ...] + + This enables efficient processing with the recurrent kernel. + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + block_dim_idx = tl.program_id(2) + + # Calculate source and destination token indices + src_token_idx = step_idx * batch_size + batch_idx + dst_token_idx = batch_idx * mtp_size + step_idx + + # Calculate dimension offsets + dim_start = block_dim_idx * BLOCK_DIM + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + mask = dim_offsets < dim_size + + # Load from source (interleaved layout) + src_offset = src_token_idx * src_stride_token + dim_offsets * src_stride_dim + data = tl.load(src_ptr + src_offset, mask=mask, other=0.0) + + # Store to destination (batched layout) + dst_offset = dst_token_idx * dst_stride_token + dim_offsets * dst_stride_dim + tl.store(dst_ptr + dst_offset, data, mask=mask) + + +@triton.jit +def _reorganize_mtp_data_back_kernel( + # Input pointers (batched layout): [batch_size, mtp_size, num_heads, head_dim] + src_ptr, + # Output pointers (interleaved layout): [total_tokens, 1, num_heads, head_dim] + dst_ptr, + # Dimensions + batch_size, + mtp_size, + num_heads, + head_dim, + # Strides for src: [batch_size, mtp_size, num_heads, head_dim] + src_stride_batch, + src_stride_mtp, + src_stride_head, + src_stride_dim, + # Strides for dst: [total_tokens, 1, num_heads, head_dim] + dst_stride_token, + dst_stride_seq, + dst_stride_head, + dst_stride_dim, + # Block sizes + BLOCK_HEAD: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """ + Reorganize output data from batched layout back to interleaved layout. + + Input shape: [batch_size, mtp_size, num_heads, head_dim] + Output shape: [batch_size * mtp_size, 1, num_heads, head_dim] (interleaved) + + Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + block_idx = tl.program_id(2) + + # Decompose block_idx into head and dim blocks + num_dim_blocks = tl.cdiv(head_dim, BLOCK_DIM) + block_head_idx = block_idx // num_dim_blocks + block_dim_idx = block_idx % num_dim_blocks + + # Calculate destination token index (interleaved) + dst_token_idx = step_idx * batch_size + batch_idx + + # Calculate offsets + head_start = block_head_idx * BLOCK_HEAD + dim_start = block_dim_idx * BLOCK_DIM + + head_offsets = head_start + tl.arange(0, BLOCK_HEAD) + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + mask = head_mask[:, None] & dim_mask[None, :] + + # Load from source (batched layout): [batch_size, mtp_size, num_heads, head_dim] + src_base = src_ptr + batch_idx * src_stride_batch + step_idx * src_stride_mtp + src_offset = head_offsets[:, None] * src_stride_head + dim_offsets[None, :] * src_stride_dim + data = tl.load(src_base + src_offset, mask=mask, other=0.0) + + # Store to destination (interleaved layout): [total_tokens, 1, num_heads, head_dim] + # The seq dimension (1) is skipped since it's always 0 + dst_base = dst_ptr + dst_token_idx * dst_stride_token + dst_offset = head_offsets[:, None] * dst_stride_head + dim_offsets[None, :] * dst_stride_dim + tl.store(dst_base + dst_offset, data, mask=mask) + + +def _get_reorganize_mtp_configs(): + """Generate candidate configurations for MTP data reorganization.""" + configs = [] + for block_dim in [64, 128, 256, 512]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_reorganize_static_key(src: torch.Tensor, mtp_size: int): + """Static key based on tensor properties.""" + return { + "dtype": str(src.dtype), + "mtp_size": mtp_size, + } + + +def _get_reorganize_run_key(src: torch.Tensor, mtp_size: int): + """Run key based on batch size and dimension.""" + total_tokens = src.shape[0] + batch_size = total_tokens // mtp_size + dim_size = src.shape[-1] + return f"{batch_size}_{dim_size}" + + +@autotune( + kernel_name="gdn_decode_mtp_reorganize:v1", + configs_gen_func=_get_reorganize_mtp_configs, + static_key_func=_get_reorganize_static_key, + run_key_func=_get_reorganize_run_key, + mutates_args=["dst"], +) +def reorganize_mtp_to_batched( + src: torch.Tensor, + dst: torch.Tensor, + mtp_size: int, + run_config: dict = None, +): + """ + Reorganize data from interleaved MTP layout to batched layout. + + Args: + src: Input tensor with interleaved layout [total_tokens, dim] + Layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...] + dst: Output tensor with batched layout [total_tokens, dim] + Layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...] + mtp_size: Number of MTP steps + run_config: Auto-tuned configuration + """ + total_tokens = src.shape[0] + batch_size = total_tokens // mtp_size + dim_size = src.shape[-1] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) + + grid = (batch_size, mtp_size, num_blocks_dim) + + _reorganize_mtp_data_kernel[grid]( + src, + dst, + batch_size, + mtp_size, + dim_size, + src.stride(0), + src.stride(-1) if src.ndim > 1 else 1, + dst.stride(0), + dst.stride(-1) if dst.ndim > 1 else 1, + BLOCK_DIM=BLOCK_DIM, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_reorganize_back_configs(): + """Generate candidate configurations for MTP output reorganization.""" + configs = [] + for block_head in [4, 8, 16, 32]: + for block_dim in [32, 64, 128]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3]: + if block_head * block_dim <= 4096: # Limit shared memory + configs.append( + { + "BLOCK_HEAD": block_head, + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_reorganize_back_static_key( + src: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, +): + """Static key for output reorganization.""" + return { + "dtype": str(src.dtype), + "mtp_size": mtp_size, + "num_heads": num_heads, + "head_dim": head_dim, + } + + +def _get_reorganize_back_run_key( + src: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, +): + """Run key for output reorganization.""" + return batch_size + + +@autotune( + kernel_name="gdn_decode_mtp_reorganize_back:v1", + configs_gen_func=_get_reorganize_back_configs, + static_key_func=_get_reorganize_back_static_key, + run_key_func=_get_reorganize_back_run_key, + mutates_args=["dst"], +) +def reorganize_mtp_output_to_interleaved( + src: torch.Tensor, + dst: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, + run_config: dict = None, +): + """ + Reorganize output from batched layout back to interleaved layout. + + Args: + src: Input tensor [batch_size, mtp_size, num_heads, head_dim] (4D) + dst: Output tensor [batch_size * mtp_size, 1, num_heads, head_dim] (4D) + batch_size: Number of batch items + mtp_size: Number of MTP steps + num_heads: Number of attention heads + head_dim: Head dimension + run_config: Auto-tuned configuration + + Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] + """ + if run_config is None: + BLOCK_HEAD = min(triton.next_power_of_2(num_heads), 16) + BLOCK_DIM = min(triton.next_power_of_2(head_dim), 64) + num_warps = 4 + num_stages = 2 + else: + BLOCK_HEAD = run_config["BLOCK_HEAD"] + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_head_blocks = triton.cdiv(num_heads, BLOCK_HEAD) + num_dim_blocks = triton.cdiv(head_dim, BLOCK_DIM) + num_blocks_total = num_head_blocks * num_dim_blocks + + grid = (batch_size, mtp_size, num_blocks_total) + + # src is 4D: [batch_size, mtp_size, num_heads, head_dim] + # dst is 4D: [total_tokens, 1, num_heads, head_dim] + _reorganize_mtp_data_back_kernel[grid]( + src, + dst, + batch_size, + mtp_size, + num_heads, + head_dim, + src.stride(0), # batch stride + src.stride(1), # mtp stride + src.stride(2), # head stride + src.stride(3), # dim stride + dst.stride(0), # token stride + dst.stride(1), # seq stride (=1) + dst.stride(2), # head stride + dst.stride(3), # dim stride + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_DIM=BLOCK_DIM, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@triton.jit +def _prepare_mtp_indices_kernel( + # Input indices (per-step buffer indices) + buffer_idx_ptr, + # Output 2D indices for recurrent kernel + output_idx_ptr, + # Dimensions + batch_size, + mtp_size, + # Strides + input_stride, + output_stride_batch, + output_stride_step, +): + """ + Prepare 2D indices for the fused recurrent kernel. + + Input: mtp_size tensors of shape [batch_size] (buffer indices for each step) + Output: 2D tensor [batch_size, mtp_size] for ssm_state_indices + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + + # Load the buffer index for this batch and step + buffer_idx = tl.load(buffer_idx_ptr + step_idx * input_stride + batch_idx) + + # Store to the 2D output + output_offset = batch_idx * output_stride_batch + step_idx * output_stride_step + tl.store(output_idx_ptr + output_offset, buffer_idx) + + +def prepare_mtp_state_indices( + mtp_buffer_idx_list: list, + batch_size: int, + device: torch.device, +) -> torch.Tensor: + """ + Prepare 2D state indices for the fused recurrent kernel. + + Args: + mtp_buffer_idx_list: List of buffer index tensors, one per MTP step + batch_size: Number of batch items + device: Target device + + Returns: + 2D tensor of shape [batch_size, mtp_size] for ssm_state_indices + """ + + # Stack indices to create [mtp_size, batch_size] tensor + stacked_indices = torch.stack(mtp_buffer_idx_list, dim=0) + + # Transpose to get [batch_size, mtp_size] + return stacked_indices.T.contiguous() + + +@triton.jit +def _fused_conv1d_mtp_step_kernel( + # Input/output data + mixed_qkv_ptr, + # Conv state buffer + conv_states_ptr, + # Conv weight and bias + conv_weight_ptr, + conv_bias_ptr, + # Buffer indices (one per MTP step, each [batch_size]) + buffer_indices_ptr, + next_buffer_indices_ptr, + # Dimensions + batch_size, + dim_size, + conv_width, + # Step info + step_idx, + mtp_size, + is_last_step: tl.constexpr, + # Strides + qkv_stride_token, + qkv_stride_dim, + state_stride_buffer, + state_stride_dim, + state_stride_width, + weight_stride_dim, + weight_stride_width, + # Block sizes + BLOCK_DIM: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, +): + """ + Fused kernel for conv1d update in MTP decode. + + Handles one MTP step for all batch items: + 1. Reads current conv state + 2. Updates with new input + 3. Computes conv1d output + 4. Optionally copies state to next MTP step + """ + batch_idx = tl.program_id(0) + block_dim_idx = tl.program_id(1) + + # Calculate token index in interleaved layout + token_idx = step_idx * batch_size + batch_idx + + # Load buffer indices + cur_buffer_idx = tl.load(buffer_indices_ptr + batch_idx).to(tl.int64) + + # Calculate dimension offsets + dim_start = block_dim_idx * BLOCK_DIM + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + dim_mask = dim_offsets < dim_size + + # Load input value + input_offset = token_idx * qkv_stride_token + dim_offsets * qkv_stride_dim + input_val = tl.load(mixed_qkv_ptr + input_offset, mask=dim_mask, other=0.0) + + # Load conv bias + bias_val = tl.load(conv_bias_ptr + dim_offsets, mask=dim_mask, other=0.0) + + # Compute conv1d output and update state + output_val = bias_val + state_base = conv_states_ptr + cur_buffer_idx * state_stride_buffer + + # Process each position in the conv window + for w in range(conv_width): + # Load weight for this position + weight_offset = dim_offsets * weight_stride_dim + w * weight_stride_width + weight_val = tl.load(conv_weight_ptr + weight_offset, mask=dim_mask, other=0.0) + + if w < conv_width - 1: + # Load from state buffer + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + state_val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) + output_val += state_val * weight_val + else: + # Use current input for the last position + output_val += input_val * weight_val + + # Update conv state (shift and insert new value) + for w in range(conv_width - 2, -1, -1): + if w == conv_width - 2: + # Insert new input at the end + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + tl.store(state_base + state_offset, input_val, mask=dim_mask) + else: + # Shift state + src_offset = dim_offsets * state_stride_dim + (w + 1) * state_stride_width + dst_offset = dim_offsets * state_stride_dim + w * state_stride_width + val = tl.load(state_base + src_offset, mask=dim_mask, other=0.0) + tl.store(state_base + dst_offset, val, mask=dim_mask) + + # Apply activation (SiLU) + if ACTIVATION_SILU: + output_val = output_val * tl.sigmoid(output_val) + + # Store output + tl.store(mixed_qkv_ptr + input_offset, output_val, mask=dim_mask) + + # Copy state to next step if not last + if not is_last_step: + next_buffer_idx = tl.load(next_buffer_indices_ptr + batch_idx).to(tl.int64) + next_state_base = conv_states_ptr + next_buffer_idx * state_stride_buffer + + for w in range(conv_width - 1): + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) + tl.store(next_state_base + state_offset, val, mask=dim_mask) + + +def _get_conv1d_mtp_configs(): + """Generate candidate configurations for conv1d MTP kernel.""" + configs = [] + for block_dim in [64, 128, 256, 512]: + for num_warps in [2, 4, 8]: + for num_stages in [1, 2, 3]: + configs.append( + { + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_conv1d_mtp_static_key( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + mtp_size: int, +): + """Static key for conv1d MTP kernel.""" + return { + "dtype": str(mixed_qkv.dtype), + "dim_size": mixed_qkv.shape[-1], + "conv_width": conv_weight.shape[-1], + "mtp_size": mtp_size, + } + + +def _get_conv1d_mtp_run_key( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + mtp_size: int, +): + """Run key for conv1d MTP kernel.""" + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // mtp_size + return batch_size + + +@autotune( + kernel_name="gdn_conv1d_mtp:v1", + configs_gen_func=_get_conv1d_mtp_configs, + static_key_func=_get_conv1d_mtp_static_key, + run_key_func=_get_conv1d_mtp_run_key, + mutates_args=["mixed_qkv", "conv_states"], +) +def fused_conv1d_mtp_update( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + mtp_buffer_idx_list: list, + mtp_size: int, + activation_silu: bool = True, + run_config: dict = None, +): + """ + Fused conv1d update for all MTP steps. + + Args: + mixed_qkv: Input tensor [batch_size * mtp_size, dim] (interleaved) + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] + conv_weight: Conv weights [dim, conv_width] + conv_bias: Conv bias [dim] + mtp_buffer_idx_list: List of buffer index tensors per step + mtp_size: Number of MTP steps + activation_silu: Whether to apply SiLU activation + run_config: Auto-tuned configuration + """ + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // mtp_size + dim_size = mixed_qkv.shape[-1] + conv_width = conv_weight.shape[-1] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) + + for step_idx in range(mtp_size): + is_last_step = step_idx == mtp_size - 1 + cur_indices = mtp_buffer_idx_list[step_idx] + next_indices = mtp_buffer_idx_list[step_idx + 1] if not is_last_step else cur_indices + + grid = (batch_size, num_blocks_dim) + + _fused_conv1d_mtp_step_kernel[grid]( + mixed_qkv, + conv_states, + conv_weight, + conv_bias, + cur_indices, + next_indices, + batch_size, + dim_size, + conv_width, + step_idx, + mtp_size, + is_last_step, + mixed_qkv.stride(0), + mixed_qkv.stride(-1) if mixed_qkv.ndim > 1 else 1, + conv_states.stride(0), + conv_states.stride(1), + conv_states.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + BLOCK_DIM=BLOCK_DIM, + ACTIVATION_SILU=activation_silu, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@triton.jit +def _copy_ssm_state_kernel( + # SSM state buffer + ssm_states_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + num_heads, + key_dim, + value_dim, + # Strides + state_stride_buffer, + state_stride_head, + state_stride_key, + state_stride_value, + # Block sizes + BLOCK_KEY: tl.constexpr, + BLOCK_VALUE: tl.constexpr, +): + """ + Copy SSM states from source indices to destination indices. + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + block_idx = tl.program_id(2) + + # Calculate block positions + num_value_blocks = tl.cdiv(value_dim, BLOCK_VALUE) + block_key_idx = block_idx // num_value_blocks + block_value_idx = block_idx % num_value_blocks + + key_start = block_key_idx * BLOCK_KEY + value_start = block_value_idx * BLOCK_VALUE + + key_offsets = key_start + tl.arange(0, BLOCK_KEY) + value_offsets = value_start + tl.arange(0, BLOCK_VALUE) + + key_mask = key_offsets < key_dim + value_mask = value_offsets < value_dim + mask = key_mask[:, None] & value_mask[None, :] + + # Load indices + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate offsets + src_base = ssm_states_ptr + src_idx * state_stride_buffer + head_idx * state_stride_head + dst_base = ssm_states_ptr + dst_idx * state_stride_buffer + head_idx * state_stride_head + + offsets = key_offsets[:, None] * state_stride_key + value_offsets[None, :] * state_stride_value + + # Copy data + data = tl.load(src_base + offsets, mask=mask, other=0.0) + tl.store(dst_base + offsets, data, mask=mask) + + +@triton.jit +def _copy_conv_state_kernel( + # Conv state buffer [num_buffers, dim, conv_width-1] + conv_states_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + dim_size, + width_size, + num_width_blocks, # Precomputed to avoid runtime division + # Strides + state_stride_buffer, + state_stride_dim, + state_stride_width, + # Block sizes + BLOCK_DIM: tl.constexpr, + BLOCK_WIDTH: tl.constexpr, +): + """ + Copy conv states from source indices to destination indices. + + Conv state shape: [num_buffers, dim, conv_width-1] + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Calculate block positions using precomputed num_width_blocks + block_dim_idx = block_idx // num_width_blocks + block_width_idx = block_idx % num_width_blocks + + dim_start = block_dim_idx * BLOCK_DIM + width_start = block_width_idx * BLOCK_WIDTH + + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + width_offsets = width_start + tl.arange(0, BLOCK_WIDTH) + + dim_mask = dim_offsets < dim_size + width_mask = width_offsets < width_size + mask = dim_mask[:, None] & width_mask[None, :] + + # Load indices + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate offsets + src_base = conv_states_ptr + src_idx * state_stride_buffer + dst_base = conv_states_ptr + dst_idx * state_stride_buffer + + offsets = dim_offsets[:, None] * state_stride_dim + width_offsets[None, :] * state_stride_width + + # Copy data + data = tl.load(src_base + offsets, mask=mask, other=0.0) + tl.store(dst_base + offsets, data, mask=mask) + + +def _get_conv_copy_configs(): + """Generate candidate configurations for conv state copy.""" + configs = [] + for block_dim in [64, 128, 256]: + for block_width in [2, 4, 8]: + for num_warps in [2, 4]: + configs.append( + { + "BLOCK_DIM": block_dim, + "BLOCK_WIDTH": block_width, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_conv_copy_static_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for conv copy.""" + return { + "dtype": str(conv_states.dtype), + "dim_size": conv_states.shape[1], + "width_size": conv_states.shape[2], + } + + +def _get_conv_copy_run_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for conv copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_conv_state_copy:v1", + configs_gen_func=_get_conv_copy_configs, + static_key_func=_get_conv_copy_static_key, + run_key_func=_get_conv_copy_run_key, + mutates_args=["conv_states"], +) +def copy_conv_states( + conv_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Copy conv states from source indices to destination indices. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + batch_size = src_indices.shape[0] + dim_size = conv_states.shape[1] + width_size = conv_states.shape[2] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 128)) + BLOCK_WIDTH = triton.next_power_of_2(min(width_size, 4)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + BLOCK_WIDTH = run_config["BLOCK_WIDTH"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_dim_blocks = triton.cdiv(dim_size, BLOCK_DIM) + num_width_blocks = triton.cdiv(width_size, BLOCK_WIDTH) + num_blocks_total = num_dim_blocks * num_width_blocks + + grid = (batch_size, num_blocks_total) + + _copy_conv_state_kernel[grid]( + conv_states, + src_indices, + dst_indices, + batch_size, + dim_size, + width_size, + num_width_blocks, # Pass precomputed value + conv_states.stride(0), + conv_states.stride(1), + conv_states.stride(2), + BLOCK_DIM=BLOCK_DIM, + BLOCK_WIDTH=BLOCK_WIDTH, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_ssm_copy_configs(): + """Generate candidate configurations for SSM state copy.""" + configs = [] + for block_key in [16, 32, 64]: + for block_value in [16, 32, 64, 128]: + for num_warps in [2, 4, 8]: + if block_key * block_value <= 4096: + configs.append( + { + "BLOCK_KEY": block_key, + "BLOCK_VALUE": block_value, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_ssm_copy_static_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for SSM copy.""" + return { + "dtype": str(ssm_states.dtype), + "num_heads": ssm_states.shape[1], + "key_dim": ssm_states.shape[2], + "value_dim": ssm_states.shape[3], + } + + +def _get_ssm_copy_run_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for SSM copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_ssm_state_copy:v1", + configs_gen_func=_get_ssm_copy_configs, + static_key_func=_get_ssm_copy_static_key, + run_key_func=_get_ssm_copy_run_key, + mutates_args=["ssm_states"], +) +def copy_ssm_states( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Copy SSM states from source indices to destination indices. + + Args: + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + batch_size = src_indices.shape[0] + num_heads = ssm_states.shape[1] + key_dim = ssm_states.shape[2] + value_dim = ssm_states.shape[3] + + if run_config is None: + BLOCK_KEY = triton.next_power_of_2(min(key_dim, 32)) + BLOCK_VALUE = triton.next_power_of_2(min(value_dim, 64)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_KEY = run_config["BLOCK_KEY"] + BLOCK_VALUE = run_config["BLOCK_VALUE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_key_blocks = triton.cdiv(key_dim, BLOCK_KEY) + num_value_blocks = triton.cdiv(value_dim, BLOCK_VALUE) + num_blocks_total = num_key_blocks * num_value_blocks + + grid = (batch_size, num_heads, num_blocks_total) + + _copy_ssm_state_kernel[grid]( + ssm_states, + src_indices, + dst_indices, + batch_size, + num_heads, + key_dim, + value_dim, + ssm_states.stride(0), + ssm_states.stride(1), + ssm_states.stride(2), + ssm_states.stride(3), + BLOCK_KEY=BLOCK_KEY, + BLOCK_VALUE=BLOCK_VALUE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +# ============================================================================= +# Optimized Flat Copy Kernels (for contiguous memory) +# ============================================================================= +# These kernels leverage the fact that both conv_states and ssm_states are +# contiguous in memory, allowing us to flatten the inner dimensions and use +# efficient 1D vectorized copy patterns. + + +@triton.jit +def _copy_state_flat_kernel( + # State buffer pointer (flattened view) + state_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + flat_size, # Total elements per buffer entry (flattened inner dims) + # Strides + stride_buffer, # Stride to next buffer entry (in elements) + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Optimized flat copy kernel for contiguous state buffers. + + Instead of using 2D/3D block patterns with stride calculations, this kernel + treats each buffer entry as a flat 1D array and uses vectorized loads/stores + for efficient memory transfer. + + Grid: (batch_size, num_blocks) where num_blocks = ceil(flat_size / BLOCK_SIZE) + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Calculate element range for this block + elem_start = block_idx * BLOCK_SIZE + elem_offsets = elem_start + tl.arange(0, BLOCK_SIZE) + elem_mask = elem_offsets < flat_size + + # Load buffer indices for this batch item + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate source and destination base pointers + src_base = state_ptr + src_idx * stride_buffer + dst_base = state_ptr + dst_idx * stride_buffer + + # Vectorized copy + data = tl.load(src_base + elem_offsets, mask=elem_mask, other=0.0) + tl.store(dst_base + elem_offsets, data, mask=elem_mask) + + +@triton.jit +def _copy_states_fused_kernel( + # Conv state buffer (flattened view) + conv_state_ptr, + # SSM state buffer (flattened view) + ssm_state_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + conv_flat_size, # Total elements per conv buffer entry + ssm_flat_size, # Total elements per ssm buffer entry + # Strides (in elements) + conv_stride_buffer, + ssm_stride_buffer, + # Block sizes + CONV_BLOCK_SIZE: tl.constexpr, + SSM_BLOCK_SIZE: tl.constexpr, +): + """ + Fused kernel to copy both conv_states and ssm_states in a single launch. + + This reduces kernel launch overhead by processing both state copies together. + Each thread block handles one batch item and copies both states sequentially. + + Grid: (batch_size, max(conv_blocks, ssm_blocks)) + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Load buffer indices (same for both conv and ssm) + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # ========== Copy Conv State ========== + conv_num_blocks = tl.cdiv(conv_flat_size, CONV_BLOCK_SIZE) + if block_idx < conv_num_blocks: + conv_elem_start = block_idx * CONV_BLOCK_SIZE + conv_elem_offsets = conv_elem_start + tl.arange(0, CONV_BLOCK_SIZE) + conv_mask = conv_elem_offsets < conv_flat_size + + conv_src_base = conv_state_ptr + src_idx * conv_stride_buffer + conv_dst_base = conv_state_ptr + dst_idx * conv_stride_buffer + + conv_data = tl.load(conv_src_base + conv_elem_offsets, mask=conv_mask, other=0.0) + tl.store(conv_dst_base + conv_elem_offsets, conv_data, mask=conv_mask) + + # ========== Copy SSM State ========== + ssm_num_blocks = tl.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) + if block_idx < ssm_num_blocks: + ssm_elem_start = block_idx * SSM_BLOCK_SIZE + ssm_elem_offsets = ssm_elem_start + tl.arange(0, SSM_BLOCK_SIZE) + ssm_mask = ssm_elem_offsets < ssm_flat_size + + ssm_src_base = ssm_state_ptr + src_idx * ssm_stride_buffer + ssm_dst_base = ssm_state_ptr + dst_idx * ssm_stride_buffer + + ssm_data = tl.load(ssm_src_base + ssm_elem_offsets, mask=ssm_mask, other=0.0) + tl.store(ssm_dst_base + ssm_elem_offsets, ssm_data, mask=ssm_mask) + + +def _get_flat_copy_configs(): + """Generate candidate configurations for flat copy kernel.""" + configs = [] + # Larger block sizes for better memory throughput on contiguous data + for block_size in [256, 512, 1024, 2048]: + for num_warps in [4, 8]: + configs.append( + { + "BLOCK_SIZE": block_size, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_conv_flat_copy_static_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for conv flat copy.""" + return { + "dtype": str(conv_states.dtype), + "flat_size": conv_states.shape[1] * conv_states.shape[2], + } + + +def _get_conv_flat_copy_run_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for conv flat copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_conv_state_flat_copy:v1", + configs_gen_func=_get_flat_copy_configs, + static_key_func=_get_conv_flat_copy_static_key, + run_key_func=_get_conv_flat_copy_run_key, + mutates_args=["conv_states"], +) +def copy_conv_states_flat( + conv_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Optimized flat copy for conv states leveraging contiguous memory. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert conv_states.is_contiguous(), "conv_states must be contiguous for flat copy" + + batch_size = src_indices.shape[0] + # Flatten inner dimensions + flat_size = conv_states.shape[1] * conv_states.shape[2] + stride_buffer = conv_states.stride(0) + + if run_config is None: + BLOCK_SIZE = 1024 + num_warps = 4 + num_stages = 2 + else: + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) + grid = (batch_size, num_blocks) + + _copy_state_flat_kernel[grid]( + conv_states, + src_indices, + dst_indices, + batch_size, + flat_size, + stride_buffer, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_ssm_flat_copy_static_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for ssm flat copy.""" + return { + "dtype": str(ssm_states.dtype), + "flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], + } + + +def _get_ssm_flat_copy_run_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for ssm flat copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_ssm_state_flat_copy:v1", + configs_gen_func=_get_flat_copy_configs, + static_key_func=_get_ssm_flat_copy_static_key, + run_key_func=_get_ssm_flat_copy_run_key, + mutates_args=["ssm_states"], +) +def copy_ssm_states_flat( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Optimized flat copy for SSM states leveraging contiguous memory. + + Args: + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert ssm_states.is_contiguous(), "ssm_states must be contiguous for flat copy" + + batch_size = src_indices.shape[0] + # Flatten inner dimensions (num_heads * key_dim * value_dim) + flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] + stride_buffer = ssm_states.stride(0) + + if run_config is None: + BLOCK_SIZE = 1024 + num_warps = 4 + num_stages = 2 + else: + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) + grid = (batch_size, num_blocks) + + _copy_state_flat_kernel[grid]( + ssm_states, + src_indices, + dst_indices, + batch_size, + flat_size, + stride_buffer, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_fused_copy_configs(): + """Generate candidate configurations for fused copy kernel.""" + configs = [] + # Use power-of-2 block sizes for both conv and ssm + for conv_block in [256, 512, 1024]: + for ssm_block in [256, 512, 1024]: + for num_warps in [4, 8]: + configs.append( + { + "CONV_BLOCK_SIZE": conv_block, + "SSM_BLOCK_SIZE": ssm_block, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_fused_copy_static_key( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for fused copy.""" + return { + "conv_dtype": str(conv_states.dtype), + "ssm_dtype": str(ssm_states.dtype), + "conv_flat_size": conv_states.shape[1] * conv_states.shape[2], + "ssm_flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], + } + + +def _get_fused_copy_run_key( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for fused copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_states_fused_copy:v1", + configs_gen_func=_get_fused_copy_configs, + static_key_func=_get_fused_copy_static_key, + run_key_func=_get_fused_copy_run_key, + mutates_args=["conv_states", "ssm_states"], +) +def copy_states_fused( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Fused copy for both conv and SSM states in a single kernel launch. + + This reduces kernel launch overhead by processing both state copies together. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert conv_states.is_contiguous(), "conv_states must be contiguous for fused copy" + assert ssm_states.is_contiguous(), "ssm_states must be contiguous for fused copy" + + batch_size = src_indices.shape[0] + + # Flatten inner dimensions + conv_flat_size = conv_states.shape[1] * conv_states.shape[2] + ssm_flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] + + conv_stride_buffer = conv_states.stride(0) + ssm_stride_buffer = ssm_states.stride(0) + + if run_config is None: + CONV_BLOCK_SIZE = 512 + SSM_BLOCK_SIZE = 512 + num_warps = 4 + num_stages = 2 + else: + CONV_BLOCK_SIZE = run_config["CONV_BLOCK_SIZE"] + SSM_BLOCK_SIZE = run_config["SSM_BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + # Grid covers both conv and ssm blocks + conv_num_blocks = triton.cdiv(conv_flat_size, CONV_BLOCK_SIZE) + ssm_num_blocks = triton.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) + max_blocks = max(conv_num_blocks, ssm_num_blocks) + grid = (batch_size, max_blocks) + + _copy_states_fused_kernel[grid]( + conv_states, + ssm_states, + src_indices, + dst_indices, + batch_size, + conv_flat_size, + ssm_flat_size, + conv_stride_buffer, + ssm_stride_buffer, + CONV_BLOCK_SIZE=CONV_BLOCK_SIZE, + SSM_BLOCK_SIZE=SSM_BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py new file mode 100644 index 0000000000..0a2b4bd662 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py @@ -0,0 +1,141 @@ +import torch + +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _gemma_rmsnorm_fwd_kernel( + x_ptr, + w_ptr, + y_ptr, + x_stride0, + x_stride1, + y_stride0, + y_stride1, + N: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + x_ptr = x_ptr + row * x_stride0 + y_ptr = y_ptr + row * y_stride0 + + _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + cols * x_stride1, mask=cols < N, other=0.0).to(tl.float32) + _sum += x * x + + var = tl.sum(_sum, axis=0) / N + rstd = 1 / tl.sqrt(var + EPS) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) + x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + w = w + 1.0 + y = x_hat * w + # Write output + tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) + + +def _get_gemma_rmsnorm_configs(): + """Generate configurations for autotuning gemma RMSNorm kernel.""" + configs = [] + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: + for num_warps in [1, 2, 4, 8]: + # num_stages has minimal impact on this simple kernel, use 1 + configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1}) + return configs + + +def _get_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + N = x.shape[-1] + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(w.dtype), + "N": N, + } + + +@autotune( + kernel_name="gemma_rmsnorm_forward:v1", + configs_gen_func=_get_gemma_rmsnorm_configs, + static_key_func=_get_gemma_rmsnorm_static_key, + run_key_func=lambda x: x.shape[-1], +) +def gemma_rmsnorm_forward(x, w, eps, out=None, run_config: dict = None): + # Inplace gemma RMS Norm + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + N = x.shape[-1] + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, N) + y_arg = y.view(-1, N) + + M, _ = x_arg.shape + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This gemma rmsnorm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} + + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + _gemma_rmsnorm_fwd_kernel[(M,)]( + x_arg, + w, + y_arg, + x_stride0=x.stride(0), + x_stride1=x.stride(1), + y_stride0=y.stride(0), + y_stride1=y.stride(1), + N=N, + EPS=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y + + +def _gemma_rmsnorm_fwd_torch(x, weight, eps): + original_dtype = x.dtype + x = x.to(torch.float32) + x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + x = x * (1.0 + weight.float()) + return x.to(original_dtype) + + +def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + # forward pass + y_tri = gemma_rmsnorm_forward(x, weight, eps) + y_ref = _gemma_rmsnorm_fwd_torch(x, weight, eps) + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + # Use appropriate tolerance based on dtype + atol = 1e-2 if dtype == torch.float32 else 5e-2 + assert torch.allclose(y_tri, y_ref, atol=atol, rtol=0) + return diff --git a/lightllm/models/qwen3next_mtp/__init__.py b/lightllm/models/qwen3next_mtp/__init__.py new file mode 100644 index 0000000000..779237817d --- /dev/null +++ b/lightllm/models/qwen3next_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel + +__all__ = ["Qwen3NextMTPModel"] diff --git a/lightllm/models/qwen3next_mtp/layer_infer/__init__.py b/lightllm/models/qwen3next_mtp/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..2918fca79c --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py @@ -0,0 +1,16 @@ +import torch +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextMTPPostLayerInfer(LlamaPostLayerInfer): + """ + Qwen3Next MTP Post Layer Inference. + Uses gemma_rmsnorm for normalization (same as Qwen3Next). + """ + + def _norm(self, input, infer_state, layer_weight: Qwen3NextMTPPreAndPostLayerWeight) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) + return out diff --git a/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..4fc207648c --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,68 @@ +import torch + +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextMTPPreLayerInfer(LlamaPreLayerInfer): + """ + Qwen3Next MTP Pre-Layer Inference. + Similar to DeepSeek MTP but with different weight structure. + + MTP forward flow: + 1. Get embedding from input_ids + 2. Get hidden state from main model (passed via infer_state) + 3. Normalize embedding with pre_fc_norm_embedding + 4. Normalize hidden with pre_fc_norm_hidden + 5. Concat normalized embedding and hidden + 6. Project through fc to get hidden_dim output + """ + + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + + # Normalize embedding + input_embdings_normed = self.alloc_tensor(input_embdings.shape, input_embdings.dtype) + gemma_rmsnorm_forward( + input_embdings, layer_weight.pre_fc_norm_embedding_weight_.weight, self.eps_, out=input_embdings_normed + ) + + # Normalize hidden state + tgt_embdings_normed = self.alloc_tensor(tgt_embdings.shape, tgt_embdings.dtype) + gemma_rmsnorm_forward( + tgt_embdings, layer_weight.pre_fc_norm_hidden_weight_.weight, self.eps_, out=tgt_embdings_normed + ) + + # Concat normalized embedding and hidden + cat_embdings = torch.cat((input_embdings_normed, tgt_embdings_normed), dim=-1) + + # Project to hidden_size + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.fc_weight_.shape[1]), dtype=cat_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.fc_weight_, out=ans_logics) + + return ans_logics + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..03630c17c1 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py @@ -0,0 +1,30 @@ +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import Qwen3NextFullAttentionBaseLayerInfer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3NextMTPTransformerLayerInfer(Qwen3NextFullAttentionBaseLayerInfer): + """ + Qwen3Next MTP Transformer Layer Inference. + MTP layers use full attention (not linear attention) with MoE FFN and shared expert. + Inherits shared methods from Qwen3NextFullAttentionBaseLayerInfer. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) + self.tp_v_head_num_ = max(self.tp_v_head_num_, 1) + return + + def _bind_ffn(self): + """MTP always uses shared expert + MoE""" + from functools import partial + import os + + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_tp, self) + return diff --git a/lightllm/models/qwen3next_mtp/layer_weights/__init__.py b/lightllm/models/qwen3next_mtp/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..8a74ef8567 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,47 @@ +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpGEMMANormWeight + + +class Qwen3NextMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + self.wte_weight_ = None + self.lm_head_weight_ = None + + hidden_size = network_config["hidden_size"] + # Use Gemma-style normalization for all MTP norm layers + self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.norm.weight", + data_type=self.data_type_, + ) + self.pre_fc_norm_embedding_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_embedding.weight", + data_type=self.data_type_, + ) + self.pre_fc_norm_hidden_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_hidden.weight", + data_type=self.data_type_, + ) + return + + def load_hf_weights(self, weights): + if "mtp.fc.weight" in weights: + self.fc_weight_ = self._cuda(weights["mtp.fc.weight"]).t() + + # Load weights for norm weight objects + self.final_norm_weight_.load_hf_weights(weights) + self.pre_fc_norm_embedding_weight_.load_hf_weights(weights) + self.pre_fc_norm_hidden_weight_.load_hf_weights(weights) + + return + + def verify_load(self): + # Verify all norm weights loaded correctly + return ( + self.final_norm_weight_.verify_load() + and self.pre_fc_norm_embedding_weight_.verify_load() + and self.pre_fc_norm_hidden_weight_.verify_load() + ) diff --git a/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..d52da5647d --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,141 @@ +import os +import torch +import math +import numpy as np +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.utils.envs_utils import enable_env_vars +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + RMSNormWeight, + QKRMSNORMWeight, + KVROWNMMWeight, +) +from functools import partial + + +class Qwen3NextMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_weight_names(self): + self._q_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.q_proj.weight" + self._q_norm_name = f"mtp.layers.{self.layer_num_}.self_attn.q_norm.weight" + self._q_bias_name = None + self._k_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.k_proj.weight" + self._k_norm_name = f"mtp.layers.{self.layer_num_}.self_attn.k_norm.weight" + self._k_bias_name = None + self._v_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.v_proj.weight" + self._v_bias_name = None + self._kv_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.kv_proj.weight" + self._kv_bias_name = None + self._o_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.o_proj.weight" + self._o_bias_name = None + self._att_norm_weight_name = f"mtp.layers.{self.layer_num_}.input_layernorm.weight" + self._att_norm_bias_name = None + self._ffn_norm_weight_name = f"mtp.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None + + def _init_qkv(self): + # Override parent's QKVROWNMMWeight which requires kv_head_num % tp == 0. + # Qwen3-Next has few KV heads; KVROWNMMWeight handles repeating. + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=self._q_weight_name, + data_type=self.data_type_, + bias_names=self._q_bias_name, + quant_method=self.get_quant_method("q_proj"), + ) + self.kv_proj = KVROWNMMWeight( + in_dim=in_dim, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("kv_proj"), + ) + + def _init_weight(self): + self._init_moe() + self._init_shared_expert_weight() + + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + + self._init_qkv() + self._init_o() + self.q_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim, weight_name=self._q_norm_name, data_type=self.data_type_ + ) + self.k_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim, weight_name=self._k_norm_name, data_type=self.data_type_ + ) + self._o_gate_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + q_out_dim = self.q_head_num_ * self.head_dim + self.o_gate_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[q_out_dim], + weight_names=self._o_gate_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + return + + def load_hf_weights(self, weights): + self._split_q_with_gate(weights) + super().load_hf_weights(weights) + + def _init_shared_expert_weight(self): + prefix = f"mtp.layers.{self.layer_num_}.mlp.shared_expert" + hidden_size = self.network_config_["hidden_size"] + shared_inter = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[shared_inter, shared_inter], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=shared_inter, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"mtp.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + + def _split_q_with_gate(self, weights): + if self._q_weight_name in weights: + weight = weights[self._q_weight_name] + num_heads = self.q_head_num_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self._q_weight_name] = _q_proj + weights[self._o_gate_weight_name] = _gate_proj diff --git a/lightllm/models/qwen3next_mtp/model.py b/lightllm/models/qwen3next_mtp/model.py new file mode 100644 index 0000000000..92e4918bea --- /dev/null +++ b/lightllm/models/qwen3next_mtp/model.py @@ -0,0 +1,101 @@ +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3next_mtp.layer_infer.pre_layer_infer import Qwen3NextMTPPreLayerInfer +from lightllm.models.qwen3next_mtp.layer_infer.transformer_layer_infer import Qwen3NextMTPTransformerLayerInfer +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.qwen3next_mtp.layer_weights.transformer_layer_weight import Qwen3NextMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights +from lightllm.models.registry import ModelRegistry + + +@ModelRegistry("qwen3next_mtp") +class Qwen3NextMTPModel(Qwen3NextTpPartModel): + + pre_and_post_weight_class = Qwen3NextMTPPreAndPostLayerWeight + pre_layer_infer_class = Qwen3NextMTPPreLayerInfer + transformer_weight_class = Qwen3NextMTPTransformerLayerWeight + transformer_layer_infer_class = Qwen3NextMTPTransformerLayerInfer + + def __init__(self, kvargs: dict): + self.mtp_n_layers = 1 + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + """Extract main model and memory layer start from kwargs.""" + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start") + return + + def autotune_layers(self): + return 1 + + def _init_some_value(self): + self.layers_num = self.mtp_n_layers + + def _init_config(self): + super()._init_config() + self.config["n_layers"] = self.mtp_n_layers + self.config["num_hidden_layers"] = self.mtp_n_layers + return + + def _init_custom(self): + """Initialize custom components, sharing cos/sin cache with main model.""" + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + """Share request manager with main model.""" + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + """Share memory manager with main model.""" + self.mem_manager = self.main_model.mem_manager + return + + def _check_mem_size(self): + """Skip mem size check for MTP models since they share memory with main model.""" + self.max_total_token_num = self.mem_manager.size + return + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(self.mtp_n_layers) + ] + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + self.layers_infer = [ + self.transformer_layer_infer_class( + i * self.config["full_attention_interval"] - 1, # Ensure full attention layer + network_config=self.config, + ) + for i in range(self.mtp_n_layers) + ] + # Ensure full attention layer + for i, layer in enumerate(self.layers_infer): + layer.layer_num_ = i + self.mem_layer_start + return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 96126744af..4d122f615d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -128,7 +128,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"], + choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2", "qwen3_coder"], default=None, help="tool call parser type", ) @@ -551,7 +551,15 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--mtp_mode", - choices=["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att", None], + choices=[ + "vanilla_with_att", + "eagle_with_att", + "vanilla_no_att", + "eagle_no_att", + "qwen3next_vanilla", + "qwen3next_eagle", + None, + ], default=None, help="""Supported MTP modes. None: Disables MTP. @@ -621,6 +629,14 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument("--mamba_cache_size", type=int, default=3000, help="""The size of linear attn cache. """) + parser.add_argument( + "--mamba_ssm_data_type", + type=str, + choices=["bfloat16", "float32"], + default="float32", + help="the data type of the model weight", + ) parser.add_argument( "--hardware_platform", type=str, diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d91bb1d947..fc14314ae3 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -176,10 +176,24 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req multimodal_params_dict["images"].append({"type": "base64", "data": data}) else: raise ValueError("Unrecognized image input.") + elif img.startswith("file://"): + # Local file path with file:// prefix + file_path = img[7:] # Remove "file://" prefix + with open(file_path, "rb") as f: + multimodal_params_dict["images"].append( + {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} + ) else: - raise ValueError( - "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." - ) + # Treat as local file path + if os.path.isfile(img): + with open(img, "rb") as f: + multimodal_params_dict["images"].append( + {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} + ) + else: + raise ValueError( + "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." + ) tools = None if request.tools and request.tool_choice != "none": diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 111def60c2..34dd69c801 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -132,7 +132,8 @@ def normal_or_p_d_start(args): # mtp params check if args.mtp_mode is not None: - assert args.mtp_draft_model_dir is not None + if args.mtp_draft_model_dir is None: + args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step assert args.mtp_step > 0 else: assert args.mtp_draft_model_dir is None diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index a369cf7f7f..d8d2c6ff8b 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -31,7 +31,8 @@ class StartArgs: batch_max_tokens: Optional[int] = field(default=None) eos_id: List[int] = field(default_factory=list) tool_call_parser: Optional[str] = field( - default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]} + default=None, + metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]}, ) reasoning_parser: Optional[str] = field( default=None, @@ -54,7 +55,7 @@ class StartArgs: }, ) chat_template: Optional[str] = field(default=None) - running_max_req_size: int = field(default=1000) + running_max_req_size: int = field(default=512) tp: int = field(default=1) dp: int = field(default=1) nnodes: int = field(default=1) @@ -107,7 +108,7 @@ class StartArgs: disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) prefll_cudagraph_max_handle_token: int = field(default=512) - graph_max_batch_size: int = field(default=256) + graph_max_batch_size: int = field(default=512) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=0) @@ -134,7 +135,18 @@ class StartArgs: ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) mtp_mode: Optional[str] = field( - default=None, metadata={"choices": ["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att"]} + default=None, + metadata={ + "choices": [ + "vanilla_with_att", + "eagle_with_att", + "vanilla_no_att", + "eagle_no_att", + "qwen3next_vanilla", + "qwen3next_eagle", + None, + ] + }, ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) @@ -162,3 +174,7 @@ class StartArgs: # multi_modal enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) + + # hybrid attention model (Qwen3Next) + mamba_cache_size: int = field(default=800) + mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py new file mode 100644 index 0000000000..2a4fe06628 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -0,0 +1,206 @@ +from typing import Set, Protocol, List, Optional, Tuple + +import torch +from sortedcontainers import SortedSet + +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class HybridRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager): + super().__init__(unique_name, total_token_num, rank_in_node, kv_cache_mem_manager) + assert hasattr(kv_cache_mem_manager, "mamba_cache_mem_manager") + self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.mamba_cache_mem_manager + self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,)) + + def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): + if need_buffer_num > self.buffer_mem_manager.can_use_mem_size: + need_evict_buffer_num = need_buffer_num - self.buffer_mem_manager.can_use_mem_size + + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self._evict_buffer(need_evict_buffer_num, release_buffer, release_mem) + self.buffer_mem_manager.free(release_buffers) + if len(release_mems) > 0: + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return + + def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback): + while need_evict_buffer_num > 0: + node = self.evict_buffer_set.pop(0) + assert node.buffer_idx is not None + evict_buffer_callback(node.buffer_idx) + node.buffer_idx = None + need_evict_buffer_num -= 1 + # 当一个节点的buffer_idx变为None时,事实上无法在后续进行match, + # 但当该节点子节点或者引用数不为0时,仍然需要保留, 否则则应该被删除 + if node.is_leaf() and node.ref_counter == 0: + self.evict_tree_set.discard(node) + evict_token_callback(node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + return + + def insert_for_hybrid_radix_cache(self, reqs): + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + reqs_to_insert = [req for req in reqs if req.cur_kv_len < req.get_cur_total_len()] + + if len(reqs_to_insert) == 0: + return + + self.free_radix_cache_to_get_enough_buffer(len(reqs_to_insert)) + req_idxes = torch.tensor([req.req_idx for req in reqs_to_insert], dtype=torch.int64, device="cuda") + req_to_buffer_index = g_infer_context.req_manager.req_to_buffer_index + # Make contiguous and convert to int64 for Triton kernel compatibility + cur_buffer_indexes = req_to_buffer_index[req_idxes, 0].contiguous().to(torch.int64) + + new_buffer_indexes = self.buffer_mem_manager.alloc(len(reqs_to_insert)) + # Move to CUDA and convert to int64, ensure contiguous + new_buffer_indexes_cuda = new_buffer_indexes.to(device="cuda", dtype=torch.int64).contiguous() + + self.buffer_mem_manager.copy_buffer_p2p(cur_buffer_indexes, new_buffer_indexes_cuda) + + for i, req in enumerate(reqs_to_insert): + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() + prefix_len, new_shared_kv_node = super().insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + self.dec_node_ref_counter(req.shared_kv_node) + self.add_node_ref_counter(new_shared_kv_node) + self.add_buffer_idx_to_node(new_shared_kv_node, new_buffer_indexes[i].item()) + req.extra_need_to_free_token_index.append( + g_infer_context.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] + ) + req.shared_kv_node = new_shared_kv_node + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + miss_prefix_len = 0 + evict_token_list = [] + while tree_node != self.root_node and tree_node.buffer_idx is None: + if tree_node.is_leaf(): + self.evict_tree_set.discard(tree_node) + + # Only update ref_counter when update_refs is True to maintain consistency + # with _match_prefix_helper which only increments ref_counter when update_refs=True + if update_refs: + if tree_node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) + tree_node.ref_counter -= 1 # 只减少当前节点,不递归 + + if tree_node.is_leaf() and tree_node.ref_counter == 0: + evict_token_list.append(tree_node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) + parent_node: TreeNode = tree_node.parent + parent_node.remove_child(tree_node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + tree_node = parent_node + else: + if tree_node.is_leaf(): + self.evict_tree_set.add(tree_node) + tree_node = tree_node.parent + miss_prefix_len += len(ans_value_list.pop()) + + if len(evict_token_list) > 0: + evict_token_value = torch.concat(evict_token_list) + self.mem_manager.free(evict_token_value) + + if tree_node == self.root_node: + return None, miss_prefix_len, None + + update_node = tree_node + while update_node != self.root_node: + if update_node.buffer_idx is not None: + self.evict_buffer_set.discard(update_node) + update_node.update_buffer_time() + self.evict_buffer_set.add(update_node) + update_node = update_node.parent + + value = torch.concat(ans_value_list) + return tree_node, miss_prefix_len, value + + def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int): + """Set buffer_idx for a node and add it to evict_buffer_set.""" + self.evict_buffer_set.discard(node) + if node.is_leaf(): + self.evict_tree_set.discard(node) + if node.buffer_idx is not None: + self.buffer_mem_manager.free([node.buffer_idx]) + node.buffer_idx = buffer_idx + node.update_buffer_time() + self.evict_buffer_set.add(node) + if node.is_leaf(): + self.evict_tree_set.add(node) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self.evict(need_evict_token_num, release_buffer, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + if len(release_buffers) > 0: + self.buffer_mem_manager.free(release_buffers) + return + + def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), f"error evict tree node state: {node.ref_counter}, {len(node.children)}" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + if node.buffer_idx is not None: + self.evict_buffer_set.discard(node) + evict_buffer_callback(node.buffer_idx) + node.buffer_idx = None + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 09bc938f23..3b59401144 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -101,6 +101,14 @@ def get_tokenizer( tokenizer = QWen3VLTokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) + elif model_type in ["qwen3_5", "qwen3_5_moe"] and "vision_config" in model_cfg: + from transformers import AutoProcessor + from ..models.qwen3_5.model import QWen3_5Tokenizer + + processor = AutoProcessor.from_pretrained(tokenizer_name) + tokenizer = QWen3_5Tokenizer( + tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + ) elif model_cfg.get("thinker_config") is not None: from transformers import AutoProcessor diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 790f185f25..fa0a9e3c71 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -87,6 +87,22 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: except: pass + # Qwen3.5 checkpoints can have an eos_token_id in config that differs from + # tokenizer.eos_token_id. In practice tokenizer.eos_token_id is the reliable + # stop id (<|im_end|>) for detokenization/stop behavior. + try: + config_json = get_config_json(model_path) + model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") + if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + if tokenizer.eos_token_id is not None: + return [int(tokenizer.eos_token_id)] + except Exception: + # Fall back to config-based lookup below. + pass + eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) if isinstance(eos_token_id, int): return [eos_token_id] diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 7a7a9be121..cdafb88873 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -158,7 +158,7 @@ def get_kv_quant_calibration_inference_count(): @lru_cache(maxsize=None) def get_triton_autotune_level(): - return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0)) + return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 1)) g_model_init_done = False From a4ab210fc05f9a47c29b98adf88c1af1eb89f00e Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 20 Feb 2026 03:30:55 +0000 Subject: [PATCH 073/180] refactor: simplify mamba buffer copy and integrate Triton kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Reduce Triton kernels from 6 (1D/2D/3D × p2p/broadcast) to 2 (1D only) by flattening contiguous trailing dimensions via tensor view - Wire up MambaCacheManager to use the Triton kernels instead of PyTorch advanced indexing with Python for-loops - Cast strides to int64 in kernels to prevent pointer arithmetic overflow - Add Qwen3.5 multimodal vision-language model support --- .../common/basemodel/attention_vit/fa3/fp.py | 3 +- .../triton_kernel/mamba_buffer_copy.py | 671 +----------------- .../mamba_cache_mem_manager/cache_manager.py | 91 +-- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 4 +- lightllm/models/qwen2_vl/vision_process.py | 5 +- lightllm/models/qwen35_moe/model.py | 42 ++ lightllm/models/qwen3_5/__init__.py | 17 + lightllm/models/qwen3_5/infer_struct.py | 110 +++ .../models/qwen3_5/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 121 ++++ .../models/qwen3_5/layer_weights/__init__.py | 0 .../layer_weights/transformer_layer_weight.py | 166 +++++ lightllm/models/qwen3_5/model.py | 229 ++++++ lightllm/server/build_prompt.py | 23 +- lightllm/server/core/objs/sampling_params.py | 52 +- lightllm/server/function_call_parser.py | 224 ++++++ .../router/dynamic_prompt/radix_cache.py | 14 +- .../server/router/model_infer/infer_batch.py | 136 +++- .../model_infer/mode_backend/base_backend.py | 54 +- .../mode_backend/chunked_prefill/impl.py | 32 + .../mode_backend/dp_backend/impl.py | 28 + .../visualserver/model_infer/model_rpc.py | 2 +- test_gsmk.py | 241 +++++++ 23 files changed, 1506 insertions(+), 759 deletions(-) create mode 100644 lightllm/models/qwen35_moe/model.py create mode 100644 lightllm/models/qwen3_5/__init__.py create mode 100644 lightllm/models/qwen3_5/infer_struct.py create mode 100644 lightllm/models/qwen3_5/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/qwen3_5/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3_5/model.py create mode 100644 test_gsmk.py diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index 406ff7408d..d5e623b188 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -45,7 +45,8 @@ def _vit_att_fwd( False, window_size[0], window_size[1], - 0.0, + 0, # attention_chunk + 0.0, # softcap is_rotary_interleaved=False, scheduler_metadata=None, num_splits=1, diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py index b4a91f7861..6a1d8adbd5 100644 --- a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -1,10 +1,3 @@ -""" -Optimized Mamba Buffer Copy Kernels with Autotune Support - -This module provides auto-tuned Triton kernels for efficient buffer copying operations -in Mamba-style models, including support for MTP (Multi-Token Prediction) buffer broadcasting. -""" - import torch import triton import triton.language as tl @@ -35,6 +28,10 @@ def _copy_buffer_p2p_1d_kernel( layer_idx = tl.program_id(1) + layer_idx_offset block_d_idx = tl.program_id(2) + # Cast strides to int64 to prevent overflow in pointer arithmetic + stride_layer = stride_layer.to(tl.int64) + stride_index = stride_index.to(tl.int64) + # Load source and destination indices for this pair src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) @@ -58,66 +55,6 @@ def _copy_buffer_p2p_1d_kernel( tl.store(dst_ptr, data, mask=mask) -@triton.jit -def _copy_buffer_p2p_2d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - pair_idx_offset, - layer_idx_offset, - stride_layer, - stride_index, - stride_d1, - stride_d2, - d1_size, - d2_size, - num_blocks_d2, - BLOCK_D1: tl.constexpr, - BLOCK_D2: tl.constexpr, -): - """ - Kernel to copy 2D buffer from source indices to destination indices. - - Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2) - Each program copies one 2D block for one (pair, layer) combination. - """ - pair_idx = tl.program_id(0) + pair_idx_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_idx = tl.program_id(2) - - # Decompose block_idx into d1 and d2 block indices - block_d1_idx = block_idx // num_blocks_d2 - block_d2_idx = block_idx % num_blocks_d2 - - # Load source and destination indices - src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) - dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) - - # Calculate offsets for this block - d1_start = block_d1_idx * BLOCK_D1 - d2_start = block_d2_idx * BLOCK_D2 - - d1_offsets = d1_start + tl.arange(0, BLOCK_D1) - d2_offsets = d2_start + tl.arange(0, BLOCK_D2) - - # Create mask for valid indices - d1_mask = d1_offsets < d1_size - d2_mask = d2_offsets < d2_size - mask = d1_mask[:, None] & d2_mask[None, :] - - # Calculate base pointers - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index - - # Calculate full offsets - offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 - - # Load and store - data = tl.load(base_src + offsets, mask=mask, other=0.0) - tl.store(base_dst + offsets, data, mask=mask) - - @triton.jit def _copy_buffer_broadcast_1d_kernel( src_buffer_ptr, @@ -142,6 +79,10 @@ def _copy_buffer_broadcast_1d_kernel( layer_idx = tl.program_id(1) + layer_idx_offset block_d_idx = tl.program_id(2) + # Cast strides to int64 to prevent overflow in pointer arithmetic + stride_layer = stride_layer.to(tl.int64) + stride_index = stride_index.to(tl.int64) + # Load source index src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) @@ -168,219 +109,6 @@ def _copy_buffer_broadcast_1d_kernel( tl.store(dst_ptr, data, mask=mask) -@triton.jit -def _copy_buffer_broadcast_2d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - copy_idx_offset, - layer_idx_offset, - stride_layer, - stride_index, - stride_d1, - stride_d2, - d1_size, - d2_size, - num_blocks_d2, - num_dst_per_src, - BLOCK_D1: tl.constexpr, - BLOCK_D2: tl.constexpr, -): - """ - Broadcast kernel for 2D buffer copy (one source to multiple destinations). - - Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2) - """ - src_idx_in_batch = tl.program_id(0) + copy_idx_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_idx = tl.program_id(2) - - # Decompose block_idx - block_d1_idx = block_idx // num_blocks_d2 - block_d2_idx = block_idx % num_blocks_d2 - - # Load source index - src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) - - # Calculate offsets - d1_start = block_d1_idx * BLOCK_D1 - d2_start = block_d2_idx * BLOCK_D2 - - d1_offsets = d1_start + tl.arange(0, BLOCK_D1) - d2_offsets = d2_start + tl.arange(0, BLOCK_D2) - - d1_mask = d1_offsets < d1_size - d2_mask = d2_offsets < d2_size - mask = d1_mask[:, None] & d2_mask[None, :] - - # Calculate source pointer and load data once - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 - data = tl.load(base_src + offsets, mask=mask, other=0.0) - - # Broadcast to all destinations - for dst_offset in range(num_dst_per_src): - dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset - dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) - - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index - tl.store(base_dst + offsets, data, mask=mask) - - -@triton.jit -def _copy_buffer_p2p_3d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - pair_idx_offset, - layer_idx_offset, - stride_layer, - stride_index, - stride_d1, - stride_d2, - stride_d3, - d1_size, - d2_size, - d3_size, - num_blocks_d2, - num_blocks_d3, - BLOCK_D1: tl.constexpr, - BLOCK_D2: tl.constexpr, - BLOCK_D3: tl.constexpr, -): - """ - Optimized kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). - - Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) - Each program copies one 3D block for one (pair, layer) combination. - """ - pair_idx = tl.program_id(0) + pair_idx_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_idx = tl.program_id(2) - - # Decompose block_idx into d1, d2, d3 block indices - block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) - temp = block_idx % (num_blocks_d2 * num_blocks_d3) - block_d2_idx = temp // num_blocks_d3 - block_d3_idx = temp % num_blocks_d3 - - # Load source and destination indices for this pair - src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) - dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) - - # Calculate offsets for this block - d1_start = block_d1_idx * BLOCK_D1 - d2_start = block_d2_idx * BLOCK_D2 - d3_start = block_d3_idx * BLOCK_D3 - - d1_offsets = d1_start + tl.arange(0, BLOCK_D1) - d2_offsets = d2_start + tl.arange(0, BLOCK_D2) - d3_offsets = d3_start + tl.arange(0, BLOCK_D3) - - # Create masks for valid indices - d1_mask = d1_offsets < d1_size - d2_mask = d2_offsets < d2_size - d3_mask = d3_offsets < d3_size - - # 3D mask: [BLOCK_D1, BLOCK_D2, BLOCK_D3] - mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] - - # Calculate base pointers - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index - - # Calculate full 3D offsets - offsets = ( - d1_offsets[:, None, None] * stride_d1 - + d2_offsets[None, :, None] * stride_d2 - + d3_offsets[None, None, :] * stride_d3 - ) - - # Load and store - data = tl.load(base_src + offsets, mask=mask, other=0.0) - tl.store(base_dst + offsets, data, mask=mask) - - -@triton.jit -def _copy_buffer_broadcast_3d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - copy_idx_offset, - layer_idx_offset, - stride_layer, - stride_index, - stride_d1, - stride_d2, - stride_d3, - d1_size, - d2_size, - d3_size, - num_blocks_d2, - num_blocks_d3, - num_dst_per_src, - BLOCK_D1: tl.constexpr, - BLOCK_D2: tl.constexpr, - BLOCK_D3: tl.constexpr, -): - """ - Broadcast kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). - - Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) - Each program loads once from source and broadcasts to all destinations. - """ - src_idx_in_batch = tl.program_id(0) + copy_idx_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_idx = tl.program_id(2) - - # Decompose block_idx into d1, d2, d3 block indices - block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) - temp = block_idx % (num_blocks_d2 * num_blocks_d3) - block_d2_idx = temp // num_blocks_d3 - block_d3_idx = temp % num_blocks_d3 - - # Load source index - src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) - - # Calculate offsets for this block - d1_start = block_d1_idx * BLOCK_D1 - d2_start = block_d2_idx * BLOCK_D2 - d3_start = block_d3_idx * BLOCK_D3 - - d1_offsets = d1_start + tl.arange(0, BLOCK_D1) - d2_offsets = d2_start + tl.arange(0, BLOCK_D2) - d3_offsets = d3_start + tl.arange(0, BLOCK_D3) - - # Create masks - d1_mask = d1_offsets < d1_size - d2_mask = d2_offsets < d2_size - d3_mask = d3_offsets < d3_size - - mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] - - # Calculate source pointer and load data once - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - - offsets = ( - d1_offsets[:, None, None] * stride_d1 - + d2_offsets[None, :, None] * stride_d2 - + d3_offsets[None, None, :] * stride_d3 - ) - - data = tl.load(base_src + offsets, mask=mask, other=0.0) - - # Broadcast to all destinations for this source - for dst_offset in range(num_dst_per_src): - dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset - dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) - - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index - tl.store(base_dst + offsets, data, mask=mask) - - # ==================== Config Generation Functions ==================== @@ -400,47 +128,6 @@ def _get_buffer_copy_1d_configs(): return configs -def _get_buffer_copy_2d_configs(): - """Generate candidate configurations for 2D buffer copy.""" - configs = [] - for block_d1 in [16, 32, 64, 128]: - for block_d2 in [16, 32, 64, 128, 256]: - for num_warps in [2, 4, 8]: - for num_stages in [2, 3, 4]: - configs.append( - { - "BLOCK_D1": block_d1, - "BLOCK_D2": block_d2, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - -def _get_buffer_copy_3d_configs(): - """Generate candidate configurations for 3D buffer copy (5D tensor).""" - configs = [] - for block_d1 in [8, 16, 32]: - for block_d2 in [8, 16, 32, 64]: - for block_d3 in [8, 16, 32, 64, 128]: - for num_warps in [4, 8]: - for num_stages in [2, 3]: - # Skip configs that are too large for shared memory - if block_d1 * block_d2 * block_d3 > 32768: - continue - configs.append( - { - "BLOCK_D1": block_d1, - "BLOCK_D2": block_d2, - "BLOCK_D3": block_d3, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - # ==================== Static and Run Key Functions ==================== @@ -450,7 +137,7 @@ def _get_buffer_copy_static_key(src_buffer: torch.Tensor): return { "ndim": len(shape), "layer_num": shape[0], - "d_sizes": str(shape[2:]), # Dimension sizes + "d_sizes": str(shape[2:]), "dtype": str(src_buffer.dtype), } @@ -483,7 +170,6 @@ def _copy_buffer_p2p_1d_autotuned( d_size = src_buffer.shape[2] if run_config is None: - # Default config if autotune is disabled BLOCK_D = triton.next_power_of_2(min(d_size, 256)) num_warps = 4 if BLOCK_D > 256 else 2 num_stages = 2 @@ -523,75 +209,6 @@ def _copy_buffer_p2p_1d_autotuned( ) -@autotune( - kernel_name="mamba_buffer_copy_p2p_2d:v1", - configs_gen_func=_get_buffer_copy_2d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, - mutates_args=["dst_buffer"], -) -def _copy_buffer_p2p_2d_autotuned( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, - run_config: dict = None, -): - """Auto-tuned 2D buffer copy.""" - num_pairs = src_indexes.shape[0] - layer_num = src_buffer.shape[0] - d1_size = src_buffer.shape[2] - d2_size = src_buffer.shape[3] - - if run_config is None: - # Default config if autotune is disabled - BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) - BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_D1 = run_config["BLOCK_D1"] - BLOCK_D2 = run_config["BLOCK_D2"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) - num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) - num_blocks_total = num_blocks_d1 * num_blocks_d2 - - MAX_GRID_SIZE = 65535 - - for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): - pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) - pair_chunk_size = pair_chunk_end - pair_chunk_start - - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - - grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) - - _copy_buffer_p2p_2d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - pair_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - src_buffer.stride(3), - d1_size, - d2_size, - num_blocks_d2, - BLOCK_D1=BLOCK_D1, - BLOCK_D2=BLOCK_D2, - num_warps=num_warps, - num_stages=num_stages, - ) - - @autotune( kernel_name="mamba_buffer_broadcast_1d:v1", configs_gen_func=_get_buffer_copy_1d_configs, @@ -653,231 +270,19 @@ def _copy_buffer_broadcast_1d_autotuned( ) -@autotune( - kernel_name="mamba_buffer_broadcast_2d:v1", - configs_gen_func=_get_buffer_copy_2d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, - mutates_args=["dst_buffer"], -) -def _copy_buffer_broadcast_2d_autotuned( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, - run_config: dict = None, -): - """Auto-tuned 2D buffer broadcast (one src to multiple dst).""" - num_src = src_indexes.shape[0] - layer_num = src_buffer.shape[0] - d1_size = src_buffer.shape[2] - d2_size = src_buffer.shape[3] - num_dst_per_src = dst_indexes.shape[0] // num_src - - if run_config is None: - BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) - BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_D1 = run_config["BLOCK_D1"] - BLOCK_D2 = run_config["BLOCK_D2"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) - num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) - num_blocks_total = num_blocks_d1 * num_blocks_d2 - - MAX_GRID_SIZE = 65535 - - for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): - src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) - src_chunk_size = src_chunk_end - src_chunk_start - - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - - grid = (src_chunk_size, layer_chunk_size, num_blocks_total) - - _copy_buffer_broadcast_2d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - src_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - src_buffer.stride(3), - d1_size, - d2_size, - num_blocks_d2, - num_dst_per_src, - BLOCK_D1=BLOCK_D1, - BLOCK_D2=BLOCK_D2, - num_warps=num_warps, - num_stages=num_stages, - ) - - -@autotune( - kernel_name="mamba_buffer_copy_p2p_3d:v1", - configs_gen_func=_get_buffer_copy_3d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, - mutates_args=["dst_buffer"], -) -def _copy_buffer_p2p_3d_autotuned( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, - run_config: dict = None, -): - """Auto-tuned 3D data buffer copy (5D tensor).""" - num_pairs = src_indexes.shape[0] - layer_num = src_buffer.shape[0] - d1_size = src_buffer.shape[2] - d2_size = src_buffer.shape[3] - d3_size = src_buffer.shape[4] - - if run_config is None: - BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) - BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) - BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) - num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 - num_stages = 2 - else: - BLOCK_D1 = run_config["BLOCK_D1"] - BLOCK_D2 = run_config["BLOCK_D2"] - BLOCK_D3 = run_config["BLOCK_D3"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) - num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) - num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) - num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 - - MAX_GRID_SIZE = 65535 - - for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): - pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) - pair_chunk_size = pair_chunk_end - pair_chunk_start - - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - - grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) - - _copy_buffer_p2p_3d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - pair_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - src_buffer.stride(3), - src_buffer.stride(4), - d1_size, - d2_size, - d3_size, - num_blocks_d2, - num_blocks_d3, - BLOCK_D1=BLOCK_D1, - BLOCK_D2=BLOCK_D2, - BLOCK_D3=BLOCK_D3, - num_warps=num_warps, - num_stages=num_stages, - ) - - -@autotune( - kernel_name="mamba_buffer_broadcast_3d:v1", - configs_gen_func=_get_buffer_copy_3d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, - mutates_args=["dst_buffer"], -) -def _copy_buffer_broadcast_3d_autotuned( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, - run_config: dict = None, -): - """Auto-tuned 3D data buffer broadcast (5D tensor, one src to multiple dst).""" - num_src = src_indexes.shape[0] - layer_num = src_buffer.shape[0] - d1_size = src_buffer.shape[2] - d2_size = src_buffer.shape[3] - d3_size = src_buffer.shape[4] - num_dst_per_src = dst_indexes.shape[0] // num_src - - if run_config is None: - BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) - BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) - BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) - num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 - num_stages = 2 - else: - BLOCK_D1 = run_config["BLOCK_D1"] - BLOCK_D2 = run_config["BLOCK_D2"] - BLOCK_D3 = run_config["BLOCK_D3"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) - num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) - num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) - num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 - - MAX_GRID_SIZE = 65535 - - for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): - src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) - src_chunk_size = src_chunk_end - src_chunk_start +# ==================== Unified Interface ==================== - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - grid = (src_chunk_size, layer_chunk_size, num_blocks_total) +def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: + """Flatten all dimensions after [layer_num, buffer_size] into one. - _copy_buffer_broadcast_3d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - src_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - src_buffer.stride(3), - src_buffer.stride(4), - d1_size, - d2_size, - d3_size, - num_blocks_d2, - num_blocks_d3, - num_dst_per_src, - BLOCK_D1=BLOCK_D1, - BLOCK_D2=BLOCK_D2, - BLOCK_D3=BLOCK_D3, - num_warps=num_warps, - num_stages=num_stages, - ) - - -# ==================== Unified Interface ==================== + For a contiguous buffer of shape [L, B, d1, d2, ...], returns a view + of shape [L, B, d1*d2*...]. This is a zero-copy operation. + """ + if buffer.ndim == 3: + return buffer + L, B = buffer.shape[:2] + return buffer.view(L, B, -1) def copy_buffer_p2p( @@ -889,7 +294,8 @@ def copy_buffer_p2p( """ Copy buffers from source indices to destination indices with auto-tuning. - Supports 3D (conv states), 4D (standard buffers), and 5D (SSM states) buffers. + Supports any buffer shape [layer_num, buffer_size, ...] as long as the + trailing dimensions are contiguous (which is the default for torch.zeros). Args: src_buffer: Source buffer tensor [layer_num, buffer_size, ...] @@ -901,20 +307,9 @@ def copy_buffer_p2p( assert src_indexes.shape == dst_indexes.shape assert len(src_indexes.shape) == 1 - if len(src_buffer.shape) == 3: - # 1D case: (layer_num, buffer_size, d) - _copy_buffer_p2p_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) - - elif len(src_buffer.shape) == 4: - # 2D case: (layer_num, buffer_size, d1, d2) - _copy_buffer_p2p_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) - - elif len(src_buffer.shape) == 5: - # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory - _copy_buffer_p2p_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) - - else: - raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") + src_flat = _flatten_trailing_dims(src_buffer) + dst_flat = _flatten_trailing_dims(dst_buffer) + _copy_buffer_p2p_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) def copy_buffer_broadcast( @@ -939,23 +334,11 @@ def copy_buffer_broadcast( assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D, got shape {dst_indexes.shape}" num_src = src_indexes.shape[0] - assert num_src == dst_indexes.shape[0], f"Mismatch: src_indexes {num_src} vs dst_indexes {dst_indexes.shape[0]}" # Flatten dst_indexes for kernel dst_indexes_flat = dst_indexes.reshape(-1).contiguous() - if len(src_buffer.shape) == 3: - # 1D case - _copy_buffer_broadcast_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) - - elif len(src_buffer.shape) == 4: - # 2D case - _copy_buffer_broadcast_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) - - elif len(src_buffer.shape) == 5: - # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory - _copy_buffer_broadcast_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) - - else: - raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") + src_flat = _flatten_trailing_dims(src_buffer) + dst_flat = _flatten_trailing_dims(dst_buffer) + _copy_buffer_broadcast_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat) diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 348b14192c..272a999bb1 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -6,6 +6,7 @@ from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.common.allocator_utils import TokenAllocator +from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_buffer_p2p, copy_buffer_broadcast from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt @@ -56,67 +57,20 @@ def get_mamba_cache(self, layer_idx: int): return conv_state, ssm_state def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): - """ - Copy buffers from source indices to destination indices using optimized Triton kernel. - - Args: - src_buffer_indexes: Source buffer indices (1D tensor) - dst_buffer_indexes: Destination buffer indices (1D tensor) - """ - assert src_buffer_indexes.dim() == 1 - assert dst_buffer_indexes.dim() == 1 - assert src_buffer_indexes.shape[0] == dst_buffer_indexes.shape[0] - - # Validate indices are within valid range [0, size] (size+1 is the buffer dim) - max_valid_idx = self.size # HOLD_BUFFER_INDEX = size is valid - src_max = src_buffer_indexes.max().item() if src_buffer_indexes.numel() > 0 else -1 - src_min = src_buffer_indexes.min().item() if src_buffer_indexes.numel() > 0 else -1 - dst_max = dst_buffer_indexes.max().item() if dst_buffer_indexes.numel() > 0 else -1 - dst_min = dst_buffer_indexes.min().item() if dst_buffer_indexes.numel() > 0 else -1 - - if src_min < 0 or src_max > max_valid_idx or dst_min < 0 or dst_max > max_valid_idx: - logger.error( - f"Invalid buffer indices: src=[{src_min}, {src_max}], dst=[{dst_min}, {dst_max}], " - f"valid range=[0, {max_valid_idx}], conv shape={self.conv_state_cache.buffer.shape}, " - f"ssm shape={self.ssm_state_cache.buffer.shape}" - ) - raise ValueError("Invalid buffer indices for copy_buffer_p2p") - - # Use PyTorch advanced indexing for buffer copy (safer than Triton for complex shapes) - # The buffer shape is [layer_num, buffer_size, *shape] - # We need to copy all layers for the given buffer indices - src_idx = src_buffer_indexes.long() - dst_idx = dst_buffer_indexes.long() - - # Copy conv_state: [layer_num, buffer_size, d1, d2] - self.conv_state_cache.buffer[:, dst_idx, ...] = self.conv_state_cache.buffer[:, src_idx, ...] - - # Copy ssm_state: [layer_num, buffer_size, d1, d2, d3] - self.ssm_state_cache.buffer[:, dst_idx, ...] = self.ssm_state_cache.buffer[:, src_idx, ...] - return + copy_buffer_p2p( + self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes + ) + copy_buffer_p2p( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes + ) def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): - assert src_buffer_index.dim() == 1 - assert dst_buffer_indexes.dim() == 2 - assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] - - # Use PyTorch advanced indexing for broadcast copy - # src_buffer_index: [num_src] - # dst_buffer_indexes: [num_src, num_dst_per_src] - src_idx = src_buffer_index.long() - dst_idx = dst_buffer_indexes.long() - - # Broadcast each source to all its destinations - # For each (src, dst_group), copy buffer[src] to buffer[dst1], buffer[dst2], ... - num_src, num_dst_per_src = dst_idx.shape - for i in range(num_src): - src = src_idx[i : i + 1] # Keep as 1D tensor with 1 element - dsts = dst_idx[i, :] # 1D tensor with num_dst_per_src elements - # Copy conv_state - self.conv_state_cache.buffer[:, dsts, ...] = self.conv_state_cache.buffer[:, src, ...] - # Copy ssm_state - self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] - return + copy_buffer_broadcast( + self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) + copy_buffer_broadcast( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): """ @@ -125,22 +79,9 @@ def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_i This is used for MTP mode where each buffer maintains its own independent conv state, but SSM states need to be synchronized. """ - assert src_buffer_index.dim() == 1 - assert dst_buffer_indexes.dim() == 2 - assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] - - # Use PyTorch advanced indexing for SSM-only broadcast copy - src_idx = src_buffer_index.long() - dst_idx = dst_buffer_indexes.long() - - # Broadcast each source to all its destinations (SSM only) - num_src = dst_idx.shape[0] - for i in range(num_src): - src = src_idx[i : i + 1] - dsts = dst_idx[i, :] - # Only copy ssm_state, NOT conv_state - self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] - return + copy_buffer_broadcast( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) def free(self, free_index: Union[torch.Tensor, List[int]]): """ diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..825a985b46 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -227,14 +227,14 @@ def _init_datatype(self): def rot_pos_emb(self, grid_thw): pos_ids = [] s = self.spatial_merge_size - for _, h, w in grid_thw: + for t, h, w in grid_thw: pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() cos_full, sin_full = self.rotary_pos_emb(max_grid_size) diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index f2cd38ec8e..bc313fe467 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -187,7 +187,10 @@ def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torc if image.mode != "RGB": image = image.convert("RGB") image_arr = np.asarray(image, dtype=np.uint8) - image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) + # Copy to ensure writable array (avoids PyTorch warning for read-only NumPy arrays) + image_data = ( + torch.from_numpy(image_arr.copy()).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) + ) grouped_images, grouped_images_index = group_images_by_shape( [image_data], disable_grouping=self.disable_grouping diff --git a/lightllm/models/qwen35_moe/model.py b/lightllm/models/qwen35_moe/model.py new file mode 100644 index 0000000000..ee149f3a81 --- /dev/null +++ b/lightllm/models/qwen35_moe/model.py @@ -0,0 +1,42 @@ +import os +import json + +from lightllm.models.qwen3_vl.model import QWen3VLTokenizer +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.common.build_utils import repair_config +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights + + +class QWen35Tokenizer(QWen3VLTokenizer): + def __init__(self, tokenizer=None, image_processor=None, **kwargs): + super().__init__(tokenizer, image_processor, **kwargs) + + +@ModelRegistry(["qwen3_5"], is_multimodal=True) +class Qwen35MoeTpPartModel(Qwen3NextTpPartModel): + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["text_config"] + + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + repair_config(self.config, same_names=["intermediate_size", "moe_intermediate_size"]) + + # Handle fine-tuning config if present + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + + def _load_hf_weights(self): + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + return diff --git a/lightllm/models/qwen3_5/__init__.py b/lightllm/models/qwen3_5/__init__.py new file mode 100644 index 0000000000..47667a92d5 --- /dev/null +++ b/lightllm/models/qwen3_5/__init__.py @@ -0,0 +1,17 @@ +""" +Qwen3.5 Multimodal Model Module + +Provides Qwen3.5 multimodal models with hybrid attention and vision-language support. +""" + +from .model import ( + Qwen3_5TpPartModel, + Qwen3_5MOETpPartModel, + QWen3_5Tokenizer, +) + +__all__ = [ + "Qwen3_5TpPartModel", + "Qwen3_5MOETpPartModel", + "QWen3_5Tokenizer", +] diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py new file mode 100644 index 0000000000..9ce407cacf --- /dev/null +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -0,0 +1,110 @@ +""" +Qwen3.5 Multimodal Inference State + +This module provides inference state for Qwen3.5 multimodal model that combines: +- Qwen3Next features (output gating, MTP-aware batching, hybrid attention buffer management) +- Qwen3VL multimodal support (mrope position encoding for images/videos) +""" + +import torch +from typing import List + +from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen35InferStateInfo(Qwen2VLInferStateInfo): + """ + Inference state for Qwen3.5 multimodal model with: + - gate_value attribute for output gating in full attention layers + - MTP-aware batching for multi-token prediction + - Custom buffer management for hybrid attention (full + linear) + - mrope position encoding support for multimodal inputs + """ + + def __init__(self): + super().__init__() + # For output gating in full attention layers (from Qwen3Next) + self.gate_value = None + # MTP-aware attributes (from Qwen3Next) + self.b_att_seq_len = None + self.att_batch_size = None + self.real_req_idx = None + self.mtp_buffer_idx_list = None + self.b_buffer_idx = None + + def _compute_mrope_delta(self, images: List) -> int: + """Compute the position delta for mrope based on image tokens. + + The position delta is the sum of all image position deltas (grid_thwd[3]) + which accounts for the extra position IDs consumed by multimodal content. + """ + position_delta = 0 + for image in images: + position_delta += image["grid_thwd"][3] + return position_delta + + def init_some_extra_state(self, model): + """Initialize Qwen3.5-specific state including mrope and MTP support""" + # First, initialize mrope position encoding using parent class + # which now has the corrected delta computation + rope_scaling = model.config.get("rope_scaling", {}) + self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) + + # Call the grandparent's (LlamaInferStateInfo) init_some_extra_state first + # to set up basic state + from lightllm.common.basemodel.infer_struct import InferStateInfo + + InferStateInfo.init_some_extra_state(self, model) + + # Now handle mrope position encoding with corrected delta computation + if self.is_prefill: + self.position_ids = self.get_mrope_position(self.multimodal_params) + else: + # Decode phase: compute correct mrope delta + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + b_position_delta[batch_idx] = self._compute_mrope_delta(p.get("images", [])) + + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1) + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids] + self.position_sin = model._sin_cached[self.position_ids] + + # Now handle MTP-aware batching (from Qwen3Next) + args_mtp_step = get_env_start_args().mtp_step + mtp_size = args_mtp_step + 1 + + if self.is_prefill: + # Prefill: Standard initialization + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() + else: + # Decode: MTP-aware handling + # In MTP mode, each request has (mtp_step + 1) tokens + # att_batch_size is the number of unique requests + self.att_batch_size = self.batch_size // mtp_size + + # Use only the sequence lengths for the last token of each MTP group + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step::mtp_size].contiguous() + self.real_req_idx = self.b_req_idx[args_mtp_step::mtp_size] + else: + self.b_att_seq_len = self.b_seq_len + self.real_req_idx = self.b_req_idx + + # Buffer indices for Mamba cache (conv and SSM states) + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.real_req_idx, :].flatten().contiguous() + + # Create per-step buffer indices for MTP + if args_mtp_step > 0: + buffer_idx_list = [] + for step_id in range(mtp_size): + buffer_idx_list.append(self.b_buffer_idx[step_id::mtp_size].tolist()) + self.mtp_buffer_idx_list = torch.tensor( + buffer_idx_list, dtype=torch.int32, device=self.b_buffer_idx.device + ) + + return diff --git a/lightllm/models/qwen3_5/layer_infer/__init__.py b/lightllm/models/qwen3_5/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..3bbc0ee3be --- /dev/null +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -0,0 +1,121 @@ +import torch +import torch.distributed as dist +from typing import Tuple + +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( + Qwen3NextFullAttentionTransformerLayerInfer, + Qwen3NextGatedDeltaNetTransformerLayerInfer, +) +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen35FullAttentionTransformerLayerInfer(Qwen3NextFullAttentionTransformerLayerInfer): + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + # Initialize mrope section from config + rope_scaling = network_config.get("rope_scaling", {}) + mrope_section = rope_scaling.get("mrope_section", [11, 11, 10]) + self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + + # Q and gate projection + if not infer_state.is_prefill: + q_gate_buf = self._get_decode_buffer( + "q_gate_out", + (self._graph_max_batch_size, self.tp_q_gate_dim), + input.dtype, + input.device, + )[: input.size(0)] + q_gate = layer_weight.q_gate_proj.mm(input, out=q_gate_buf) + kv_buf = self._get_decode_buffer( + "kv_out", + (self._graph_max_batch_size, self.tp_kv_dim), + input.dtype, + input.device, + )[: input.size(0)] + kv_out = layer_weight.kv_proj.mm(input, out=kv_buf) + else: + q_gate = layer_weight.q_gate_proj.mm(input) + kv_out = layer_weight.kv_proj.mm(input) + + q_dim = self.tp_q_head_num_ * self.head_dim_ + q = q_gate[:, :q_dim].contiguous() + # In-place sigmoid for gate + infer_state.gate_value = q_gate[:, q_dim:].sigmoid_() + cache_kv = kv_out.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + # Q normalization (in-place) + from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + gemma_rmsnorm_forward( + q.view(-1, self.head_dim_), + layer_weight.q_norm_weight_.weight, + eps=self.eps_, + out=q.view(-1, self.head_dim_), + ) + + k_input = cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]) + if not infer_state.is_prefill: + k_normed = self._get_decode_buffer( + "k_norm_out", + (self._graph_max_batch_size * self.tp_k_head_num_, cache_kv.shape[-1]), + k_input.dtype, + k_input.device, + )[: k_input.shape[0]] + gemma_rmsnorm_forward(k_input, layer_weight.k_norm_weight_.weight, eps=self.eps_, out=k_normed) + else: + k_normed = gemma_rmsnorm_forward(k_input, layer_weight.k_norm_weight_.weight, eps=self.eps_) + cache_kv[:, : self.tp_k_head_num_, :] = k_normed.view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + + if hasattr(infer_state, "position_cos") and infer_state.position_cos is not None: + rotary_dim = int(self.head_dim_ * self.partial_rotary_factor) + + q_rotary = q.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, :rotary_dim].contiguous() + k_rotary = cache_kv[:, : self.tp_k_head_num_, :rotary_dim].contiguous() + + mrope_triton_fused( + q_rotary, + k_rotary, + infer_state.position_cos, + infer_state.position_sin, + self.mrope_section, + is_interleaved=True, # Qwen3 uses interleaved mrope + ) + + q.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, :rotary_dim] = q_rotary + cache_kv[:, : self.tp_k_head_num_, :rotary_dim] = k_rotary + else: + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, + ) + + return q, cache_kv + + +class Qwen35GatedDeltaNetTransformerLayerInfer(Qwen3NextGatedDeltaNetTransformerLayerInfer): + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + rope_scaling = network_config.get("rope_scaling", {}) + mrope_section = rope_scaling.get("mrope_section", [11, 11, 10]) + self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") diff --git a/lightllm/models/qwen3_5/layer_weights/__init__.py b/lightllm/models/qwen3_5/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..ca1f9d992e --- /dev/null +++ b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py @@ -0,0 +1,166 @@ +import torch + +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def split_fused_expert_weights(weights, layer_num, moe_intermediate_size): + layer_prefix = f"model.layers.{layer_num}." + keys = list(weights.keys()) + gate_up_count = 0 + down_count = 0 + num_experts = 0 + + for k in keys: + if not k.startswith(layer_prefix): + continue + + if "mlp.experts.gate_up_proj" in k: + fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size] + num_experts = fused_weight.shape[0] + + prefix = k.rsplit(".gate_up_proj", 1)[0] + + gate_weight = fused_weight[:, :moe_intermediate_size, :] + up_weight = fused_weight[:, moe_intermediate_size:, :] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] + weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] + + gate_up_count += 1 + + elif "mlp.experts.down_proj" in k: + down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] + num_experts = down_weight.shape[0] + + prefix = k.rsplit(".down_proj", 1)[0] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx] + + down_count += 1 + + +class Qwen35NextFullAttentionTransformerLayerWeight(Qwen3NextFullAttentionTransformerLayerWeight): + def load_hf_weights(self, weights): + self._split_fused_expert_weights(weights) + super().load_hf_weights(weights) + + def _split_fused_expert_weights(self, weights): + moe_intermediate_size = self.network_config_.get("moe_intermediate_size") + if moe_intermediate_size is None: + moe_intermediate_size = self.network_config_.get("intermediate_size") + + if moe_intermediate_size is None: + logger.warning( + f"Layer {self.layer_num_}: Cannot find moe_intermediate_size in config, " + "skipping fused expert weight splitting" + ) + return + + layer_prefix = f"model.layers.{self.layer_num_}.mlp.experts" + has_fused_weights = any(layer_prefix in k and ("gate_up_proj" in k or "down_proj" in k) for k in weights.keys()) + + if has_fused_weights: + split_fused_expert_weights(weights, self.layer_num_, moe_intermediate_size) + + +class Qwen35NextGatedDeltaNetTransformerLayerWeight(Qwen3NextGatedDeltaNetTransformerLayerWeight): + def _init_gdn_weight(self): + # Initialize everything from parent first, then override only linear_in_proj. + super()._init_gdn_weight() + + prefix = f"model.layers.{self.layer_num_}.linear_attn" + hidden_size = self.network_config_["hidden_size"] + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + + # NOTE: keep grouped layout directly (q, k, v, z, b, a). + self.linear_in_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[ + qk_dim, + qk_dim, + v_dim, + v_dim, + self.linear_num_v_heads, + self.linear_num_v_heads, + ], + weight_names=[ + f"{prefix}.in_proj_q.weight", + f"{prefix}.in_proj_k.weight", + f"{prefix}.in_proj_v.weight", + f"{prefix}.in_proj_z.weight", + f"{prefix}.in_proj_b.weight", + f"{prefix}.in_proj_a.weight", + ], + data_type=self.data_type_, + quant_method=self.get_quant_method("in_proj_weight"), + ) + + def load_hf_weights(self, weights): + self._split_fused_expert_weights(weights) + super().load_hf_weights(weights) + + def _preprocess_weight(self, weights): + # Keep parent conv1d preprocessing path. + linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" + linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" + + if linear_conv1d_weight_name in weights: + weights[linear_conv1d_weight_name] = self._parse_linear_conv1d( + weights[linear_conv1d_weight_name].squeeze(1) + ) + if linear_conv1d_bias_name in weights: + weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name]) + + self._split_linear_in_proj_qkv(weights) + + def _split_linear_in_proj_qkv(self, weights): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + qkv_name = f"{prefix}.in_proj_qkv.weight" + if qkv_name not in weights: + return + + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + expected_rows = 2 * qk_dim + v_dim + + qkv = weights[qkv_name] + if qkv.shape[0] != expected_rows: + logger.warning( + f"Layer {self.layer_num_}: unexpected in_proj_qkv shape " + f"{tuple(qkv.shape)}, expected first dim {expected_rows}; skip split" + ) + return + + q, k, v = torch.split(qkv, [qk_dim, qk_dim, v_dim], dim=0) + weights[f"{prefix}.in_proj_q.weight"] = q + weights[f"{prefix}.in_proj_k.weight"] = k + weights[f"{prefix}.in_proj_v.weight"] = v + del weights[qkv_name] + + def _split_fused_expert_weights(self, weights): + moe_intermediate_size = self.network_config_.get("moe_intermediate_size") + if moe_intermediate_size is None: + moe_intermediate_size = self.network_config_.get("intermediate_size") + + if moe_intermediate_size is None: + logger.warning( + f"Layer {self.layer_num_}: Cannot find moe_intermediate_size in config, " + "skipping fused expert weight splitting" + ) + return + + layer_prefix = f"model.layers.{self.layer_num_}.mlp.experts" + has_fused_weights = any(layer_prefix in k and ("gate_up_proj" in k or "down_proj" in k) for k in weights.keys()) + + if has_fused_weights: + split_fused_expert_weights(weights, self.layer_num_, moe_intermediate_size) diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py new file mode 100644 index 0000000000..fdbccdf787 --- /dev/null +++ b/lightllm/models/qwen3_5/model.py @@ -0,0 +1,229 @@ +import os +import json +import time +import gc +from safetensors import safe_open +from tqdm import tqdm +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35NextFullAttentionTransformerLayerWeight, + Qwen35NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.qwen3_vl.model import QWen3VLTokenizer +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( + Qwen35FullAttentionTransformerLayerInfer, + Qwen35GatedDeltaNetTransformerLayerInfer, +) +from lightllm.models.qwen3_5.infer_struct import Qwen35InferStateInfo +from lightllm.common.build_utils import repair_config +from lightllm.utils.log_utils import init_logger +import lightllm.utils.petrel_helper as utils + +logger = init_logger(__name__) + + +class QWen3_5Tokenizer(QWen3VLTokenizer): + """ + Tokenizer for Qwen3.5 multimodal model. + + Inherits all multimodal tokenization logic from Qwen3VL, + including image and video token handling. + """ + + def __init__(self, tokenizer=None, image_processor=None, **kwargs): + super().__init__(tokenizer, image_processor, **kwargs) + + +@ModelRegistry(["qwen3_5"], is_multimodal=True) +class Qwen3_5TpPartModel(Qwen3NextTpPartModel): + """ + Qwen3.5 Multimodal Model (Dense Variant) + + This model combines: + - Hybrid attention from Qwen3Next (Gated Delta Networks + Full Attention) + - Multimodal capabilities from Qwen3VL (image/video processing) + - Dense MLP layers (non-MoE) + + Architecture: + - Every Nth layer uses full attention (config: full_attention_interval) + - Other layers use linear attention (Gated Delta Networks) + - Vision encoder processes images/videos before text model + - Multimodal embeddings merged with text embeddings + """ + + # Override to use multimodal pre-layer for vision processing + pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer + + # Override to use multimodal pre/post weights (includes vision weights) + pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + + # Override to use Qwen3.5 infer state with mrope support + infer_state_class = Qwen35InferStateInfo + + def __init__(self, kvargs): + """ + Initialize Qwen3.5 model. + + Args: + kvargs: Dictionary containing: + - weight_dir: Path to model weights + - max_total_token_num: Maximum total tokens + - Additional model configuration + """ + super().__init__(kvargs) + logger.info("Initialized Qwen3.5 multimodal model") + + def _init_config(self): + """ + Load and parse Qwen3.5 configuration. + + Qwen3.5 uses a nested config structure: + { + "model_type": "qwen3_5", + "text_config": { ... }, + "vision_config": { ... } + } + + This method extracts the text_config for the language model + and stores vision_config for multimodal processing. + """ + config_path = os.path.join(self.weight_dir_, "config.json") + + with open(config_path, "r") as json_file: + all_config = json.load(json_file) + + # Extract text config for language model + self.config = all_config["text_config"] + + # Store vision config for multimodal components + self.vision_config = all_config.get("vision_config", None) + + if self.vision_config is None: + logger.warning("No vision_config found in checkpoint. " "Multimodal features may not work correctly.") + + # Apply standard config repairs + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + + # Qwen3.5 uses layer_types array instead of decoder_sparse_step for MoE placement + # Set default for decoder_sparse_step (used by inherited Qwen3Next weight initialization) + # Default to 1 meaning all layers with num_experts > 0 use MoE + if "decoder_sparse_step" not in self.config: + self.config["decoder_sparse_step"] = 1 + + # Ensure mlp_only_layers exists (default to empty list) + if "mlp_only_layers" not in self.config: + self.config["mlp_only_layers"] = [] + + # Qwen3.5 MoE uses moe_intermediate_size instead of intermediate_size + # Set intermediate_size for compatibility with base layer weight classes + if "intermediate_size" not in self.config: + if "moe_intermediate_size" in self.config: + self.config["intermediate_size"] = self.config["moe_intermediate_size"] + else: + # Default fallback: 4x hidden_size (common in transformer architectures) + self.config["intermediate_size"] = self.config.get("hidden_size", 4096) * 4 + + # Qwen3.5 stores RoPE config under text_config.rope_parameters. + # Qwen3Next/llama infer path expects flattened keys like rope_theta and + # partial_rotary_factor on the main config dict. + rope_parameters = self.config.get("rope_parameters") + if isinstance(rope_parameters, dict): + if "rope_theta" in rope_parameters and "rope_theta" not in self.config: + self.config["rope_theta"] = rope_parameters["rope_theta"] + if "partial_rotary_factor" in rope_parameters and "partial_rotary_factor" not in self.config: + self.config["partial_rotary_factor"] = rope_parameters["partial_rotary_factor"] + # Preserve the richer RoPE metadata in the expected field when absent. + if "rope_scaling" not in self.config: + self.config["rope_scaling"] = rope_parameters + + # MoE routing parameters - set defaults for Qwen3.5 compatibility + if "norm_topk_prob" not in self.config: + self.config["norm_topk_prob"] = True # Standard default for MoE models + + # Handle fine-tuning config if present + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + + # Calculate num_kv_heads for KV cache memory management + # Required by parent class _init_mem_manager() in Qwen3NextTpPartModel + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + self.trans_layers_weight = [ + ( + Qwen35NextFullAttentionTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + if (i + 1) % num_full_attention_layers == 0 + else Qwen35NextGatedDeltaNetTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + ) + for i in range(self.config["n_layer"]) + ] + + def _init_infer_layer(self): + """ + Initialize inference layers for Qwen3.5 multimodal model. + + Uses mrope-enabled transformer layers to properly handle image/video + tokens with 3D position encoding (temporal, height, width). + + This overrides the parent class to use Qwen35* layer classes instead + of Qwen3Next* layer classes. + """ + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + + self.layers_infer = [ + ( + Qwen35FullAttentionTransformerLayerInfer(i, network_config=self.config) + if (i + 1) % num_full_attention_layers == 0 + else Qwen35GatedDeltaNetTransformerLayerInfer(i, network_config=self.config) + ) + for i in range(self.config["n_layer"]) + ] + + +@ModelRegistry(["qwen3_5_moe"], is_multimodal=True) +class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): + """ + Qwen3.5-MoE Multimodal Model (Mixture of Experts Variant) + + Extends Qwen3.5 with sparse expert routing: + - Same hybrid attention architecture as Qwen3.5 + - MoE layers replace dense MLP layers + - Expert routing handled by Qwen3NextSparseMoeBlock (inherited) + + The MoE variant is automatically configured by inheriting from + Qwen3NextTpPartModel, which inherits from Qwen3MOEModel. + + No additional configuration needed - MoE support is built-in. + """ + + def __init__(self, kvargs): + """ + Initialize Qwen3.5-MoE model. + + Args: + kvargs: Dictionary containing: + - weight_dir: Path to model weights + - max_total_token_num: Maximum total tokens + - Additional model configuration + """ + super().__init__(kvargs) + logger.info("Initialized Qwen3.5-MoE multimodal model with expert routing") diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index f770459a55..5356da4caf 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -15,8 +15,28 @@ def init_tokenizer(args): async def build_prompt(request, tools) -> str: global tokenizer + import json + # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + + # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility + # Qwen's chat template expects arguments to be a dict (uses |items filter) + # but OpenAI format sends arguments as a JSON string + for msg in messages: + tool_calls = msg.get("tool_calls") + if tool_calls and isinstance(tool_calls, list): + for tool_call in tool_calls: + func = tool_call.get("function") + if func and isinstance(func, dict): + args = func.get("arguments") + if isinstance(args, str) and args: + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, TypeError): + # Keep original string if not valid JSON + pass + kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -32,7 +52,8 @@ async def build_prompt(request, tools) -> str: # This except branch will be triggered when the chosen model # has a different tools input format that is not compatiable # with openAI's apply_chat_template tool_call format, like Mistral. - tools = [t if "function" in t else {"function": t} for t in tools] + if tools is not None: + tools = [t if "function" in t else {"function": t} for t in tools] input_str = tokenizer.apply_chat_template( **kwargs, tokenize=True, diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index d955aa6a87..99331c061c 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -333,15 +333,31 @@ class SamplingParams(ctypes.Structure): def init(self, tokenizer, **kwargs): super().__init__() + # 移除kwargs中为null的参数,避免覆盖默认值 + kwargs = {k: v for k, v in kwargs.items() if v is not None} + self.best_of = kwargs.get("best_of", 1) self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) + do_sample = kwargs.get("do_sample", SamplingParams._do_sample) + self.do_sample = False if do_sample is None else do_sample + + presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) + self.presence_penalty = 0.0 if presence_penalty is None else presence_penalty + + frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) + self.frequency_penalty = 0.0 if frequency_penalty is None else frequency_penalty + + repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) + self.repetition_penalty = 1.0 if repetition_penalty is None else repetition_penalty + + temperature = kwargs.get("temperature", SamplingParams._temperature) + self.temperature = 1.0 if temperature is None else temperature + + top_p = kwargs.get("top_p", SamplingParams._top_p) + self.top_p = 1.0 if top_p is None else top_p + + top_k = kwargs.get("top_k", SamplingParams._top_k) + self.top_k = -1 if top_k is None else top_k self.ignore_eos = kwargs.get("ignore_eos", False) self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) self.max_new_tokens = kwargs.get("max_new_tokens", 16) @@ -408,13 +424,35 @@ def init(self, tokenizer, **kwargs): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() + # Some checkpoints store null sampling fields in generation_config.json. + # Keep robust numeric defaults instead of propagating None into ctypes fields. cls._do_sample = generation_cfg.get("do_sample", False) + if cls._do_sample is None: + cls._do_sample = False + cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) + if cls._presence_penalty is None: + cls._presence_penalty = 0.0 + cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) + if cls._frequency_penalty is None: + cls._frequency_penalty = 0.0 + cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) + if cls._repetition_penalty is None: + cls._repetition_penalty = 1.0 + cls._temperature = generation_cfg.get("temperature", 1.0) + if cls._temperature is None: + cls._temperature = 1.0 + cls._top_p = generation_cfg.get("top_p", 1.0) + if cls._top_p is None: + cls._top_p = 1.0 + cls._top_k = generation_cfg.get("top_k", -1) + if cls._top_k is None: + cls._top_k = -1 except: pass diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 9214715b1d..4c494d138b 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import json import orjson import logging @@ -1443,6 +1444,228 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text="", calls=calls) +class Qwen3CoderDetector(BaseFormatDetector): + """ + Detector for Qwen3-Coder XML-style function call format. + + Format Structure: + ``` + + + + value1 + + + value2 + + + + ``` + + Key differences from Qwen25Detector (JSON-based): + - Parameters are XML key-value pairs, not JSON objects + - Function name is embedded in the tag attribute + - Values need schema-aware type conversion (string by default) + + Reference: https://docs.vllm.ai/projects/recipes/en/latest/Qwen/Qwen3-Coder-480B-A35B.html + """ + + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + self.tool_call_separator = "\n" + + # Regex patterns + self.tool_call_block_regex = re.compile(r"(.*?)", re.DOTALL) + self.function_regex = re.compile(r"||(?=)|$)", re.DOTALL + ) + self._normal_text_buffer = "" + + def has_tool_call(self, text: str) -> bool: + return " Dict: + """Extract parameter type configuration from tool definitions.""" + for tool in tools: + if tool.function.name == func_name and tool.function.parameters: + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + return {} + + def _convert_param_value(self, value: str, param_name: str, param_config: Dict, func_name: str) -> Any: + """Convert parameter value based on schema type. Safe alternative to eval().""" + if value.lower() == "null": + return None + + if param_name not in param_config: + return value + + prop = param_config.get(param_name, {}) + param_type = str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + + if param_type in ("string", "str", "enum"): + return value + elif param_type.startswith("int") or param_type == "integer": + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ("number", "float", "double"): + try: + fv = float(value) + return int(fv) if fv == int(fv) else fv + except (ValueError, TypeError): + return value + elif param_type in ("boolean", "bool"): + return value.lower() == "true" + elif param_type in ("object", "array"): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError, ValueError): + try: + return ast.literal_eval(value) + except (ValueError, SyntaxError, TypeError): + return value + return value + + def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional[ToolCallItem]: + """Parse a single ... block into a ToolCallItem.""" + try: + end_index = function_str.index(">") + except ValueError: + return None + + func_name = function_str[:end_index].strip() + tool_indices = self._get_tool_indices(tools) + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + return None + + parameters_text = function_str[end_index + 1 :] + param_config = self._get_param_config(func_name, tools) + param_dict = {} + + for match in self.parameter_regex.findall(parameters_text): + try: + idx = match.index(">") + except ValueError: + continue + param_name = match[:idx].strip() + param_value = match[idx + 1 :] + # Strip leading/trailing newlines from value + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) + + return ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=json.dumps(param_dict, ensure_ascii=False), + ) + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + + if " StreamingParseResult: + """Streaming incremental parsing for Qwen3-Coder XML tool calls.""" + self._buffer += new_text + current_text = self._buffer + + if not self.has_tool_call(current_text): + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + return StreamingParseResult() + self._buffer = "" + cleaned = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=cleaned) + + # Check for complete tool call blocks + if self.eot_token in current_text: + result = self.detect_and_parse(current_text, tools) + last_end = current_text.rfind(self.eot_token) + if last_end != -1: + self._buffer = current_text[last_end + len(self.eot_token) :].lstrip() + else: + self._buffer = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + return result + + # Partial tool call - try to extract function name for early streaming + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls = [] + tool_call_start = current_text.find(self.bot_token) + if tool_call_start == -1: + return StreamingParseResult() + + content_after = current_text[tool_call_start + len(self.bot_token) :] + func_prefix = "") + if gt_pos == -1: + return StreamingParseResult() + + func_name = after_func[:gt_pos].strip() + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + if func_name and func_name in self._tool_indices and not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = {"name": func_name, "arguments": {}} + + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1461,6 +1684,7 @@ class FunctionCallParser: "mistral": MistralDetector, "qwen": Qwen25Detector, "qwen25": Qwen25Detector, + "qwen3_coder": Qwen3CoderDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 88b099459b..4403dba517 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -31,6 +31,12 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 + # Used by hybrid attention models (e.g., Qwen3Next) to track + # a per-request buffer_idx alongside the token-level KV cache. + # Pure attention models keep buffer_idx as None. + self.buffer_idx = None + self.buffer_time = time_gen.generate_time_id() + def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) @@ -78,6 +84,9 @@ def remove_child(self, child_node: "TreeNode"): def update_time(self): self.time_id = time_gen.generate_time_id() + def update_buffer_time(self): + self.buffer_time = time_gen.generate_time_id() + def is_leaf(self): return len(self.children) == 0 @@ -103,10 +112,10 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None, kv_cache_mem_manager=None): from lightllm.common.kv_cache_mem_manager import MemoryManager - self.mem_manager: MemoryManager = mem_manager + self.mem_manager: MemoryManager = kv_cache_mem_manager if kv_cache_mem_manager is not None else mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 @@ -359,6 +368,7 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: or parent_node.ref_counter != 0 or len(parent_node.children) != 1 or child_node.ref_counter != 0 + or parent_node.buffer_idx is not None ): return None diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538f..57241de967 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -7,10 +7,11 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any -from lightllm.common.req_manager import ReqManager +from lightllm.common.req_manager import ReqManager, ReqManagerForMamba from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.common.basemodel.infer_lock import g_infer_state_lock @@ -32,10 +33,13 @@ class InferenceContext: infer_req_ids = None vocab_size = None cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None + mtp_step: int = 0 overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream + use_mamba_model: bool = False + def register( self, backend, @@ -43,6 +47,7 @@ def register( radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int, + use_mamba_model: bool = False, ): self.args = get_env_start_args() from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend @@ -57,6 +62,14 @@ def register( self.infer_req_ids = [] self.vocab_size = vocab_size + + self.use_mamba_model = use_mamba_model + if self.use_mamba_model: + assert self.radix_cache is None or isinstance( + self.radix_cache, HybridRadixCache + ), "Mamba model only support HybridRadixCache" + assert isinstance(self.req_manager, ReqManagerForMamba), "Mamba model only support ReqManagerForMamba" + self.mtp_step = get_env_start_args().mtp_step return def init_cpu_embed_cache_client(self): @@ -73,6 +86,27 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: self.cpu_kv_cache_stream = torch.cuda.Stream() return self.cpu_kv_cache_stream + def _alloc_and_copy_req_buffers(self, req_objs: List["InferReq"]) -> None: + """Allocate and copy buffers for requests. Delegates to req_manager which handles model-specific logic.""" + if not req_objs: + return + + if self.radix_cache is not None and hasattr(self.radix_cache, "free_radix_cache_to_get_enough_buffer"): + self.radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs) * (self.mtp_step + 1)) + + request_indices_gpu = torch.tensor([r.req_idx for r in req_objs], device="cuda", dtype=torch.int64) + self.req_manager.alloc_buffer_for_req(request_indices_gpu) + + if self.radix_cache is None: + return + + copy_data = [(r.req_idx, r.shared_kv_node.buffer_idx) for r in req_objs if r.shared_kv_node is not None] + if copy_data: + copy_indices, copy_buffers = zip(*copy_data) + copy_indices_tensor = torch.tensor(copy_indices, device="cuda", dtype=torch.int64) + copy_buffers_tensor = torch.tensor(copy_buffers, device="cuda", dtype=torch.int64) + self.req_manager.copy_buffer_from_another_buffer(copy_buffers_tensor, copy_indices_tensor) + def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] request_ids = [] @@ -111,9 +145,15 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: slave_req: InferReq = slave_req slave_req.related_master_req = master_req + self._alloc_and_copy_req_buffers(req_objs) + return req_objs def free_a_req_mem(self, free_token_index: List, req: "InferReq"): + # If no KV cache has been allocated yet, there's nothing to free + if req.cur_kv_len == 0: + return + if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) else: @@ -122,7 +162,8 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): # .cpu() 是 流内阻塞操作 value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len, _ = self.radix_cache.insert(key, value) + prefix_len, node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -130,6 +171,50 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> bool: + # 返回该请求的 mamba buffer 是否需要手动释放 + if req.cur_kv_len == 0: + return True + + if self.radix_cache is None: + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) + else: + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + + prefix_len, node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len <= prefix_len + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + + if len(req.extra_need_to_free_token_index) > 0: + free_token_index.extend(req.extra_need_to_free_token_index) + req.extra_need_to_free_token_index = [] + + if node.buffer_idx is None: + req_to_buffer_index = self.req_manager.req_to_buffer_index + buffer_idx = req_to_buffer_index[req.req_idx, 0].item() + self.radix_cache.add_buffer_idx_to_node(node, buffer_idx) + # 该请求的 buffer 已经被插入到 radix cache 中,不需要手动释放 + return False + return True + + def _free_req_mem_and_buffers(self, free_token_index: List, free_buffer_index: List, req: "InferReq"): + """释放请求的 KV cache 和 buffer 内存""" + if self.use_mamba_model: + need_free_base_buffer = self.free_a_req_mem_for_mamba(free_token_index, req) + req_to_buffer_index = self.req_manager.req_to_buffer_index + if need_free_base_buffer: + free_buffer_index.extend(req_to_buffer_index[req.req_idx, :].tolist()) + elif self.mtp_step > 0: + free_buffer_index.extend(req_to_buffer_index[req.req_idx, 1:].tolist()) + else: + self.free_a_req_mem(free_token_index, req) + def _save_promptcache_kvbuffer(self): """ save prompt cache kv buffer @@ -151,19 +236,23 @@ def _filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] + free_buffer_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) - + self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) free_req_index.append(req.req_idx) # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) - free_token_index = custom_cat(free_token_index) - self.req_manager.free(free_req_index, free_token_index) + if len(free_token_index) != 0: + free_token_index = custom_cat(free_token_index) + self.req_manager.free(free_req_index, free_token_index) + + if self.use_mamba_model and len(free_buffer_index) != 0: + self.req_manager.free_buffer(free_buffer_index) finished_req_ids_set = set(finished_request_ids) self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set] @@ -191,12 +280,15 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if pause_reqs: g_infer_state_lock.acquire() + pause_req_indices = [] free_token_index = [] + free_buffer_index = [] for req in pause_reqs: + pause_req_indices.append(req.req_idx) if self.args.diverse_mode: # 发生暂停的时候,需要清除 diverse 模式下的主从关系 req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) + self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) req.cur_kv_len = 0 req.shm_req.shm_cur_kv_len = req.cur_kv_len assert req.wait_pause is True @@ -209,13 +301,16 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) + if self.use_mamba_model and len(free_buffer_index) != 0: + self.req_manager.free_buffer(free_buffer_index) + g_infer_state_lock.release() return self def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int): if paused_reqs: g_infer_state_lock.acquire() - + revovered_reqs = [] for req in paused_reqs: prefill_need_token_num = req.get_cur_total_len() if prefill_need_token_num > can_alloc_token_num: @@ -226,7 +321,9 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo if is_master_in_dp: req.shm_req.is_paused = False can_alloc_token_num -= prefill_need_token_num + revovered_reqs.append(req) + self._alloc_and_copy_req_buffers(revovered_reqs) g_infer_state_lock.release() return @@ -351,6 +448,11 @@ def __init__( self.nixl_pd_task_failed_num: int = 0 self.nixl_trans_device_id: int = -1 + # 在开启radix cache的情况下,用于标记命中情况,用于插入算法 + self.mamba_model_match_len = 0 + self.mamba_buffer_insert_len = 0 + self.extra_need_to_free_token_index = [] + # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED @@ -402,7 +504,7 @@ def _match_radix_cache(self): input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + share_node, miss_prefix_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) if share_node is not None: self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len @@ -411,6 +513,13 @@ def _match_radix_cache(self): self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 + if g_infer_context.use_mamba_model: + MAMBA_PREFILL_BLOCK_SIZE = 128 + MAMBA_MIN_INSERT_LEN = 1024 + miss_prefix_len = miss_prefix_len - miss_prefix_len % MAMBA_PREFILL_BLOCK_SIZE + if miss_prefix_len > MAMBA_MIN_INSERT_LEN: + self.mamba_buffer_insert_len = miss_prefix_len + self.shm_req.shm_cur_kv_len = self.cur_kv_len return @@ -458,13 +567,18 @@ def get_input_token_ids(self): return self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] def get_chuncked_input_token_ids(self): - chunked_start = self.cur_kv_len - chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + # 复用 get_chuncked_input_token_len 的逻辑,保持一致性 + chunked_end = self.get_chuncked_input_token_len() return self.shm_req.shm_prompt_ids.arr[0:chunked_end] def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + + if self.mamba_buffer_insert_len > 0: + chunked_end = min(self.get_cur_total_len(), chunked_start + self.mamba_buffer_insert_len) + self.mamba_buffer_insert_len = 0 + return chunked_end def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8b085c45ed..0ba4b9248c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -9,7 +9,6 @@ from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger from lightllm.models import get_model -from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -42,6 +41,7 @@ from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token @@ -172,12 +172,16 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + + self.use_buffer_manager = getattr(self.model, "use_buffer_manager", False) + + radix_cache_class = self.model.get_radix_cache_class() self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, - mem_manager=self.model.mem_manager, + kv_cache_mem_manager=self.model.mem_manager, ) if self.use_dynamic_prompt_cache else None @@ -189,12 +193,18 @@ def init_model(self, kvargs): self.logger.info(f"loaded model class {self.model.__class__}") + # Check if the model uses Mamba (linear attention) layers + from lightllm.common.req_manager import ReqManagerForMamba + + use_mamba_model = isinstance(self.model.req_manager, ReqManagerForMamba) + g_infer_context.register( backend=self, req_manager=self.model.req_manager, radix_cache=self.radix_cache, shm_req_manager=self.shm_req_manager, vocab_size=self.model.vocab_size, + use_mamba_model=use_mamba_model, ) # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 @@ -287,21 +297,33 @@ def decode(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): raise NotImplementedError() def init_mtp_draft_model(self, main_kvargs: dict): - # 当前只支持 deepseekv3 模式的 mtp + # Support deepseekv3 and qwen3_next MTP modes self.mtp_step = self.args.mtp_step - self.draft_models: List[Deepseek3MTPModel] = [] + self.draft_models = [] os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: + if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att", "qwen3next_vanilla"]: num_mtp_modules = self.args.mtp_step - elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: + elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att", "qwen3next_eagle"]: num_mtp_modules = 1 else: assert False, f"error mtp mode {self.args.mtp_mode}" for i in range(num_mtp_modules): + # Get MTP model config first to calculate mem_layer_start mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) + + # Calculate mem_layer_start: main model layers + previous MTP model layers + # For models with integrated MTP (like qwen3_next), each MTP module has 1 layer + # For models with separate MTP configs, use the config's num_hidden_layers + model_type = mtp_model_cfg.get("model_type", "") + if model_type == "qwen3_next": + # Qwen3Next has integrated MTP with 1 layer per module + mtp_layers_per_module = 1 + else: + mtp_layers_per_module = mtp_model_cfg["num_hidden_layers"] + mem_layer_start = self.model.config["num_hidden_layers"] + i * mtp_layers_per_module mtp_model_kvargs = { "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, @@ -314,7 +336,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): "data_type": main_kvargs.get("data_type", "float16"), "graph_max_batch_size": main_kvargs.get("graph_max_batch_size", 16), "graph_max_len_in_batch": main_kvargs.get("graph_max_len_in_batch", 8196), - "disable_cudagraph": main_kvargs.get("disable_cudagraph", False), + "disable_cudagraph": True, # Disable CUDA graphs for MTP draft models "mem_fraction": main_kvargs["mem_fraction"], "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), "quant_type": main_kvargs.get("quant_type", None), @@ -322,23 +344,27 @@ def init_mtp_draft_model(self, main_kvargs: dict): "run_mode": "normal", "main_model": self.model, "mtp_previous_draft_models": self.draft_models.copy(), + "mem_layer_start": mem_layer_start, + "mtp_index": i, } - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) - if mtp_model_cfg["model_type"] == "deepseek_v3": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + # Select MTP model class based on model type + model_type = mtp_model_cfg.get("model_type", "") + if model_type == "deepseek_v3": self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "qwen3_moe": + elif model_type == "qwen3_moe": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "mistral": + elif model_type == "mistral": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) + elif model_type == "qwen3_next": + self.draft_models.append(Qwen3NextMTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "glm4_moe_lite": assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) else: - assert False, f"error mtp mode {mtp_model_cfg['model_type']}" + raise ValueError(f"Unsupported MTP model type: {model_type}") self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224ebc..3cabd97baa 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -24,6 +24,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from .control_state import ControlState logger = init_logger(__name__) @@ -50,6 +51,14 @@ def __init__(self) -> None: self.classed_req_strict_prefill = False return + def _maybe_insert_hybrid_radix_cache(self, run_reqs: List[InferReq]): + # Insert hybrid radix cache entries if applicable, use for hybrid attention models. + if self.use_buffer_manager and self.radix_cache is not None: + torch.cuda.synchronize() + g_infer_state_lock.acquire() + self.radix_cache.insert_for_hybrid_radix_cache(run_reqs) + g_infer_state_lock.release() + def infer_loop(self): torch.cuda.set_device(get_current_device_id()) try: @@ -136,6 +145,9 @@ def prefill_normal( extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + + self._maybe_insert_hybrid_radix_cache(run_reqs) + # 第四阶段 event_pack.notify_pre_post_handle() return @@ -219,6 +231,8 @@ def prefill_mtp( nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + self._maybe_insert_hybrid_radix_cache(run_reqs) + # 第四阶段 event_pack.notify_pre_post_handle() return @@ -258,6 +272,24 @@ def decode_mtp( key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) + + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 (accept_len == 1 means buffer[0] is already correct) + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = model_input.b_req_idx[b_req_mtp_start_loc[mask]] + # Source: the accepted buffer (at index accept_len - 1) + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + # Destination: buffer[0] for each request + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + # P2P copy both conv_states and ssm_states + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e76..c5dd768224 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -454,6 +454,20 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): gpu_tensor=mtp_accept_len, ) + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() @@ -767,6 +781,20 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ) all_next_token_ids.append(next_token_ids) + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 3e97f4de3e..ed4665e725 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -68,7 +68,7 @@ def exposed_init_model(self, kvargs): self.model = ( Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) - elif self.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + elif self.model_type in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]: self.model = ( Qwen3VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) diff --git a/test_gsmk.py b/test_gsmk.py new file mode 100644 index 0000000000..78a5aa467f --- /dev/null +++ b/test_gsmk.py @@ -0,0 +1,241 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + "repetition_penalty": 1.0, + "top_p": 1.0, + "top_k": 1, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + "Invalid API response format. 'generated_text' should be a non-empty list, " + f"got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + # First try to find the answer after "####" marker (GSM8K format) + match = re.search(r"####\s*(-?\d+)", answer_str) + if match: + try: + return ast.literal_eval(match.group(1)) + except SyntaxError: + pass + # Fallback: find all numbers and take the last one + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=256) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. " + "Using {} questions.".format(num_questions, max_available, num_shots, max_available) + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(label != INVALID for label in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=few_shot_examples + questions[i], + temperature=0, + max_tokens=1024, + stop=["Question", "Assistant:", "<|separator|>", "Human:", "\n\nQuestion"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args) From 1686d34b23115fd9fb16806b6ad94909daa0bf4f Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 21 Feb 2026 08:04:05 +0000 Subject: [PATCH 074/180] fix conv3d --- lightllm/models/qwen2_vl/qwen2_visual.py | 2 ++ .../qwen3_omni_visual.py | 2 ++ lightllm/models/qwen3_vl/qwen3_visual.py | 13 ++++++++ lightllm/server/api_models.py | 32 ++++++++++++++----- lightllm/server/api_openai.py | 8 +++++ lightllm/server/httpserver/manager.py | 8 +++++ 6 files changed, 57 insertions(+), 8 deletions(-) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 0e2af0cbb2..a29cb8758b 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -62,6 +62,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) + # Use channels_last_3d to enable cuDNN optimized Conv3D path + hidden_states = hidden_states.contiguous(memory_format=torch.channels_last_3d) hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index ffa2e19bd6..c20c227996 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -68,6 +68,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) + # Use channels_last_3d to enable cuDNN optimized Conv3D path + hidden_states = hidden_states.contiguous(memory_format=torch.channels_last_3d) hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index 00ad6c05a7..7fc8187ddc 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -29,6 +29,9 @@ from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class Qwen3VLVisionMLP(nn.Module): @@ -68,6 +71,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) + # Use channels_last_3d to enable cuDNN optimized Conv3D path + hidden_states = hidden_states.contiguous(memory_format=torch.channels_last_3d) hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states @@ -374,7 +379,15 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) + orig_size = image_data.size pixel_values, image_grid_thw = self.processor.preprocess(image_data) + + # Debug logging for image processing + logger.debug( + f"[VISUAL_DEBUG] Image {i}: orig_size={orig_size}, " + f"pixel_values.shape={pixel_values.shape}, grid_thw={image_grid_thw}" + ) + img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index f30ecc55fe..7c7d40698c 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -115,6 +115,7 @@ class CompletionRequest(BaseModel): prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None max_tokens: Optional[int] = 8192 + max_completion_tokens: Optional[int] = None # OpenAI's newer parameter, alias for max_tokens temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 @@ -169,10 +170,17 @@ def load_generation_cfg(cls, weight_dir: str): @classmethod def apply_loaded_defaults(cls, data: Any): """Apply loaded default values if field is not provided.""" - if isinstance(data, dict) and cls._loaded_defaults: - for key, value in cls._loaded_defaults.items(): - if key not in data: - data[key] = value + if isinstance(data, dict): + # Map max_completion_tokens to max_tokens if provided + # (OpenAI's newer parameter name) + if "max_completion_tokens" in data and data["max_completion_tokens"] is not None: + if "max_tokens" not in data or data["max_tokens"] is None: + data["max_tokens"] = data["max_completion_tokens"] + + if cls._loaded_defaults: + for key, value in cls._loaded_defaults.items(): + if key not in data: + data[key] = value return data @@ -187,6 +195,7 @@ class ChatCompletionRequest(BaseModel): stream_options: Optional[StreamOptions] = None stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = 8192 + max_completion_tokens: Optional[int] = None # OpenAI's newer parameter, alias for max_tokens presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -246,10 +255,17 @@ def load_generation_cfg(cls, weight_dir: str): @classmethod def apply_loaded_defaults(cls, data: Any): """Apply loaded default values if field is not provided.""" - if isinstance(data, dict) and cls._loaded_defaults: - for key, value in cls._loaded_defaults.items(): - if key not in data: - data[key] = value + if isinstance(data, dict): + # Map max_completion_tokens to max_tokens if provided + # (OpenAI's newer parameter name) + if "max_completion_tokens" in data and data["max_completion_tokens"] is not None: + if "max_tokens" not in data or data["max_tokens"] is None: + data["max_tokens"] = data["max_completion_tokens"] + + if cls._loaded_defaults: + for key, value in cls._loaded_defaults.items(): + if key not in data: + data[key] = value return data diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index fc14314ae3..de1423c496 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -276,6 +276,14 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req text = "".join(final_output_dict[sub_req_id]) full_text = text + # Debug logging for empty responses + if not text or len(text.strip()) == 0: + logger.warning( + f"[EMPTY_RESPONSE_DEBUG] sub_req_id={sub_req_id}, " + f"completion_tokens={completion_tokens}, finish_reason={finish_reason}, " + f"prompt_tokens={prompt_tokens}, output_chunks={len(final_output_dict[sub_req_id])}" + ) + # Handle reasoning content reasoning_text = None reasoning_parser = get_env_start_args().reasoning_parser diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d51f88cdda..c290880c73 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -288,6 +288,14 @@ async def generate( if self.pd_mode.is_P_or_NORMAL(): await multimodal_params.verify_and_preload(request) + # Debug logging for multimodal requests + if multimodal_params and multimodal_params.images: + logger.debug( + f"[MULTIMODAL_DEBUG] req_id={group_request_id}, " + f"num_images={len(multimodal_params.images)}, " + f"max_new_tokens={sampling_params.max_new_tokens}" + ) + # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) # encode From dd9b61160e28eb13af2d194ab9f5c994b262db4b Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 07:02:44 +0000 Subject: [PATCH 075/180] [draft] qwen3.5 dense --- .../{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json | 7 + ...ARLEN=true,REVERSE=false}_NVIDIA_H200.json | 38 ++++ ...=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json | 7 + ...4,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 50 ++++++ ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 48 +++++ ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...=12,dtype=torch.bfloat16}_NVIDIA_H200.json | 50 ++++++ .../layer_weights/transformer_layer_weight.py | 8 +- lightllm/models/qwen3_moe/model.py | 4 +- .../layer_infer/transformer_layer_infer.py | 18 +- .../layer_weights/transformer_layer_weight.py | 168 ++++++++++++------ lightllm/models/qwen3next/model.py | 4 +- 15 files changed, 368 insertions(+), 65 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..808ed9a7fc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 1 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 2 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 8 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..fb62cf8259 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 32, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..ad8d397d3b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 8, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "1024": { + "BLK_HEADS": 32, + "num_warps": 1 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 16, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json index f525d11257..55ccb24a65 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -11,6 +11,10 @@ "BLOCK_N": 128, "num_warps": 1 }, + "1536": { + "BLOCK_N": 128, + "num_warps": 1 + }, "16": { "BLOCK_N": 256, "num_warps": 4 @@ -23,10 +27,26 @@ "BLOCK_N": 128, "num_warps": 1 }, + "192": { + "BLOCK_N": 128, + "num_warps": 1 + }, "2048": { "BLOCK_N": 64, "num_warps": 2 }, + "24": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2400": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "24576": { + "BLOCK_N": 128, + "num_warps": 1 + }, "256": { "BLOCK_N": 512, "num_warps": 2 @@ -35,18 +55,38 @@ "BLOCK_N": 128, "num_warps": 1 }, + "3072": { + "BLOCK_N": 128, + "num_warps": 1 + }, "32768": { "BLOCK_N": 128, "num_warps": 1 }, + "384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "393216": { + "BLOCK_N": 128, + "num_warps": 1 + }, "4096": { "BLOCK_N": 64, "num_warps": 2 }, + "49152": { + "BLOCK_N": 128, + "num_warps": 1 + }, "512": { "BLOCK_N": 256, "num_warps": 4 }, + "6144": { + "BLOCK_N": 128, + "num_warps": 1 + }, "64": { "BLOCK_N": 64, "num_warps": 2 @@ -55,6 +95,10 @@ "BLOCK_N": 128, "num_warps": 1 }, + "768": { + "BLOCK_N": 256, + "num_warps": 2 + }, "8": { "BLOCK_N": 64, "num_warps": 2 @@ -66,5 +110,9 @@ "8192": { "BLOCK_N": 128, "num_warps": 2 + }, + "98304": { + "BLOCK_N": 128, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..df501847ec --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..ada783ef92 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "num_stages": 2, + "num_warps": 1 + }, + "1024": { + "num_stages": 2, + "num_warps": 1 + }, + "128": { + "num_stages": 5, + "num_warps": 2 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 3, + "num_warps": 1 + }, + "2048": { + "num_stages": 5, + "num_warps": 2 + }, + "256": { + "num_stages": 3, + "num_warps": 2 + }, + "32": { + "num_stages": 3, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "num_stages": 3, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 13ba6cbe0f..e525cb2d20 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -5,11 +5,11 @@ class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - self.n_routed_experts = network_config["num_experts"] + self.n_routed_experts = network_config.get("num_experts", 0) self.is_moe = ( - network_config["num_experts"] > 0 - and layer_num not in network_config["mlp_only_layers"] - and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) super().__init__(layer_num, data_type, network_config, quant_cfg) return diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a5051276..b71d7f4878 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -25,4 +25,6 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + # Only initialize DeepEP group for MoE models with num_experts + if "num_experts" in self.config and self.config["num_experts"] > 0: + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index cd5fd67d53..dc44c64434 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -143,14 +143,19 @@ def _ffn_core(self, input, layer_weight, is_decode=False): def _standard_ffn(self, input, infer_state, layer_weight): """Standard FFN using shared expert weights (non-MoE layers).""" + # For dense models without shared experts, return zeros (no FFN computation) + if not hasattr(layer_weight, "shared_expert_gate_up_proj") or layer_weight.shared_expert_gate_up_proj is None: + return torch.zeros_like(input) ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) return ffn2_out def _compute_shared_expert(self, input, layer_weight, is_decode=False): """Compute shared expert FFN output with gating.""" ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) - gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() - ffn2_out.mul_(gate) + # Dense models don't have shared_expert_gate + if layer_weight.shared_expert_gate is not None: + gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() + ffn2_out.mul_(gate) return ffn2_out, input_view def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): @@ -488,14 +493,19 @@ def _ffn_core(self, input, layer_weight, is_decode=False): def _standard_ffn(self, input, infer_state, layer_weight): """Standard FFN using shared expert weights (non-MoE layers).""" + # For dense models without shared experts, return zeros (no FFN computation) + if not hasattr(layer_weight, "shared_expert_gate_up_proj") or layer_weight.shared_expert_gate_up_proj is None: + return torch.zeros_like(input) ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) return ffn2_out def _compute_shared_expert(self, input, layer_weight, is_decode=False): """Compute shared expert FFN output with gating.""" ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) - gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() - ffn2_out.mul_(gate) + # Dense models don't have shared_expert_gate + if layer_weight.shared_expert_gate is not None: + gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() + ffn2_out.mul_(gate) return ffn2_out, input_view def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index d4e16555d9..3e72041f8a 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -47,6 +47,11 @@ def _init_weight(self): self._init_gate_shared_expert_weight() return + def _init_ffn(self): + # Qwen3Next architecture uses _init_gate_shared_expert_weight() for FFN-like component + # No standard MLP FFN weights needed for this architecture + pass + def load_hf_weights(self, weights): self._split_q_with_gate(weights) super().load_hf_weights(weights) @@ -62,41 +67,65 @@ def _split_q_with_gate(self, weights): weights[self._o_gate_weight_name] = _gate_proj def _init_gate_shared_expert_weight(self): - prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" hidden_size = self.network_config_["hidden_size"] - shared_inter = self.network_config_["shared_expert_intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[shared_inter, shared_inter], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=shared_inter, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - self.shared_expert_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) + + # Check if this is a MoE model with shared_expert or a dense model + if "shared_expert_intermediate_size" in self.network_config_: + # MoE model with shared expert + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + else: + # Dense model with standard MLP + prefix = f"model.layers.{self.layer_num_}.mlp" + inter_size = self.network_config_["intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + # No shared_expert_gate for dense models + self.shared_expert_gate = None class Qwen3NextGatedDeltaNetTransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.is_moe = ( - network_config["num_experts"] > 0 - and layer_num not in network_config["mlp_only_layers"] - and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) super().__init__(layer_num, data_type, network_config, quant_cfg) @@ -126,6 +155,11 @@ def _init_weight(self): self._init_ffn() self._init_gate_shared_expert_weight() + def _init_ffn(self): + # GatedDeltaNet architecture uses _init_gate_shared_expert_weight() for FFN-like component + # No standard MLP FFN weights needed for this architecture + pass + def _init_gdn_weight(self): prefix = f"model.layers.{self.layer_num_}.linear_attn" hidden_size = self.network_config_["hidden_size"] @@ -284,30 +318,54 @@ def _parse_linear_conv1d(self, weight): return new_weight def _init_gate_shared_expert_weight(self): - prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" hidden_size = self.network_config_["hidden_size"] - shared_inter = self.network_config_["shared_expert_intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[shared_inter, shared_inter], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=shared_inter, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - self.shared_expert_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) + + # Check if this is a MoE model with shared_expert or a dense model + if "shared_expert_intermediate_size" in self.network_config_: + # MoE model with shared expert + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + else: + # Dense model with standard MLP + prefix = f"model.layers.{self.layer_num_}.mlp" + inter_size = self.network_config_["intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + # No shared_expert_gate for dense models + self.shared_expert_gate = None diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 1234a659ed..d15b357608 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -60,7 +60,9 @@ def _init_config(self): def _init_custom(self): super()._init_custom() - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + # Only initialize DeepEP group for MoE models with num_experts + if "num_experts" in self.config and self.config["num_experts"] > 0: + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 From 6a3a17c6467b85897595dc8b9dba1b2847e92f8d Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 10:02:22 +0000 Subject: [PATCH 076/180] split dense and moe --- ...num=8,use_fp8_w8a8=false}_NVIDIA_H200.json | 38 +++++++++++++++++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 38 +++++++++++++++++ .../{topk_num=8}_NVIDIA_H200.json | 12 ++++++ ...orch.bfloat16,topk_num=8}_NVIDIA_H200.json | 18 ++++++++ ...M=8,dtype=torch.bfloat16}_NVIDIA_H200.json | 18 ++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 24 +++++++++++ lightllm/models/qwen35_moe/model.py | 42 ------------------- .../layer_infer/transformer_layer_infer.py | 12 +++--- 8 files changed, 154 insertions(+), 48 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json delete mode 100644 lightllm/models/qwen35_moe/model.py diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..662875ecdb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..1f8134fa64 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json index 002b842cbb..bf2afabaef 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json @@ -19,6 +19,14 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "16384": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, "256": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -27,6 +35,10 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "64": { "BLOCK_SIZE": 128, "num_warps": 8 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json index bc904bb7f8..b32622e3b1 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -29,6 +29,18 @@ "NUM_STAGE": 1, "num_warps": 2 }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, "256": { "BLOCK_DIM": 1024, "BLOCK_M": 1, @@ -41,6 +53,12 @@ "NUM_STAGE": 4, "num_warps": 4 }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, "64": { "BLOCK_DIM": 128, "BLOCK_M": 1, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b3e656b6d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,18 @@ +{ + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "16384": { + "num_stages": 4, + "num_warps": 2 + }, + "2048": { + "num_stages": 2, + "num_warps": 2 + }, + "4096": { + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json index e08a58baf5..0a0f01fe7a 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -23,12 +23,24 @@ "NUM_STAGES": 1, "num_warps": 1 }, + "131072": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "160": { "BLOCK_M": 1, "BLOCK_N": 256, "NUM_STAGES": 1, "num_warps": 1 }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "163840": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -53,6 +65,12 @@ "NUM_STAGES": 2, "num_warps": 1 }, + "32768": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "40960": { "BLOCK_M": 32, "BLOCK_N": 128, @@ -70,5 +88,11 @@ "BLOCK_N": 128, "NUM_STAGES": 2, "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/models/qwen35_moe/model.py b/lightllm/models/qwen35_moe/model.py deleted file mode 100644 index ee149f3a81..0000000000 --- a/lightllm/models/qwen35_moe/model.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import json - -from lightllm.models.qwen3_vl.model import QWen3VLTokenizer -from lightllm.models.registry import ModelRegistry -from lightllm.models.qwen3next.model import Qwen3NextTpPartModel -from lightllm.common.build_utils import repair_config -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights - - -class QWen35Tokenizer(QWen3VLTokenizer): - def __init__(self, tokenizer=None, image_processor=None, **kwargs): - super().__init__(tokenizer, image_processor, **kwargs) - - -@ModelRegistry(["qwen3_5"], is_multimodal=True) -class Qwen35MoeTpPartModel(Qwen3NextTpPartModel): - def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: - all_config = json.load(json_file) - self.config = all_config["text_config"] - - repair_config(self.config, same_names=["num_attention_heads", "n_head"]) - repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) - repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) - repair_config(self.config, same_names=["intermediate_size", "moe_intermediate_size"]) - - # Handle fine-tuning config if present - if self.finetune_config: - self.config["vocab_size"] = self.finetune_config.vocab_size - - def _load_hf_weights(self): - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 9eccddffc1..4f96506b14 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -21,14 +21,14 @@ class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, network_config): - self.n_routed_experts = network_config["num_experts"] + self.n_routed_experts = network_config.get("num_experts", 0) self.is_moe = ( - network_config["num_experts"] > 0 - and layer_num not in network_config["mlp_only_layers"] - and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) - self.num_experts_per_tok = network_config["num_experts_per_tok"] - self.norm_topk_prob = network_config["norm_topk_prob"] + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 0) + self.norm_topk_prob = network_config.get("norm_topk_prob", True) super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) From e1cdfb43ce6e20ab20bcbabc058f0923685f4d60 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 11:40:11 +0000 Subject: [PATCH 077/180] feat: add mamba_cache_ratio for automatic memory allocation - Add mamba_cache_ratio parameter (default 0.5) - Change mamba_cache_size default from 3000 to None - Implement automatic memory allocation based on ratio - Add clear error messages with solutions when memory insufficient - Maintain backward compatibility with explicit mamba_cache_size Ratio formula: mamba_memory = total_available * ratio / (1 + ratio) - ratio=0.5 -> 33% mamba, 67% KV - ratio=1.0 -> 50% mamba, 50% KV - ratio=2.0 -> 67% mamba, 33% KV --- lightllm/models/qwen3next/model.py | 83 +++++++++++++++++++- lightllm/server/api_cli.py | 17 +++- lightllm/server/core/objs/start_args_type.py | 3 +- lightllm/utils/envs_utils.py | 2 +- 4 files changed, 98 insertions(+), 7 deletions(-) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index d15b357608..205eb1dc9b 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -54,6 +54,75 @@ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch def autotune_layers(self): return self.config["full_attention_interval"] + def _calculate_mamba_cache_size(self, start_args: StartArgs) -> int: + """Calculate mamba cache size based on available memory and mamba_cache_ratio.""" + from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory + import torch.distributed as dist + + use_ratio = self.max_total_token_num is None and start_args.mamba_cache_size is None + + world_size = dist.get_world_size() + total_memory = get_total_gpu_memory() + available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - self.mem_fraction) + + conv_kernel_size = self.config["linear_conv_kernel_dim"] + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) // self.tp_world_size_ + + num_linear_layers = self.config["n_layer"] - (self.config["n_layer"] // self.config["full_attention_interval"]) + + conv_cell_size = ( + num_linear_layers * conv_dim * (conv_kernel_size - 1) * torch._utils._element_size(self.data_type) + ) + + ssm_dtype = torch.bfloat16 if start_args.mamba_ssm_data_type == "bfloat16" else torch.float32 + ssm_cell_size = ( + num_linear_layers + * (self.num_linear_v_heads // self.tp_world_size_) + * self.head_linear_k_dim + * self.head_linear_v_dim + * torch._utils._element_size(ssm_dtype) + ) + + total_cell_size = conv_cell_size + ssm_cell_size + + if use_ratio: + mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 + mamba_memory_gb = available_memory * mamba_cache_ratio / (1 + mamba_cache_ratio) + else: + mamba_memory_gb = available_memory + mamba_cache_ratio = None + + mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) + + if mamba_cache_size < start_args.running_max_req_size: + ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 + raise ValueError( + f"Insufficient memory for mamba cache allocation!\n\n" + f"Calculated mamba_cache_size ({mamba_cache_size}) < " + f"running_max_req_size ({start_args.running_max_req_size})\n\n" + f"Memory budget:\n" + f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated buffers: {mamba_cache_size}\n" + f" Required buffers: {start_args.running_max_req_size}\n\n" + f"Solutions:\n" + f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" + f" 2. Increase --mamba_cache_ratio from {ratio} to " + f"{start_args.running_max_req_size * (1 + ratio) / mamba_cache_size - 1:.3f} or higher\n" + f" 3. Increase --mem_fraction to leave more memory for caches\n" + ) + + logger.info( + f"Mamba cache allocation:\n" + f" Available memory: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated mamba_cache_size: {mamba_cache_size}" + ) + + return mamba_cache_size + def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) @@ -69,16 +138,22 @@ def _init_mem_manager(self): start_args: StartArgs = get_env_start_args() mamba_cache_size = start_args.mamba_cache_size - if mamba_cache_size is not None: - assert ( - mamba_cache_size >= start_args.running_max_req_size - ), "mamba_cache_size must be greater than running_max_req_size" self.num_linear_k_heads = self.config["linear_num_key_heads"] self.num_linear_v_heads = self.config["linear_num_value_heads"] self.head_linear_k_dim = self.config["linear_key_head_dim"] self.head_linear_v_dim = self.config["linear_value_head_dim"] + if mamba_cache_size is None: + mamba_cache_size = self._calculate_mamba_cache_size(start_args) + else: + if mamba_cache_size < start_args.running_max_req_size: + raise ValueError( + f"Explicitly set mamba_cache_size ({mamba_cache_size}) < " + f"running_max_req_size ({start_args.running_max_req_size})\n" + f"Please increase mamba_cache_size to at least {start_args.running_max_req_size}" + ) + conv_kernel_size = self.config["linear_conv_kernel_dim"] conv_dim = ( self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4d122f615d..25365491d3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -629,7 +629,22 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) - parser.add_argument("--mamba_cache_size", type=int, default=3000, help="""The size of linear attn cache. """) + parser.add_argument( + "--mamba_cache_size", + type=int, + default=None, + help="""The size of linear attn cache. If not specified, will be calculated + automatically based on mamba_cache_ratio or max_total_token_num.""", + ) + parser.add_argument( + "--mamba_cache_ratio", + type=float, + default=0.5, + help="""Ratio of available memory to allocate for mamba cache (after model + weights and dynamic memory reservation). Only effective when both + mamba_cache_size and max_total_token_num are not set. Default is 0.5 + (50%% of available memory for mamba cache, rest for KV cache).""", + ) parser.add_argument( "--mamba_ssm_data_type", type=str, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d8d2c6ff8b..0baa11383a 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -176,5 +176,6 @@ class StartArgs: enable_multimodal_audio: bool = field(default=False) # hybrid attention model (Qwen3Next) - mamba_cache_size: int = field(default=800) + mamba_cache_size: Optional[int] = field(default=None) + mamba_cache_ratio: Optional[float] = field(default=0.5) mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index cdafb88873..7a7a9be121 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -158,7 +158,7 @@ def get_kv_quant_calibration_inference_count(): @lru_cache(maxsize=None) def get_triton_autotune_level(): - return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 1)) + return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0)) g_model_init_done = False From f2e148e4646a7fe5aadefd85f187589540509c18 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 11:49:25 +0000 Subject: [PATCH 078/180] refactor: simplify mamba_cache_ratio to direct percentage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change ratio meaning from complex formula to simple percentage: - Old: ratio = mamba / kv, mamba = total * ratio / (1+ratio) - New: ratio = mamba / total, mamba = total * ratio This makes the ratio more intuitive: - 0.3 → 30% mamba, 70% KV - 0.5 → 50% mamba, 50% KV (default) - 0.7 → 70% mamba, 30% KV Also simplifies error message recommendation formula. --- MAMBA_CACHE_USAGE.md | 53 ++++++++++++++++++++++++++++++ lightllm/models/qwen3next/model.py | 5 +-- lightllm/server/api_cli.py | 8 ++--- 3 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 MAMBA_CACHE_USAGE.md diff --git a/MAMBA_CACHE_USAGE.md b/MAMBA_CACHE_USAGE.md new file mode 100644 index 0000000000..e8bebdec89 --- /dev/null +++ b/MAMBA_CACHE_USAGE.md @@ -0,0 +1,53 @@ +# Mamba Cache Ratio-Based Allocation + +## Parameters + +- `--mamba_cache_ratio ` (default: 0.5) - Percentage of cache memory for mamba +- `--mamba_cache_size ` (default: None) - Explicit buffer count (backward compatible) + +## Ratio Meaning + +`mamba_cache_ratio = mamba_memory / total_cache_memory` + +Examples: +- `0.3` → 30% mamba, 70% KV +- `0.5` → 50% mamba, 50% KV (default) +- `0.7` → 70% mamba, 30% KV + +## Usage Examples + +### Automatic (recommended) +```bash +python -m lightllm.server.api_server \ + --model_dir /path/to/qwen3next \ + --mem_fraction 0.9 +# Uses default ratio 0.5 → 50% mamba, 50% KV +``` + +### Custom ratio +```bash +# For long-context workloads (more KV cache) +python -m lightllm.server.api_server \ + --model_dir /path/to/qwen3next \ + --mamba_cache_ratio 0.3 # 30% mamba, 70% KV + +# For high-concurrency workloads (more mamba cache) +python -m lightllm.server.api_server \ + --model_dir /path/to/qwen3next \ + --mamba_cache_ratio 0.7 # 70% mamba, 30% KV +``` + +### Explicit size (backward compatible) +```bash +python -m lightllm.server.api_server \ + --model_dir /path/to/qwen3next \ + --mamba_cache_size 3000 +``` + +## Troubleshooting + +### Error: "Insufficient memory for mamba cache allocation!" + +**Solution 1**: Reduce `--running_max_req_size` to calculated value or lower +**Solution 2**: Increase `--mamba_cache_ratio` to give more memory to mamba +**Solution 3**: Increase `--mem_fraction` to leave more memory for caches diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 205eb1dc9b..263d1c622d 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -88,8 +88,9 @@ def _calculate_mamba_cache_size(self, start_args: StartArgs) -> int: total_cell_size = conv_cell_size + ssm_cell_size if use_ratio: + # mamba_cache_ratio = mamba_memory / total_cache_memory mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 - mamba_memory_gb = available_memory * mamba_cache_ratio / (1 + mamba_cache_ratio) + mamba_memory_gb = available_memory * mamba_cache_ratio else: mamba_memory_gb = available_memory mamba_cache_ratio = None @@ -110,7 +111,7 @@ def _calculate_mamba_cache_size(self, start_args: StartArgs) -> int: f"Solutions:\n" f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" f" 2. Increase --mamba_cache_ratio from {ratio} to " - f"{start_args.running_max_req_size * (1 + ratio) / mamba_cache_size - 1:.3f} or higher\n" + f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" f" 3. Increase --mem_fraction to leave more memory for caches\n" ) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 25365491d3..eec9a05cf2 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -640,10 +640,10 @@ def make_argument_parser() -> argparse.ArgumentParser: "--mamba_cache_ratio", type=float, default=0.5, - help="""Ratio of available memory to allocate for mamba cache (after model - weights and dynamic memory reservation). Only effective when both - mamba_cache_size and max_total_token_num are not set. Default is 0.5 - (50%% of available memory for mamba cache, rest for KV cache).""", + help="""Ratio of mamba cache to total cache memory (mamba + KV). + Only effective when both mamba_cache_size and max_total_token_num are not set. + Default is 0.5 (50%% mamba cache, 50%% KV cache). + Example: 0.3 -> 30%% mamba, 70%% KV; 0.7 -> 70%% mamba, 30%% KV.""", ) parser.add_argument( "--mamba_ssm_data_type", From b4fe20123a0e757d807aac37e6c130655cb0ecd5 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 13:04:36 +0000 Subject: [PATCH 079/180] add H100 config --- ...12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 14 +++ ...12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 14 +++ ...12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 14 +++ ...12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 12 ++ ...,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json | 38 ++++++ ...,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json | 38 ++++++ ...,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json | 38 ++++++ ...LEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...LEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...LEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json | 12 ++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 118 ++++++++++++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ .../{topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...t16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 74 +++++++++++ ...t16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 74 +++++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 74 +++++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 74 +++++++++++ 40 files changed, 1708 insertions(+) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..cc5c68eb79 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..b6e5109b62 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..511935b4cf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 128, + "num_stages": 4, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 128, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..cc5c68eb79 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..1038611f6a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..4bc06d07d9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..7421097fa4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..f1159e4357 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..892c20e78d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..d831f32c4a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..40cdc996b9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,12 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..833062ec2f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 2 + }, + "100": { + "num_warps": 1 + }, + "1024": { + "num_warps": 1 + }, + "128": { + "num_warps": 1 + }, + "16": { + "num_warps": 2 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 1 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 1 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 8 + }, + "8": { + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..5f2cf9465b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 2 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 2 + }, + "2048": { + "num_warps": 2 + }, + "256": { + "num_warps": 2 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 2 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..c8a1841674 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 2 + }, + "1024": { + "num_warps": 8 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 4 + }, + "32": { + "num_warps": 2 + }, + "4096": { + "num_warps": 8 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..a97cabf8b2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "4": { + "BK": 64, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..786624883f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 64, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..eaca03cf75 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,12 @@ +{ + "2": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..d9064e5d6a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..baef19d90c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 32, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "64": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "8": { + "BLK_HEADS": 64, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..90ac24c408 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 1 + }, + "1024": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 4, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 8, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 16, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..31d7a6e203 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,118 @@ +{ + "1024": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "12": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "1200": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "12288": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "128": { + "BLOCK_N": 256, + "num_warps": 8 + }, + "131072": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "1536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "1600": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "192": { + "BLOCK_N": 512, + "num_warps": 1 + }, + "196608": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "24576": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "256": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "262144": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "3072": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "384": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "4096": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "49152": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "64": { + "BLOCK_N": 256, + "num_warps": 8 + }, + "65536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "768": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "8": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "96": { + "BLOCK_N": 512, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..864d1d3f18 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 4096, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..bcf56e01f7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 128, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ba1dc8a75d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "3072": { + "BLOCK_SIZE": 2048, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..6f109e1c6e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 32768, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0042ef8a2a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..54a5967071 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..bb78d1dd84 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..1552d8bf1a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..08cbfd85c3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..13a070b8f0 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..169a148799 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 512, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 128, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..5022588ef5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 2, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..4ae96d02d1 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 32, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 16 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..28c654f3d2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 3, + "num_warps": 2 + }, + "100": { + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "num_stages": 3, + "num_warps": 2 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 2, + "num_warps": 1 + }, + "2048": { + "num_stages": 4, + "num_warps": 1 + }, + "256": { + "num_stages": 3, + "num_warps": 1 + }, + "32": { + "num_stages": 3, + "num_warps": 2 + }, + "4096": { + "num_stages": 4, + "num_warps": 1 + }, + "64": { + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "num_stages": 2, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..08b0d5e5bc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "num_stages": 3, + "num_warps": 1 + }, + "1024": { + "num_stages": 2, + "num_warps": 1 + }, + "128": { + "num_stages": 4, + "num_warps": 4 + }, + "16": { + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 5, + "num_warps": 2 + }, + "32": { + "num_stages": 5, + "num_warps": 8 + }, + "4096": { + "num_stages": 2, + "num_warps": 1 + }, + "64": { + "num_stages": 2, + "num_warps": 1 + }, + "8": { + "num_stages": 5, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0d871841ed --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "num_stages": 1, + "num_warps": 1 + }, + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "num_stages": 2, + "num_warps": 1 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 4, + "num_warps": 1 + }, + "2048": { + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "num_stages": 4, + "num_warps": 1 + }, + "32": { + "num_stages": 2, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 1 + }, + "64": { + "num_stages": 4, + "num_warps": 1 + }, + "8": { + "num_stages": 3, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..9f3a8dcb25 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "131072": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..72026f01c4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "131072": { + "BLOCK_M": 64, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file From e2ce9c04e995f6f1c441c6bd8583b490ac8379da Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 06:15:24 +0000 Subject: [PATCH 080/180] refactor: align radix_cache_class with infer_state_class style - Replace get_radix_cache_class() classmethod with radix_cache_class class attribute in TpPartBaseModel and Qwen3NextTpPartModel - Move RadixCache/HybridRadixCache imports to module top-level - Update base_backend.py to access radix_cache_class directly - Replace alloc_buffer_for_req_triton with simpler indexed PyTorch assignment - Remove now-unused alloc_buffer_kernel.py Triton kernel - Revert LOADWORKER default to 1 and remove language_model. prefix stripping --- lightllm/common/basemodel/basemodel.py | 8 +- .../basemodel/layer_weights/hf_load_utils.py | 10 +-- .../triton_kernel/alloc_buffer_kernel.py | 80 ------------------- lightllm/common/req_manager.py | 4 +- lightllm/models/qwen3next/model.py | 7 +- .../model_infer/mode_backend/base_backend.py | 2 +- 6 files changed, 9 insertions(+), 102 deletions(-) delete mode 100644 lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index caa90462cc..1d36c72d0b 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -53,11 +54,8 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo - @classmethod - def get_radix_cache_class(cls): - from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache - - return RadixCache + # radix cache class + radix_cache_class = RadixCache def __init__(self, kvargs): self.args = get_env_start_args() diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 304b04ab44..8cf66a5ad6 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -18,14 +18,6 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay weights = {k: weights.get_tensor(k) for k in weights.keys()} else: weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") - new_weight = {} - for k, v in weights.items(): - if "language_model." in k: - new_weight[k[len("language_model.") :]] = v - else: - new_weight[k] = v - del weights - weights = new_weight if pre_post_layer is not None: pre_post_layer.load_hf_weights(weights) @@ -68,7 +60,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye transformer_layer_list=transformer_layer_list, weight_dir=weight_dir, ) # noqa - worker = int(os.environ.get("LOADWORKER", 18)) + worker = int(os.environ.get("LOADWORKER", 1)) with Pool(worker) as p: iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" diff --git a/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py b/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py deleted file mode 100644 index b6444449b1..0000000000 --- a/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def alloc_buffer_for_req_kernel( - req_index_ptr, # [num_reqs] - indices of requests to allocate buffers for - buffer_indexes_ptr, # [num_reqs * num_buffers_per_req] - buffer indices to assign (from CPU) - req_to_buffer_index_ptr, # [max_request_num + 1, num_buffers_per_req] - tensor mapping req_idx to buffer_idx - num_reqs, # number of requests to process - stride_buffer, # stride for req_to_buffer_index second dimension - NUM_BUFFERS_PER_REQ: tl.constexpr, # number of buffers per request (mtp_step + 1) - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - # Mask for valid indices - mask = offsets < num_reqs - - # Load request indices - req_indices = tl.load(req_index_ptr + offsets, mask=mask, other=0) - - # For each request, allocate NUM_BUFFERS_PER_REQ buffers - for buf_idx in tl.static_range(NUM_BUFFERS_PER_REQ): - # Load buffer index for this position - buffer_offset = offsets * NUM_BUFFERS_PER_REQ + buf_idx - buffer_indices = tl.load(buffer_indexes_ptr + buffer_offset, mask=mask, other=0) - - # Update req_to_buffer_index[req_indices, buf_idx] = buffer_indices - output_offset = req_indices * stride_buffer + buf_idx - tl.store(req_to_buffer_index_ptr + output_offset, buffer_indices, mask=mask) - - -def alloc_buffer_for_req_triton( - req_index: torch.Tensor, # [num_reqs] int32/int64 tensor on CUDA - buffer_indexes: torch.Tensor, # [num_reqs * (mtp_step + 1)] int32 tensor (can be CPU or CUDA) - req_to_buffer_index: torch.Tensor, # [max_request_num + 1, mtp_step + 1] int32 tensor on CUDA - mtp_step: int = 0, # number of additional buffers per request (default 0 for non-MTP mode) -): - num_reqs = req_index.shape[0] - num_buffers_per_req = mtp_step + 1 - - # Ensure inputs are on CUDA - if not req_index.is_cuda: - req_index = req_index.cuda() - if not buffer_indexes.is_cuda: - buffer_indexes = buffer_indexes.cuda() - - # Ensure correct dtypes - if req_index.dtype not in [torch.int32, torch.int64]: - req_index = req_index.to(torch.int32) - if buffer_indexes.dtype != torch.int32: - buffer_indexes = buffer_indexes.to(torch.int32) - - # Validate buffer_indexes size - expected_size = num_reqs * num_buffers_per_req - assert buffer_indexes.shape[0] == expected_size, ( - f"Expected {expected_size} buffer indices for {num_reqs} requests " - f"with mtp_step={mtp_step}, but got {buffer_indexes.shape[0]}" - ) - - # Get stride for the second dimension of req_to_buffer_index - stride_buffer = req_to_buffer_index.stride(0) - - # Launch kernel - BLOCK_SIZE = 256 - grid = (triton.cdiv(num_reqs, BLOCK_SIZE),) - - alloc_buffer_for_req_kernel[grid]( - req_index, - buffer_indexes, - req_to_buffer_index, - num_reqs, - stride_buffer, - NUM_BUFFERS_PER_REQ=num_buffers_per_req, - BLOCK_SIZE=BLOCK_SIZE, - ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 573fe50842..bad3fa0557 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,6 +1,5 @@ import torch import collections -from lightllm.common.basemodel.triton_kernel.alloc_buffer_kernel import alloc_buffer_for_req_triton from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional @@ -268,7 +267,8 @@ def alloc_buffer_for_req(self, req_index: torch.Tensor): num_reqs = req_index.shape[0] num_buffers_per_req = self.mtp_step + 1 buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) - alloc_buffer_for_req_triton(req_index, buffer_indexes, self.req_to_buffer_index, self.mtp_step) + # Pure PyTorch: indexed assignment is already a fused GPU kernel + self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): # 获取目标请求的所有 MTP buffer (从 buffer[0] 到 buffer[mtp_step]) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 263d1c622d..b3f0f53cac 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -20,6 +20,7 @@ from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache logger = init_logger(__name__) @@ -33,11 +34,7 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): is_hybrid_attention = True # Indicates model uses hybrid (full + linear) attention use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states - @classmethod - def get_radix_cache_class(cls): - from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache - - return HybridRadixCache + radix_cache_class = HybridRadixCache def __init__(self, kvargs) -> None: self.mem_manager: Qwen3NextHybridMemManager = None diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0ba4b9248c..57a3508e93 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -175,7 +175,7 @@ def init_model(self, kvargs): self.use_buffer_manager = getattr(self.model, "use_buffer_manager", False) - radix_cache_class = self.model.get_radix_cache_class() + radix_cache_class = self.model.radix_cache_class self.radix_cache = ( radix_cache_class( get_unique_server_name(), From b1adbf3bae3e52784ea32fb1c421172d516f5385 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 06:19:23 +0000 Subject: [PATCH 081/180] fix: add missing attention_chunk param to flashattention_nopad.py The sgl_kernel.fwd.default API requires attention_chunk before softcap. This file was missed when the parameter was added in commit a4ab210f. Also update sgl-kernel from 0.3.7.post1 to 0.3.21 which supports this API. --- lightllm/models/vit/triton_kernel/flashattention_nopad.py | 3 ++- requirements.txt | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 8428e52996..b43f8f95af 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -195,7 +195,8 @@ def flash_attention_v3_fwd( False, window_size[0], window_size[1], - 0.0, + 0, # attention_chunk + 0.0, # softcap is_rotary_interleaved=False, scheduler_metadata=None, num_splits=1, diff --git a/requirements.txt b/requirements.txt index 25cdab955d..521038f719 100644 --- a/requirements.txt +++ b/requirements.txt @@ -81,7 +81,7 @@ atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 flashinfer-python==0.2.4 -sgl-kernel==0.3.7.post1 +sgl-kernel==0.3.21 httpx==0.28.1 librosa==0.11.0 cuda_bindings==12.9.0 From c744ebd87691c72a0ee7efdd2bfa1110695282e5 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 06:31:05 +0000 Subject: [PATCH 082/180] refactor: clarify naming in mamba_buffer_copy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename copy_buffer_p2p → copy_mamba_buffer (indexed 1:1 slot copy) - Rename copy_buffer_broadcast → fork_mamba_buffer (1:N MTP fork) - Unify chunk offset param name (pair_idx_offset/copy_idx_offset → chunk_offset) - Rename stride_index → stride_slot to reflect the slot/cache dimension - Rename src_idx_in_batch → src_chunk_idx in fork kernel - Extract _MAX_GRID_DIM = 65535 module constant (was duplicated inline) - Add divisibility assertion before implicit // in fork autotuned wrapper - Update autotuner cache keys to match new names --- .../triton_kernel/mamba_buffer_copy.py | 133 +++++++++--------- 1 file changed, 68 insertions(+), 65 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py index 6a1d8adbd5..21301570d3 100644 --- a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -3,36 +3,38 @@ import triton.language as tl from lightllm.common.triton_utils.autotuner import autotune +_MAX_GRID_DIM = 65535 + @triton.jit -def _copy_buffer_p2p_1d_kernel( +def _copy_mamba_buffer_1d_kernel( src_buffer_ptr, dst_buffer_ptr, src_indexes_ptr, dst_indexes_ptr, - pair_idx_offset, + chunk_offset, layer_idx_offset, stride_layer, - stride_index, + stride_slot, stride_d, d_size, BLOCK_D: tl.constexpr, ): """ - Optimized kernel for 1D buffer copy. + Indexed 1:1 copy kernel for Mamba recurrent state buffers. Grid: (num_pairs, layer_num, num_blocks_d) Each program copies one block of dimension d for one (pair, layer) combination. """ - pair_idx = tl.program_id(0) + pair_idx_offset + pair_idx = tl.program_id(0) + chunk_offset layer_idx = tl.program_id(1) + layer_idx_offset block_d_idx = tl.program_id(2) # Cast strides to int64 to prevent overflow in pointer arithmetic stride_layer = stride_layer.to(tl.int64) - stride_index = stride_index.to(tl.int64) + stride_slot = stride_slot.to(tl.int64) - # Load source and destination indices for this pair + # Load source and destination slot indices for this pair src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) @@ -44,8 +46,8 @@ def _copy_buffer_p2p_1d_kernel( mask = d_offsets < d_size # Calculate source and destination pointers for this layer and pair - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_slot + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_slot src_ptr = base_src + d_offsets * stride_d dst_ptr = base_dst + d_offsets * stride_d @@ -56,54 +58,53 @@ def _copy_buffer_p2p_1d_kernel( @triton.jit -def _copy_buffer_broadcast_1d_kernel( +def _fork_mamba_buffer_1d_kernel( src_buffer_ptr, dst_buffer_ptr, src_indexes_ptr, dst_indexes_ptr, - copy_idx_offset, + chunk_offset, layer_idx_offset, stride_layer, - stride_index, + stride_slot, stride_d, d_size, num_dst_per_src, BLOCK_D: tl.constexpr, ): """ - Broadcast kernel for 1D buffer copy (one source to multiple destinations). + Fork kernel for Mamba recurrent state buffers: one source slot → N destination slots. + Used for MTP speculation where one parent state is copied to multiple child slots. Grid: (num_src, layer_num, num_blocks_d) """ - src_idx_in_batch = tl.program_id(0) + copy_idx_offset + src_chunk_idx = tl.program_id(0) + chunk_offset layer_idx = tl.program_id(1) + layer_idx_offset block_d_idx = tl.program_id(2) # Cast strides to int64 to prevent overflow in pointer arithmetic stride_layer = stride_layer.to(tl.int64) - stride_index = stride_index.to(tl.int64) + stride_slot = stride_slot.to(tl.int64) - # Load source index - src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + # Load source slot index + src_idx = tl.load(src_indexes_ptr + src_chunk_idx).to(tl.int64) # Calculate offsets for this block d_start = block_d_idx * BLOCK_D d_offsets = d_start + tl.arange(0, BLOCK_D) mask = d_offsets < d_size - # Calculate source pointer - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + # Calculate source pointer and load data once + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_slot src_ptr = base_src + d_offsets * stride_d - - # Load data once data = tl.load(src_ptr, mask=mask, other=0.0) - # Broadcast to all destinations for this source + # Write to each destination slot for this source for dst_offset in range(num_dst_per_src): - dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx_in_batch = src_chunk_idx * num_dst_per_src + dst_offset dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_slot dst_ptr = base_dst + d_offsets * stride_d tl.store(dst_ptr, data, mask=mask) @@ -151,20 +152,20 @@ def _get_buffer_copy_run_key(src_indexes: torch.Tensor): @autotune( - kernel_name="mamba_buffer_copy_p2p_1d:v1", + kernel_name="mamba_buffer_copy_1d:v1", configs_gen_func=_get_buffer_copy_1d_configs, static_key_func=_get_buffer_copy_static_key, run_key_func=_get_buffer_copy_run_key, mutates_args=["dst_buffer"], ) -def _copy_buffer_p2p_1d_autotuned( +def _copy_mamba_buffer_1d_autotuned( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, dst_indexes: torch.Tensor, run_config: dict = None, ): - """Auto-tuned 1D buffer copy.""" + """Auto-tuned indexed 1:1 copy of Mamba recurrent state buffer slots.""" num_pairs = src_indexes.shape[0] layer_num = src_buffer.shape[0] d_size = src_buffer.shape[2] @@ -180,19 +181,17 @@ def _copy_buffer_p2p_1d_autotuned( num_blocks_d = triton.cdiv(d_size, BLOCK_D) - MAX_GRID_SIZE = 65535 - - for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): - pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + for pair_chunk_start in range(0, num_pairs, _MAX_GRID_DIM): + pair_chunk_end = min(pair_chunk_start + _MAX_GRID_DIM, num_pairs) pair_chunk_size = pair_chunk_end - pair_chunk_start - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + for layer_chunk_start in range(0, layer_num, _MAX_GRID_DIM): + layer_chunk_end = min(layer_chunk_start + _MAX_GRID_DIM, layer_num) layer_chunk_size = layer_chunk_end - layer_chunk_start grid = (pair_chunk_size, layer_chunk_size, num_blocks_d) - _copy_buffer_p2p_1d_kernel[grid]( + _copy_mamba_buffer_1d_kernel[grid]( src_buffer, dst_buffer, src_indexes, @@ -210,23 +209,26 @@ def _copy_buffer_p2p_1d_autotuned( @autotune( - kernel_name="mamba_buffer_broadcast_1d:v1", + kernel_name="mamba_buffer_fork_1d:v1", configs_gen_func=_get_buffer_copy_1d_configs, static_key_func=_get_buffer_copy_static_key, run_key_func=_get_buffer_copy_run_key, mutates_args=["dst_buffer"], ) -def _copy_buffer_broadcast_1d_autotuned( +def _fork_mamba_buffer_1d_autotuned( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, + dst_indexes: torch.Tensor, # flat 1D: [num_src * num_dst_per_src] run_config: dict = None, ): - """Auto-tuned 1D buffer broadcast (one src to multiple dst).""" + """Auto-tuned fork: copy each source Mamba slot to N destination slots.""" num_src = src_indexes.shape[0] layer_num = src_buffer.shape[0] d_size = src_buffer.shape[2] + assert ( + dst_indexes.shape[0] % num_src == 0 + ), f"dst_indexes length {dst_indexes.shape[0]} must be divisible by num_src {num_src}" num_dst_per_src = dst_indexes.shape[0] // num_src if run_config is None: @@ -240,19 +242,17 @@ def _copy_buffer_broadcast_1d_autotuned( num_blocks_d = triton.cdiv(d_size, BLOCK_D) - MAX_GRID_SIZE = 65535 - - for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): - src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + for src_chunk_start in range(0, num_src, _MAX_GRID_DIM): + src_chunk_end = min(src_chunk_start + _MAX_GRID_DIM, num_src) src_chunk_size = src_chunk_end - src_chunk_start - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + for layer_chunk_start in range(0, layer_num, _MAX_GRID_DIM): + layer_chunk_end = min(layer_chunk_start + _MAX_GRID_DIM, layer_num) layer_chunk_size = layer_chunk_end - layer_chunk_start grid = (src_chunk_size, layer_chunk_size, num_blocks_d) - _copy_buffer_broadcast_1d_kernel[grid]( + _fork_mamba_buffer_1d_kernel[grid]( src_buffer, dst_buffer, src_indexes, @@ -285,23 +285,23 @@ def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: return buffer.view(L, B, -1) -def copy_buffer_p2p( +def copy_mamba_buffer( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, dst_indexes: torch.Tensor, ): """ - Copy buffers from source indices to destination indices with auto-tuning. + Indexed 1:1 copy of Mamba recurrent state buffer slots. - Supports any buffer shape [layer_num, buffer_size, ...] as long as the - trailing dimensions are contiguous (which is the default for torch.zeros). + Copies slot src_indexes[i] → dst_indexes[i] for all layers simultaneously. + Used for cache eviction/restore and normal token state management. Args: - src_buffer: Source buffer tensor [layer_num, buffer_size, ...] - dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] - src_indexes: Source buffer indices [num_pairs] - dst_indexes: Destination buffer indices [num_pairs] + src_buffer: [layer_num, num_slots, ...] + dst_buffer: [layer_num, num_slots, ...] + src_indexes: source slot indices [num_pairs] + dst_indexes: destination slot indices [num_pairs] """ assert src_buffer.shape == dst_buffer.shape assert src_indexes.shape == dst_indexes.shape @@ -309,36 +309,39 @@ def copy_buffer_p2p( src_flat = _flatten_trailing_dims(src_buffer) dst_flat = _flatten_trailing_dims(dst_buffer) - _copy_buffer_p2p_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) + _copy_mamba_buffer_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) -def copy_buffer_broadcast( +def fork_mamba_buffer( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, dst_indexes: torch.Tensor, ): """ - Broadcast buffers from source indices to multiple destination indices (MTP use case). + Fork Mamba recurrent state slots: copy one source slot to N destination slots. - Each source buffer is copied to multiple destination buffers. + Used for MTP (Multi-Token Prediction) speculation, where a parent token's + recurrent state must be replicated into each speculative child slot. Args: - src_buffer: Source buffer tensor [layer_num, buffer_size, ...] - dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] - src_indexes: Source buffer indices [num_src] - dst_indexes: Destination buffer indices [num_src, num_dst_per_src] (2D tensor) + src_buffer: [layer_num, num_slots, ...] + dst_buffer: [layer_num, num_slots, ...] + src_indexes: source slot indices [num_src] + dst_indexes: destination slot indices [num_src, num_dst_per_src] """ assert src_buffer.shape == dst_buffer.shape assert len(src_indexes.shape) == 1 - assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D, got shape {dst_indexes.shape}" + assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D [num_src, num_dst_per_src], got {dst_indexes.shape}" num_src = src_indexes.shape[0] - assert num_src == dst_indexes.shape[0], f"Mismatch: src_indexes {num_src} vs dst_indexes {dst_indexes.shape[0]}" + assert ( + num_src == dst_indexes.shape[0] + ), f"Mismatch: src_indexes {num_src} vs dst_indexes rows {dst_indexes.shape[0]}" - # Flatten dst_indexes for kernel + # Flatten dst_indexes to 1D for kernel; kernel reconstructs the 2D layout via num_dst_per_src dst_indexes_flat = dst_indexes.reshape(-1).contiguous() src_flat = _flatten_trailing_dims(src_buffer) dst_flat = _flatten_trailing_dims(dst_buffer) - _copy_buffer_broadcast_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat) + _fork_mamba_buffer_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat) From 9def6979fc3311e2d158f67d1dca7ee15abe0329 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 06:35:06 +0000 Subject: [PATCH 083/180] clean --- MAMBA_CACHE_USAGE.md | 53 ------------------- .../mamba_cache_mem_manager/cache_manager.py | 14 ++--- 2 files changed, 7 insertions(+), 60 deletions(-) delete mode 100644 MAMBA_CACHE_USAGE.md diff --git a/MAMBA_CACHE_USAGE.md b/MAMBA_CACHE_USAGE.md deleted file mode 100644 index e8bebdec89..0000000000 --- a/MAMBA_CACHE_USAGE.md +++ /dev/null @@ -1,53 +0,0 @@ -# Mamba Cache Ratio-Based Allocation - -## Parameters - -- `--mamba_cache_ratio ` (default: 0.5) - Percentage of cache memory for mamba -- `--mamba_cache_size ` (default: None) - Explicit buffer count (backward compatible) - -## Ratio Meaning - -`mamba_cache_ratio = mamba_memory / total_cache_memory` - -Examples: -- `0.3` → 30% mamba, 70% KV -- `0.5` → 50% mamba, 50% KV (default) -- `0.7` → 70% mamba, 30% KV - -## Usage Examples - -### Automatic (recommended) -```bash -python -m lightllm.server.api_server \ - --model_dir /path/to/qwen3next \ - --mem_fraction 0.9 -# Uses default ratio 0.5 → 50% mamba, 50% KV -``` - -### Custom ratio -```bash -# For long-context workloads (more KV cache) -python -m lightllm.server.api_server \ - --model_dir /path/to/qwen3next \ - --mamba_cache_ratio 0.3 # 30% mamba, 70% KV - -# For high-concurrency workloads (more mamba cache) -python -m lightllm.server.api_server \ - --model_dir /path/to/qwen3next \ - --mamba_cache_ratio 0.7 # 70% mamba, 30% KV -``` - -### Explicit size (backward compatible) -```bash -python -m lightllm.server.api_server \ - --model_dir /path/to/qwen3next \ - --mamba_cache_size 3000 -``` - -## Troubleshooting - -### Error: "Insufficient memory for mamba cache allocation!" - -**Solution 1**: Reduce `--running_max_req_size` to calculated value or lower -**Solution 2**: Increase `--mamba_cache_ratio` to give more memory to mamba -**Solution 3**: Increase `--mem_fraction` to leave more memory for caches diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 272a999bb1..9b0933f22f 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -6,7 +6,7 @@ from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.common.allocator_utils import TokenAllocator -from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_buffer_p2p, copy_buffer_broadcast +from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_mamba_buffer, fork_mamba_buffer from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt @@ -57,29 +57,29 @@ def get_mamba_cache(self, layer_idx: int): return conv_state, ssm_state def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): - copy_buffer_p2p( + copy_mamba_buffer( self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes ) - copy_buffer_p2p( + copy_mamba_buffer( self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes ) def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): - copy_buffer_broadcast( + fork_mamba_buffer( self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) - copy_buffer_broadcast( + fork_mamba_buffer( self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): """ - Broadcast ONLY SSM states (not conv states) from source indices to destination indices. + Fork ONLY SSM states (not conv states) from source indices to destination indices. This is used for MTP mode where each buffer maintains its own independent conv state, but SSM states need to be synchronized. """ - copy_buffer_broadcast( + fork_mamba_buffer( self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) From 2b3deb80c0c6b50c1b14aab37591d8ef6741dbc7 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 07:36:35 +0000 Subject: [PATCH 084/180] fix --- lightllm/common/req_manager.py | 3 ++- lightllm/server/api_models.py | 32 ++++++++------------------------ lightllm/server/api_openai.py | 21 +++------------------ 3 files changed, 13 insertions(+), 43 deletions(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index bad3fa0557..3a5e048fb9 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -267,7 +267,8 @@ def alloc_buffer_for_req(self, req_index: torch.Tensor): num_reqs = req_index.shape[0] num_buffers_per_req = self.mtp_step + 1 buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) - # Pure PyTorch: indexed assignment is already a fused GPU kernel + if not buffer_indexes.is_cuda: + buffer_indexes = buffer_indexes.cuda() self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 7c7d40698c..f30ecc55fe 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -115,7 +115,6 @@ class CompletionRequest(BaseModel): prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None max_tokens: Optional[int] = 8192 - max_completion_tokens: Optional[int] = None # OpenAI's newer parameter, alias for max_tokens temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 @@ -170,17 +169,10 @@ def load_generation_cfg(cls, weight_dir: str): @classmethod def apply_loaded_defaults(cls, data: Any): """Apply loaded default values if field is not provided.""" - if isinstance(data, dict): - # Map max_completion_tokens to max_tokens if provided - # (OpenAI's newer parameter name) - if "max_completion_tokens" in data and data["max_completion_tokens"] is not None: - if "max_tokens" not in data or data["max_tokens"] is None: - data["max_tokens"] = data["max_completion_tokens"] - - if cls._loaded_defaults: - for key, value in cls._loaded_defaults.items(): - if key not in data: - data[key] = value + if isinstance(data, dict) and cls._loaded_defaults: + for key, value in cls._loaded_defaults.items(): + if key not in data: + data[key] = value return data @@ -195,7 +187,6 @@ class ChatCompletionRequest(BaseModel): stream_options: Optional[StreamOptions] = None stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = 8192 - max_completion_tokens: Optional[int] = None # OpenAI's newer parameter, alias for max_tokens presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -255,17 +246,10 @@ def load_generation_cfg(cls, weight_dir: str): @classmethod def apply_loaded_defaults(cls, data: Any): """Apply loaded default values if field is not provided.""" - if isinstance(data, dict): - # Map max_completion_tokens to max_tokens if provided - # (OpenAI's newer parameter name) - if "max_completion_tokens" in data and data["max_completion_tokens"] is not None: - if "max_tokens" not in data or data["max_tokens"] is None: - data["max_tokens"] = data["max_completion_tokens"] - - if cls._loaded_defaults: - for key, value in cls._loaded_defaults.items(): - if key not in data: - data[key] = value + if isinstance(data, dict) and cls._loaded_defaults: + for key, value in cls._loaded_defaults.items(): + if key not in data: + data[key] = value return data diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index de1423c496..33f342822f 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -184,16 +184,9 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} ) else: - # Treat as local file path - if os.path.isfile(img): - with open(img, "rb") as f: - multimodal_params_dict["images"].append( - {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} - ) - else: - raise ValueError( - "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." - ) + raise ValueError( + "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." + ) tools = None if request.tools and request.tool_choice != "none": @@ -276,14 +269,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req text = "".join(final_output_dict[sub_req_id]) full_text = text - # Debug logging for empty responses - if not text or len(text.strip()) == 0: - logger.warning( - f"[EMPTY_RESPONSE_DEBUG] sub_req_id={sub_req_id}, " - f"completion_tokens={completion_tokens}, finish_reason={finish_reason}, " - f"prompt_tokens={prompt_tokens}, output_chunks={len(final_output_dict[sub_req_id])}" - ) - # Handle reasoning content reasoning_text = None reasoning_parser = get_env_start_args().reasoning_parser From 61f894533629fd65ae74df8a65c6fc3c0135cf54 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 10:23:02 +0000 Subject: [PATCH 085/180] clean --- lightllm/models/qwen3_5/model.py | 13 ++- lightllm/models/qwen3next/buffer_pool.py | 83 ------------------- .../layer_infer/shared_expert_mixin.py | 7 +- lightllm/server/api_cli.py | 6 +- 4 files changed, 12 insertions(+), 97 deletions(-) delete mode 100644 lightllm/models/qwen3next/buffer_pool.py diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index fdbccdf787..2f7413bc87 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -1,9 +1,5 @@ import os import json -import time -import gc -from safetensors import safe_open -from tqdm import tqdm from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3next.model import Qwen3NextTpPartModel from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( @@ -11,8 +7,12 @@ Qwen35NextGatedDeltaNetTransformerLayerWeight, ) from lightllm.models.qwen3_vl.model import QWen3VLTokenizer -from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer -from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( + Qwen3VLMultimodalPreLayerInfer, +) +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import ( + Qwen3VLPreAndPostLayerWeight, +) from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( Qwen35FullAttentionTransformerLayerInfer, Qwen35GatedDeltaNetTransformerLayerInfer, @@ -20,7 +20,6 @@ from lightllm.models.qwen3_5.infer_struct import Qwen35InferStateInfo from lightllm.common.build_utils import repair_config from lightllm.utils.log_utils import init_logger -import lightllm.utils.petrel_helper as utils logger = init_logger(__name__) diff --git a/lightllm/models/qwen3next/buffer_pool.py b/lightllm/models/qwen3next/buffer_pool.py deleted file mode 100644 index 42c4bcafc7..0000000000 --- a/lightllm/models/qwen3next/buffer_pool.py +++ /dev/null @@ -1,83 +0,0 @@ -# lightllm/models/qwen3next/buffer_pool.py -import torch -from typing import Dict, Tuple - - -class Qwen3NextBufferPool: - """ - Buffer pool for Qwen3Next inference to reduce allocations. - - NOT thread-safe. Each GPU worker process should have its own pool instance. - - Manages reusable buffers for: - - Attention norm outputs - - FFN norm outputs - - FFN intermediate activations - - GDN intermediate tensors - """ - - def __init__(self, enable_stats: bool = False, max_buffers: int = 64): - self._buffers: Dict[Tuple[tuple, torch.dtype, torch.device], torch.Tensor] = {} - self._in_use: set = set() - self._max_buffers = max_buffers - self._access_order: list = [] # Track LRU order - self._enable_stats = enable_stats - self._stats = {"hits": 0, "misses": 0, "peak_buffers": 0, "evictions": 0} if enable_stats else None - - def get_buffer( - self, - shape: Tuple[int, ...], - dtype: torch.dtype, - device: torch.device, - ) -> torch.Tensor: - """Get a buffer from the pool or allocate a new one.""" - key = (shape, dtype, device) - - # Check if we have a matching buffer not in use - if key in self._buffers and key not in self._in_use: - self._in_use.add(key) - # Update LRU order - if key in self._access_order: - self._access_order.remove(key) - self._access_order.append(key) - if self._enable_stats: - self._stats["hits"] += 1 - return self._buffers[key] - - # Evict oldest unused buffer if at capacity - if len(self._buffers) >= self._max_buffers: - self._evict_one() - - # Allocate new buffer - buffer = torch.empty(shape, dtype=dtype, device=device) - self._buffers[key] = buffer - self._in_use.add(key) - self._access_order.append(key) - if self._enable_stats: - self._stats["misses"] += 1 - self._stats["peak_buffers"] = max(self._stats["peak_buffers"], len(self._buffers)) - return buffer - - def _evict_one(self): - """Evict oldest unused buffer (LRU).""" - for key in self._access_order: - if key not in self._in_use and key in self._buffers: - del self._buffers[key] - self._access_order.remove(key) - if self._enable_stats: - self._stats["evictions"] += 1 - return - - def release_all(self): - """Release all buffers back to the pool (call after forward pass).""" - self._in_use.clear() - - def clear(self): - """Clear all buffers (call when changing batch size significantly).""" - self._buffers.clear() - self._in_use.clear() - self._access_order.clear() - - def get_stats(self): - """Return buffer pool statistics (if enabled).""" - return self._stats.copy() if self._stats else None diff --git a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py index 2da106dbb2..be9000fcad 100644 --- a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py +++ b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py @@ -32,12 +32,7 @@ def _ffn_core(self, input, layer_weight): """Core FFN computation: gate_up -> silu_and_mul -> down.""" input = input.view(-1, self.embed_dim_) up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) - - if hasattr(self, "buffer_pool") and self.buffer_pool: - ffn1_out = self.buffer_pool.get_buffer((input.size(0), up_gate_out.size(1) // 2), input.dtype, input.device) - else: - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) silu_and_mul_fwd(up_gate_out, ffn1_out) ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) return ffn2_out, input diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index eec9a05cf2..47111f76bc 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -638,7 +638,11 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--mamba_cache_ratio", - type=float, + type=lambda v: float(v) + if 0.0 <= (_ := float(v)) <= 1.0 + else (_ for _ in ()).throw( + argparse.ArgumentTypeError(f"--mamba_cache_ratio must be between 0.0 and 1.0, got {v}") + ), default=0.5, help="""Ratio of mamba cache to total cache memory (mamba + KV). Only effective when both mamba_cache_size and max_total_token_num are not set. From f7280a30c26e7a61708b2ab3f7eef5e35cd048f6 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 13:37:03 +0000 Subject: [PATCH 086/180] split --- lightllm/models/__init__.py | 6 +-- lightllm/models/qwen3_5/__init__.py | 7 ++- lightllm/models/qwen3_5/model.py | 30 ------------ lightllm/models/qwen3_5_moe/__init__.py | 0 .../qwen3_5_moe/layer_infer/__init__.py | 0 .../qwen3_5_moe/layer_weights/__init__.py | 0 lightllm/models/qwen3_5_moe/model.py | 48 +++++++++++++++++++ 7 files changed, 53 insertions(+), 38 deletions(-) create mode 100644 lightllm/models/qwen3_5_moe/__init__.py create mode 100644 lightllm/models/qwen3_5_moe/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_5_moe/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_5_moe/model.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index af13e34cd9..ad040cdf25 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -40,8 +40,6 @@ ) from lightllm.models.gpt_oss.model import GptOssTpPartModel from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel -from lightllm.models.qwen3_5.model import ( - Qwen3_5TpPartModel, - Qwen3_5MOETpPartModel, -) +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.models.qwen3_5_moe.model import Qwen3_5MOETpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/qwen3_5/__init__.py b/lightllm/models/qwen3_5/__init__.py index 47667a92d5..56a41a228a 100644 --- a/lightllm/models/qwen3_5/__init__.py +++ b/lightllm/models/qwen3_5/__init__.py @@ -1,17 +1,16 @@ """ -Qwen3.5 Multimodal Model Module +Qwen3.5 Multimodal Model Module (Dense Variant) -Provides Qwen3.5 multimodal models with hybrid attention and vision-language support. +Provides Qwen3.5 dense multimodal model with hybrid attention and vision-language support. +For MoE variant, see qwen3_5_moe module. """ from .model import ( Qwen3_5TpPartModel, - Qwen3_5MOETpPartModel, QWen3_5Tokenizer, ) __all__ = [ "Qwen3_5TpPartModel", - "Qwen3_5MOETpPartModel", "QWen3_5Tokenizer", ] diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index 2f7413bc87..3d093b3939 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -196,33 +196,3 @@ def _init_infer_layer(self): ) for i in range(self.config["n_layer"]) ] - - -@ModelRegistry(["qwen3_5_moe"], is_multimodal=True) -class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): - """ - Qwen3.5-MoE Multimodal Model (Mixture of Experts Variant) - - Extends Qwen3.5 with sparse expert routing: - - Same hybrid attention architecture as Qwen3.5 - - MoE layers replace dense MLP layers - - Expert routing handled by Qwen3NextSparseMoeBlock (inherited) - - The MoE variant is automatically configured by inheriting from - Qwen3NextTpPartModel, which inherits from Qwen3MOEModel. - - No additional configuration needed - MoE support is built-in. - """ - - def __init__(self, kvargs): - """ - Initialize Qwen3.5-MoE model. - - Args: - kvargs: Dictionary containing: - - weight_dir: Path to model weights - - max_total_token_num: Maximum total tokens - - Additional model configuration - """ - super().__init__(kvargs) - logger.info("Initialized Qwen3.5-MoE multimodal model with expert routing") diff --git a/lightllm/models/qwen3_5_moe/__init__.py b/lightllm/models/qwen3_5_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/layer_infer/__init__.py b/lightllm/models/qwen3_5_moe/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/layer_weights/__init__.py b/lightllm/models/qwen3_5_moe/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/model.py b/lightllm/models/qwen3_5_moe/model.py new file mode 100644 index 0000000000..069992bb37 --- /dev/null +++ b/lightllm/models/qwen3_5_moe/model.py @@ -0,0 +1,48 @@ +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager + +logger = init_logger(__name__) + + +@ModelRegistry("qwen3_5_moe", is_multimodal=True) +class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): + """ + Qwen3.5-MoE Multimodal Model (Mixture of Experts Variant) + + Extends Qwen3.5 with sparse expert routing: + - Same hybrid attention architecture as Qwen3.5 + - MoE layers replace dense MLP layers + - Expert routing handled by inherited MoE infrastructure + + This model combines: + - Hybrid attention from Qwen3Next (Gated Delta Networks + Full Attention) + - Multimodal capabilities from Qwen3VL (image/video processing) + - MoE sparse routing for efficient scaling + """ + + def __init__(self, kvargs): + """ + Initialize Qwen3.5-MoE model. + + Args: + kvargs: Dictionary containing: + - weight_dir: Path to model weights + - max_total_token_num: Maximum total tokens + - Additional model configuration + """ + super().__init__(kvargs) + logger.info("Initialized Qwen3.5-MoE multimodal model with expert routing") + + def _init_custom(self): + """ + Initialize MoE-specific components. + + Sets up DeepEP communication group for expert parallelism + when the model has experts configured. + """ + super()._init_custom() + # Initialize DeepEP group for MoE models with num_experts + if "num_experts" in self.config and self.config["num_experts"] > 0: + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) From c05838e6bac1eded65a94e3207415198b85ba491 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 28 Feb 2026 08:32:34 +0000 Subject: [PATCH 087/180] fix: lazy-initialize SHM name constants to avoid import-time crash KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME and MAMBA_CACHE_CAN_USE_NUM_SHM_NAME were evaluated at module import time, calling get_unique_server_name() before set_unique_server_name(args) had been called. This caused: TypeError: argument of type 'NoneType' is not iterable Replace module-level constants with _get_*_shm_name() functions so the env var lookup is deferred until the values are actually needed (inside __init__ / class methods, after server startup). --- lightllm/common/kv_cache_mem_manager/mem_manager.py | 8 +++++--- lightllm/common/mamba_cache_mem_manager/cache_manager.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 4f37a1db89..e671eac01d 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -25,7 +25,9 @@ logger = init_logger(__name__) -KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_kv_cache_token_can_use_num" + +def _get_kvcache_shm_name(): + return f"{get_unique_server_name()}_kv_cache_token_can_use_num" class MemoryManager(TokenAllocator): @@ -39,7 +41,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False # profile the max total token num if the size is None self.profile_size(mem_fraction) - super().__init__(self.size, f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + super().__init__(self.size, f"{_get_kvcache_shm_name()}_{get_current_rank_in_node()}") self._init_buffers( self.size, @@ -440,7 +442,7 @@ def __init__(self) -> None: # 兼容多机 dp size=1 纯 tp 模式的情况 self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 self.shared_tp_can_use_token_nums = [ - SharedInt(f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") + SharedInt(f"{_get_kvcache_shm_name()}_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 9b0933f22f..58b51670c4 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -12,7 +12,9 @@ logger = init_logger(__name__) -MAMBA_CACHE_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_mamba_cache_can_use_num" + +def _get_mamba_cache_shm_name(): + return f"{get_unique_server_name()}_mamba_cache_can_use_num" class LayerCache: @@ -38,7 +40,7 @@ def __init__( ssm_state_dtype: torch.dtype, ssm_state_shape: Tuple[int, ...], ): - super().__init__(size, f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + super().__init__(size, f"{_get_mamba_cache_shm_name()}_{get_current_rank_in_node()}") self.conv_state_cache = LayerCache(size, conv_state_dtype, conv_state_shape, layer_num) self.ssm_state_cache = LayerCache(size, ssm_state_dtype, ssm_state_shape, layer_num) self.HOLD_BUFFER_INDEX = size @@ -119,7 +121,7 @@ def __init__(self) -> None: # 兼容多机 dp size=1 纯 tp 模式的情况 self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 self.shared_tp_can_use_token_nums = [ - SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") + SharedInt(f"{_get_mamba_cache_shm_name()}_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] From 243c6a022668d9163542be2e6ee61821a84c900a Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 28 Feb 2026 10:14:10 +0000 Subject: [PATCH 088/180] fix: revert weight slicing and rmsnorm precision regressions Two critical bugs introduced by merge origin/qwen3.5_clean: 1. mm_slicer.py: Weight slicing dimension changed from shape[0]/shape[1] to shape[-2]/shape[-1], which corrupts 2D tensor weights during TP. For 2D weights [out_dim, in_dim], slicing dim -2 instead of dim 0 produces incorrect results with tensor parallelism. 2. rmsnorm.py: Precision regression in Triton kernel: - Weight loaded without float32 conversion - x_hat computed then cast to bfloat16 before multiplication - This loses precision during normalization, degrading accuracy Both issues caused intermittent garbage output and accuracy drop (0.88 -> 0.65) for qwen3.5_moe model. --- .../basemodel/triton_kernel/norm/rmsnorm.py | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index 6988cc4113..ca8f9a1c81 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -4,7 +4,7 @@ import triton.language as tl import os -rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "4")) +rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) @triton.jit @@ -36,12 +36,12 @@ def _rms_norm_fwd_fused( for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - w = tl.load(W + cols, mask=mask) + w = tl.load(W + cols, mask=mask).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = (x * rstd).to(tl.bfloat16) - y = x_hat * w.to(tl.bfloat16) + x_hat = x * rstd + y = x_hat * w # Write output - tl.store(Y + cols * y_stride1, y, mask=mask) + tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): @@ -79,19 +79,22 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) return y -# def rmsnorm_forward(hidden_states, weight, eps, out=None): -# input_dtype = hidden_states.dtype -# hidden_states = hidden_states.to(torch.float32) -# variance = hidden_states.pow(2).mean(-1, keepdim=True) -# hidden_states = hidden_states * torch.rsqrt(variance + eps) -# out = weight * hidden_states.to(input_dtype) -# return out +def torch_rms_norm(x, weight, eps): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight -def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - print(f"norm weight dtype:{self.weight.dtype}") - return self.weight * hidden_states.to(input_dtype) +def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + # forward pass + y_tri = rmsnorm_forward(x, weight, eps) + y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype) + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + return From 711e30c14e9f3956394417b5ebd288b75d5dd546 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 28 Feb 2026 10:16:41 +0000 Subject: [PATCH 089/180] fix --- lightllm/models/llama/model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 20d6cad743..cc1dc28178 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -129,10 +129,9 @@ def _init_to_get_rotary(self, default_base=10000): except: pass - full_inv_freq = 1.0 / ( + inv_freq = 1.0 / ( base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) ) - inv_freq = full_inv_freq[::2] # for neo t = ( torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) / rope_scaling_factor @@ -170,10 +169,9 @@ def _init_to_get_hw_rotary(self, default_base=10000): except: pass - full_inv_freq = 1.0 / ( + inv_freq = 1.0 / ( base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) ) - inv_freq = full_inv_freq[::2] t = ( torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) / rope_scaling_factor From 7734c21181a8477b0535fd68942a590d905ebbc5 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 19 Feb 2026 14:37:20 +0000 Subject: [PATCH 090/180] feat: add Qwen3Next linear attention model support Implement comprehensive support for Qwen3Next model with linear attention mechanism: Model Features: - Implement linear attention with MTP (Multi-Token Prediction) capability - Add custom Triton kernels for gated delta networks (GDN) operations - Support chunked operations for efficient attention computation - Add specialized buffer pool and memory managers for linear attention Triton Kernels: - Add causal_conv1d for efficient convolution operations - Implement chunk-based operations (chunk_o, chunk_delta_h, chunk_scaled_dot_kkt) - Add gated delta network kernels (fused_gdn_gating, gdn_decode_mtp) - Implement fused normalization (gemma_rmsnorm, gated_rmsnorm) Infrastructure: - Add hybrid radix cache for efficient memory management - Implement mamba cache manager for state management - Add allocator utilities for buffer management - Add parameter weight abstraction for flexible weight handling - Update model registration and API endpoints Performance Optimizations: - Add H200 autotune configurations for all Triton kernels - Optimize memory allocation with custom kernels - Support chunked prefill and decode backends This implementation enables efficient inference for models with linear attention mechanisms, providing significant speedup for long sequence lengths. --- lightllm/common/allocator_utils.py | 98 ++ lightllm/common/basemodel/basemodel.py | 6 + .../transformer_layer_infer_template.py | 49 +- .../basemodel/layer_weights/hf_load_utils.py | 10 +- .../layer_weights/meta_weights/__init__.py | 1 + .../meta_weights/parameter_weight.py | 83 + .../triton_kernel/alloc_buffer_kernel.py | 80 + .../triton_kernel/mamba_buffer_copy.py | 961 ++++++++++++ .../kv_cache_mem_manager/mem_manager.py | 108 +- .../mamba_cache_mem_manager/cache_manager.py | 188 +++ lightllm/common/req_manager.py | 46 + .../{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json | 14 + .../{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json | 14 + .../{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json | 14 + .../{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json | 7 + .../{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json | 12 + ...ARLEN=true,REVERSE=false}_NVIDIA_H200.json | 38 + ...ARLEN=true,REVERSE=false}_NVIDIA_H200.json | 38 + ...=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json | 7 + ...H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json | 12 + ...6,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 50 + ...8,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 50 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 70 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...um=10,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...um=10,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ ...um=10,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++ .../{topk_num=10}_NVIDIA_H200.json | 50 + ...rch.bfloat16,topk_num=10}_NVIDIA_H200.json | 74 + ...rch.bfloat16,topk_num=10}_NVIDIA_H200.json | 74 + ...M=4,dtype=torch.bfloat16}_NVIDIA_H200.json | 50 + ...M=4,dtype=torch.bfloat16}_NVIDIA_H200.json | 50 + ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 + ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 + lightllm/models/__init__.py | 6 + lightllm/models/qwen3next/__init__.py | 3 + lightllm/models/qwen3next/buffer_pool.py | 83 + lightllm/models/qwen3next/infer_struct.py | 62 + .../qwen3next/layer_infer/post_layer_infer.py | 12 + .../layer_infer/shared_expert_mixin.py | 101 ++ .../layer_infer/transformer_layer_infer.py | 1067 +++++++++++++ .../layer_weights/transformer_layer_weight.py | 313 ++++ lightllm/models/qwen3next/mem_manager.py | 72 + lightllm/models/qwen3next/model.py | 157 ++ .../qwen3next/triton_kernel/causal_conv1d.py | 122 ++ .../qwen3next/triton_kernel/fla/__init__.py | 11 + .../triton_kernel/fla/ops/__init__.py | 15 + .../qwen3next/triton_kernel/fla/ops/chunk.py | 224 +++ .../triton_kernel/fla/ops/chunk_delta_h.py | 324 ++++ .../triton_kernel/fla/ops/chunk_o.py | 205 +++ .../fla/ops/chunk_scaled_dot_kkt.py | 180 +++ .../qwen3next/triton_kernel/fla/ops/cumsum.py | 306 ++++ .../triton_kernel/fla/ops/fused_recurrent.py | 492 ++++++ .../qwen3next/triton_kernel/fla/ops/index.py | 30 + .../qwen3next/triton_kernel/fla/ops/l2norm.py | 173 +++ .../qwen3next/triton_kernel/fla/ops/op.py | 65 + .../triton_kernel/fla/ops/solve_tril.py | 462 ++++++ .../qwen3next/triton_kernel/fla/ops/utils.py | 179 +++ .../triton_kernel/fla/ops/wy_fast.py | 145 ++ .../triton_kernel/fused_add_gemma_rmsnorm.py | 186 +++ .../triton_kernel/fused_gdn_gating.py | 87 ++ .../triton_kernel/fused_qkv_gating.py | 163 ++ .../triton_kernel/fused_split_copy.py | 400 +++++ .../qwen3next/triton_kernel/gated_rmsnorm.py | 174 +++ .../qwen3next/triton_kernel/gdn_decode_mtp.py | 1333 +++++++++++++++++ .../qwen3next/triton_kernel/gemma_rmsnorm.py | 141 ++ lightllm/models/qwen3next_mtp/__init__.py | 3 + .../qwen3next_mtp/layer_infer/__init__.py | 0 .../layer_infer/post_layer_infer.py | 16 + .../layer_infer/pre_layer_infer.py | 68 + .../layer_infer/transformer_layer_infer.py | 30 + .../qwen3next_mtp/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 47 + .../layer_weights/transformer_layer_weight.py | 141 ++ lightllm/models/qwen3next_mtp/model.py | 101 ++ lightllm/server/api_cli.py | 20 +- lightllm/server/api_openai.py | 20 +- lightllm/server/api_start.py | 3 +- lightllm/server/core/objs/start_args_type.py | 24 +- .../dynamic_prompt/hybrid_radix_cache.py | 206 +++ lightllm/server/tokenizer.py | 8 + lightllm/utils/config_utils.py | 16 + lightllm/utils/envs_utils.py | 2 +- 91 files changed, 10981 insertions(+), 124 deletions(-) create mode 100644 lightllm/common/allocator_utils.py create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py create mode 100644 lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py create mode 100644 lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py create mode 100644 lightllm/common/mamba_cache_mem_manager/cache_manager.py create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/models/qwen3next/__init__.py create mode 100644 lightllm/models/qwen3next/buffer_pool.py create mode 100644 lightllm/models/qwen3next/infer_struct.py create mode 100644 lightllm/models/qwen3next/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py create mode 100644 lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3next/mem_manager.py create mode 100644 lightllm/models/qwen3next/model.py create mode 100644 lightllm/models/qwen3next/triton_kernel/causal_conv1d.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/__init__.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/index.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/op.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_split_copy.py create mode 100644 lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py create mode 100644 lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py create mode 100644 lightllm/models/qwen3next_mtp/__init__.py create mode 100644 lightllm/models/qwen3next_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/qwen3next_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3next_mtp/model.py create mode 100644 lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py diff --git a/lightllm/common/allocator_utils.py b/lightllm/common/allocator_utils.py new file mode 100644 index 0000000000..803ed0a715 --- /dev/null +++ b/lightllm/common/allocator_utils.py @@ -0,0 +1,98 @@ +from typing import List, Union + +import torch + +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class TokenAllocator: + def __init__(self, size, shared_can_use_token_num_name: str): + self.size = size + + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + + self.can_use_mem_size = self.size + + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + self.shared_can_use_token_num = SharedInt(shared_can_use_token_num_name) + + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.HOLD_TOKEN_MEMINDEX = self.size + + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" + + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) + self.mem_state[start:end] = free_index_tensor + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + return + + def free_all(self): + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) + + def resize_mem(self, new_size): + """ + just for test code + """ + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5c1d2b8712..caa90462cc 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -53,6 +53,12 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo + @classmethod + def get_radix_cache_class(cls): + from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache + + return RadixCache + def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 9153349c5d..646f998642 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -62,20 +62,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) - input1 = None + def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) - q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.context_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -87,39 +88,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) - input1 = None + def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.token_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight) - input1 = None + def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) - q = None o = self._tpsp_get_o(o, infer_state, layer_weight) + return o + + def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -129,14 +133,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight) - input1 = None + def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._tpsp_get_o(o, infer_state, layer_weight) + return o + + def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8cf66a5ad6..304b04ab44 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -18,6 +18,14 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay weights = {k: weights.get_tensor(k) for k in weights.keys()} else: weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") + new_weight = {} + for k, v in weights.items(): + if "language_model." in k: + new_weight[k[len("language_model.") :]] = v + else: + new_weight[k] = v + del weights + weights = new_weight if pre_post_layer is not None: pre_post_layer.load_hf_weights(weights) @@ -60,7 +68,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye transformer_layer_list=transformer_layer_list, weight_dir=weight_dir, ) # noqa - worker = int(os.environ.get("LOADWORKER", 1)) + worker = int(os.environ.get("LOADWORKER", 18)) with Pool(worker) as p: iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index edf7fe21b9..fe77ca669c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -11,3 +11,4 @@ from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight from .fused_moe.fused_moe_weight import FusedMoeWeight +from .parameter_weight import ParameterWeight, TpParameterWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py new file mode 100644 index 0000000000..0afb0ecab2 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -0,0 +1,83 @@ +import torch +from typing import Dict, Optional, Tuple +from .base_weight import BaseWeightTpl + + +class ParameterWeight(BaseWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + weight_shape: Optional[Tuple[int, ...]], + bias_name: Optional[str] = None, + bias_shape: Optional[Tuple[int, ...]] = None, + ): + super().__init__() + self.weight_name = weight_name + self.bias_name = bias_name + self.data_type_ = data_type + self.weight_shape = weight_shape + self.bias_shape = bias_shape + self.weight: Optional[torch.Tensor] = None + self.bias: Optional[torch.Tensor] = None + if weight_shape is not None: + self._create_weight() + + def _create_weight(self): + if self.weight_shape is not None: + self.weight = torch.empty(*self.weight_shape, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False + if self.bias_name is not None and self.bias_shape is not None: + self.bias = torch.empty(*self.bias_shape, dtype=self.data_type_, device=self.device_id_) + self.bias.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_name in weights: + t_weight = weights[self.weight_name] + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True + if self.bias_name is not None and self.bias_name in weights: + t_bias = weights[self.bias_name] + self.bias.copy_(t_bias.to(self.data_type_)) + self.bias.load_ok = True + + def verify_load(self) -> bool: + if self.weight is not None and not getattr(self.weight, "load_ok", False): + return False + if self.bias is not None and not getattr(self.bias, "load_ok", False): + return False + return True + + +class TpParameterWeight(ParameterWeight): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_shape: Optional[Tuple[int, ...]] = None, + bias_shape: Optional[Tuple[int, ...]] = None, + ): + self.split_n_embed = split_n_embed + # Calculate TP-split shapes if full shapes are provided + tp_weight_shape = None + tp_bias_shape = None + if weight_shape is not None: + tp_weight_shape = (split_n_embed,) + weight_shape[1:] + if bias_shape is not None: + tp_bias_shape = (split_n_embed,) + bias_shape[1:] + super().__init__(weight_name, data_type, tp_weight_shape, bias_name, tp_bias_shape) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + start = self.split_n_embed * self.tp_rank_ + end = self.split_n_embed * (self.tp_rank_ + 1) + + if self.weight_name in weights: + t_weight = weights[self.weight_name][start:end] + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True + if self.bias_name is not None and self.bias_name in weights: + t_bias = weights[self.bias_name][start:end] + self.bias.copy_(t_bias.to(self.data_type_)) + self.bias.load_ok = True diff --git a/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py b/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py new file mode 100644 index 0000000000..b6444449b1 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py @@ -0,0 +1,80 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def alloc_buffer_for_req_kernel( + req_index_ptr, # [num_reqs] - indices of requests to allocate buffers for + buffer_indexes_ptr, # [num_reqs * num_buffers_per_req] - buffer indices to assign (from CPU) + req_to_buffer_index_ptr, # [max_request_num + 1, num_buffers_per_req] - tensor mapping req_idx to buffer_idx + num_reqs, # number of requests to process + stride_buffer, # stride for req_to_buffer_index second dimension + NUM_BUFFERS_PER_REQ: tl.constexpr, # number of buffers per request (mtp_step + 1) + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Mask for valid indices + mask = offsets < num_reqs + + # Load request indices + req_indices = tl.load(req_index_ptr + offsets, mask=mask, other=0) + + # For each request, allocate NUM_BUFFERS_PER_REQ buffers + for buf_idx in tl.static_range(NUM_BUFFERS_PER_REQ): + # Load buffer index for this position + buffer_offset = offsets * NUM_BUFFERS_PER_REQ + buf_idx + buffer_indices = tl.load(buffer_indexes_ptr + buffer_offset, mask=mask, other=0) + + # Update req_to_buffer_index[req_indices, buf_idx] = buffer_indices + output_offset = req_indices * stride_buffer + buf_idx + tl.store(req_to_buffer_index_ptr + output_offset, buffer_indices, mask=mask) + + +def alloc_buffer_for_req_triton( + req_index: torch.Tensor, # [num_reqs] int32/int64 tensor on CUDA + buffer_indexes: torch.Tensor, # [num_reqs * (mtp_step + 1)] int32 tensor (can be CPU or CUDA) + req_to_buffer_index: torch.Tensor, # [max_request_num + 1, mtp_step + 1] int32 tensor on CUDA + mtp_step: int = 0, # number of additional buffers per request (default 0 for non-MTP mode) +): + num_reqs = req_index.shape[0] + num_buffers_per_req = mtp_step + 1 + + # Ensure inputs are on CUDA + if not req_index.is_cuda: + req_index = req_index.cuda() + if not buffer_indexes.is_cuda: + buffer_indexes = buffer_indexes.cuda() + + # Ensure correct dtypes + if req_index.dtype not in [torch.int32, torch.int64]: + req_index = req_index.to(torch.int32) + if buffer_indexes.dtype != torch.int32: + buffer_indexes = buffer_indexes.to(torch.int32) + + # Validate buffer_indexes size + expected_size = num_reqs * num_buffers_per_req + assert buffer_indexes.shape[0] == expected_size, ( + f"Expected {expected_size} buffer indices for {num_reqs} requests " + f"with mtp_step={mtp_step}, but got {buffer_indexes.shape[0]}" + ) + + # Get stride for the second dimension of req_to_buffer_index + stride_buffer = req_to_buffer_index.stride(0) + + # Launch kernel + BLOCK_SIZE = 256 + grid = (triton.cdiv(num_reqs, BLOCK_SIZE),) + + alloc_buffer_for_req_kernel[grid]( + req_index, + buffer_indexes, + req_to_buffer_index, + num_reqs, + stride_buffer, + NUM_BUFFERS_PER_REQ=num_buffers_per_req, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py new file mode 100644 index 0000000000..b4a91f7861 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -0,0 +1,961 @@ +""" +Optimized Mamba Buffer Copy Kernels with Autotune Support + +This module provides auto-tuned Triton kernels for efficient buffer copying operations +in Mamba-style models, including support for MTP (Multi-Token Prediction) buffer broadcasting. +""" + +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _copy_buffer_p2p_1d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d, + d_size, + BLOCK_D: tl.constexpr, +): + """ + Optimized kernel for 1D buffer copy. + + Grid: (num_pairs, layer_num, num_blocks_d) + Each program copies one block of dimension d for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_d_idx = tl.program_id(2) + + # Load source and destination indices for this pair + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d_start = block_d_idx * BLOCK_D + d_offsets = d_start + tl.arange(0, BLOCK_D) + + # Create mask for valid indices + mask = d_offsets < d_size + + # Calculate source and destination pointers for this layer and pair + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + src_ptr = base_src + d_offsets * stride_d + dst_ptr = base_dst + d_offsets * stride_d + + # Load and store + data = tl.load(src_ptr, mask=mask, other=0.0) + tl.store(dst_ptr, data, mask=mask) + + +@triton.jit +def _copy_buffer_p2p_2d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + d1_size, + d2_size, + num_blocks_d2, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, +): + """ + Kernel to copy 2D buffer from source indices to destination indices. + + Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2) + Each program copies one 2D block for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1 and d2 block indices + block_d1_idx = block_idx // num_blocks_d2 + block_d2_idx = block_idx % num_blocks_d2 + + # Load source and destination indices + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + + # Create mask for valid indices + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + mask = d1_mask[:, None] & d2_mask[None, :] + + # Calculate base pointers + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + # Calculate full offsets + offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 + + # Load and store + data = tl.load(base_src + offsets, mask=mask, other=0.0) + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_1d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d, + d_size, + num_dst_per_src, + BLOCK_D: tl.constexpr, +): + """ + Broadcast kernel for 1D buffer copy (one source to multiple destinations). + + Grid: (num_src, layer_num, num_blocks_d) + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_d_idx = tl.program_id(2) + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets for this block + d_start = block_d_idx * BLOCK_D + d_offsets = d_start + tl.arange(0, BLOCK_D) + mask = d_offsets < d_size + + # Calculate source pointer + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + src_ptr = base_src + d_offsets * stride_d + + # Load data once + data = tl.load(src_ptr, mask=mask, other=0.0) + + # Broadcast to all destinations for this source + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + dst_ptr = base_dst + d_offsets * stride_d + + tl.store(dst_ptr, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_2d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + d1_size, + d2_size, + num_blocks_d2, + num_dst_per_src, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, +): + """ + Broadcast kernel for 2D buffer copy (one source to multiple destinations). + + Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2) + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx + block_d1_idx = block_idx // num_blocks_d2 + block_d2_idx = block_idx % num_blocks_d2 + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + mask = d1_mask[:, None] & d2_mask[None, :] + + # Calculate source pointer and load data once + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 + data = tl.load(base_src + offsets, mask=mask, other=0.0) + + # Broadcast to all destinations + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_p2p_3d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + stride_d3, + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, + BLOCK_D3: tl.constexpr, +): + """ + Optimized kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). + + Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) + Each program copies one 3D block for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1, d2, d3 block indices + block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) + temp = block_idx % (num_blocks_d2 * num_blocks_d3) + block_d2_idx = temp // num_blocks_d3 + block_d3_idx = temp % num_blocks_d3 + + # Load source and destination indices for this pair + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + d3_start = block_d3_idx * BLOCK_D3 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + d3_offsets = d3_start + tl.arange(0, BLOCK_D3) + + # Create masks for valid indices + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + d3_mask = d3_offsets < d3_size + + # 3D mask: [BLOCK_D1, BLOCK_D2, BLOCK_D3] + mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] + + # Calculate base pointers + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + # Calculate full 3D offsets + offsets = ( + d1_offsets[:, None, None] * stride_d1 + + d2_offsets[None, :, None] * stride_d2 + + d3_offsets[None, None, :] * stride_d3 + ) + + # Load and store + data = tl.load(base_src + offsets, mask=mask, other=0.0) + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_3d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + stride_d3, + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + num_dst_per_src, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, + BLOCK_D3: tl.constexpr, +): + """ + Broadcast kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). + + Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) + Each program loads once from source and broadcasts to all destinations. + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1, d2, d3 block indices + block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) + temp = block_idx % (num_blocks_d2 * num_blocks_d3) + block_d2_idx = temp // num_blocks_d3 + block_d3_idx = temp % num_blocks_d3 + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + d3_start = block_d3_idx * BLOCK_D3 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + d3_offsets = d3_start + tl.arange(0, BLOCK_D3) + + # Create masks + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + d3_mask = d3_offsets < d3_size + + mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] + + # Calculate source pointer and load data once + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + + offsets = ( + d1_offsets[:, None, None] * stride_d1 + + d2_offsets[None, :, None] * stride_d2 + + d3_offsets[None, None, :] * stride_d3 + ) + + data = tl.load(base_src + offsets, mask=mask, other=0.0) + + # Broadcast to all destinations for this source + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + tl.store(base_dst + offsets, data, mask=mask) + + +# ==================== Config Generation Functions ==================== + + +def _get_buffer_copy_1d_configs(): + """Generate candidate configurations for 1D buffer copy.""" + configs = [] + for block_d in [32, 64, 128, 256, 512, 1024]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_D": block_d, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_buffer_copy_2d_configs(): + """Generate candidate configurations for 2D buffer copy.""" + configs = [] + for block_d1 in [16, 32, 64, 128]: + for block_d2 in [16, 32, 64, 128, 256]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_D1": block_d1, + "BLOCK_D2": block_d2, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_buffer_copy_3d_configs(): + """Generate candidate configurations for 3D buffer copy (5D tensor).""" + configs = [] + for block_d1 in [8, 16, 32]: + for block_d2 in [8, 16, 32, 64]: + for block_d3 in [8, 16, 32, 64, 128]: + for num_warps in [4, 8]: + for num_stages in [2, 3]: + # Skip configs that are too large for shared memory + if block_d1 * block_d2 * block_d3 > 32768: + continue + configs.append( + { + "BLOCK_D1": block_d1, + "BLOCK_D2": block_d2, + "BLOCK_D3": block_d3, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +# ==================== Static and Run Key Functions ==================== + + +def _get_buffer_copy_static_key(src_buffer: torch.Tensor): + """Static key based on buffer shape and dtype.""" + shape = src_buffer.shape + return { + "ndim": len(shape), + "layer_num": shape[0], + "d_sizes": str(shape[2:]), # Dimension sizes + "dtype": str(src_buffer.dtype), + } + + +def _get_buffer_copy_run_key(src_indexes: torch.Tensor): + """Run key based on number of copy pairs.""" + return src_indexes.shape[0] + + +# ==================== Auto-tuned Buffer Copy Functions ==================== + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_1d:v1", + configs_gen_func=_get_buffer_copy_1d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_1d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 1D buffer copy.""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d_size = src_buffer.shape[2] + + if run_config is None: + # Default config if autotune is disabled + BLOCK_D = triton.next_power_of_2(min(d_size, 256)) + num_warps = 4 if BLOCK_D > 256 else 2 + num_stages = 2 + else: + BLOCK_D = run_config["BLOCK_D"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d = triton.cdiv(d_size, BLOCK_D) + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_d) + + _copy_buffer_p2p_1d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + d_size, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_2d:v1", + configs_gen_func=_get_buffer_copy_2d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_2d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 2D buffer copy.""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + + if run_config is None: + # Default config if autotune is disabled + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_total = num_blocks_d1 * num_blocks_d2 + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_p2p_2d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + d1_size, + d2_size, + num_blocks_d2, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_1d:v1", + configs_gen_func=_get_buffer_copy_1d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_1d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 1D buffer broadcast (one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d_size = src_buffer.shape[2] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D = triton.next_power_of_2(min(d_size, 256)) + num_warps = 4 if BLOCK_D > 256 else 2 + num_stages = 2 + else: + BLOCK_D = run_config["BLOCK_D"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d = triton.cdiv(d_size, BLOCK_D) + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_d) + + _copy_buffer_broadcast_1d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + d_size, + num_dst_per_src, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_2d:v1", + configs_gen_func=_get_buffer_copy_2d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_2d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 2D buffer broadcast (one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_total = num_blocks_d1 * num_blocks_d2 + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_broadcast_2d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + d1_size, + d2_size, + num_blocks_d2, + num_dst_per_src, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_3d:v1", + configs_gen_func=_get_buffer_copy_3d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_3d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 3D data buffer copy (5D tensor).""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + d3_size = src_buffer.shape[4] + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) + BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) + num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + BLOCK_D3 = run_config["BLOCK_D3"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) + num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_p2p_3d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + src_buffer.stride(4), + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + BLOCK_D3=BLOCK_D3, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_3d:v1", + configs_gen_func=_get_buffer_copy_3d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_3d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 3D data buffer broadcast (5D tensor, one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + d3_size = src_buffer.shape[4] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) + BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) + num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + BLOCK_D3 = run_config["BLOCK_D3"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) + num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_broadcast_3d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + src_buffer.stride(4), + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + num_dst_per_src, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + BLOCK_D3=BLOCK_D3, + num_warps=num_warps, + num_stages=num_stages, + ) + + +# ==================== Unified Interface ==================== + + +def copy_buffer_p2p( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + """ + Copy buffers from source indices to destination indices with auto-tuning. + + Supports 3D (conv states), 4D (standard buffers), and 5D (SSM states) buffers. + + Args: + src_buffer: Source buffer tensor [layer_num, buffer_size, ...] + dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] + src_indexes: Source buffer indices [num_pairs] + dst_indexes: Destination buffer indices [num_pairs] + """ + assert src_buffer.shape == dst_buffer.shape + assert src_indexes.shape == dst_indexes.shape + assert len(src_indexes.shape) == 1 + + if len(src_buffer.shape) == 3: + # 1D case: (layer_num, buffer_size, d) + _copy_buffer_p2p_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + elif len(src_buffer.shape) == 4: + # 2D case: (layer_num, buffer_size, d1, d2) + _copy_buffer_p2p_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + elif len(src_buffer.shape) == 5: + # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory + _copy_buffer_p2p_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + else: + raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") + + +def copy_buffer_broadcast( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + """ + Broadcast buffers from source indices to multiple destination indices (MTP use case). + + Each source buffer is copied to multiple destination buffers. + + Args: + src_buffer: Source buffer tensor [layer_num, buffer_size, ...] + dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] + src_indexes: Source buffer indices [num_src] + dst_indexes: Destination buffer indices [num_src, num_dst_per_src] (2D tensor) + """ + assert src_buffer.shape == dst_buffer.shape + assert len(src_indexes.shape) == 1 + assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D, got shape {dst_indexes.shape}" + + num_src = src_indexes.shape[0] + + assert num_src == dst_indexes.shape[0], f"Mismatch: src_indexes {num_src} vs dst_indexes {dst_indexes.shape[0]}" + + # Flatten dst_indexes for kernel + dst_indexes_flat = dst_indexes.reshape(-1).contiguous() + + if len(src_buffer.shape) == 3: + # 1D case + _copy_buffer_broadcast_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + elif len(src_buffer.shape) == 4: + # 2D case + _copy_buffer_broadcast_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + elif len(src_buffer.shape) == 5: + # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory + _copy_buffer_broadcast_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + else: + raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec7..8d6fb48c28 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -18,14 +18,17 @@ from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm +from lightllm.common.allocator_utils import TokenAllocator from multiprocessing.reduction import ForkingPickler from filelock import FileLock logger = init_logger(__name__) +KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_kv_cache_token_can_use_num" -class MemoryManager: + +class MemoryManager(TokenAllocator): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num @@ -36,27 +39,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False # profile the max total token num if the size is None self.profile_size(mem_fraction) - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size - - self.can_use_mem_size = self.size + super().__init__(self.size, f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name - - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) - - self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, @@ -64,7 +48,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False head_dim, layer_num, ) - self.HOLD_TOKEN_MEMINDEX = self.size def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ @@ -341,59 +324,13 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """_summary_ - - Args: - free_index (torch.Tensor): _description_ - """ - - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - self.mem_state.numpy()[start:end] = free_index - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - return + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kv_buffer[:, index]} - def free_all(self): - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + # 重写resize_mem方法,添加_free_buffers和_init_buffers调用 def resize_mem(self, new_size): """ just for test code @@ -404,24 +341,13 @@ def resize_mem(self, new_size): head_dim = self.head_dim layer_num = self.layer_num - self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + # 调用父类的resize_mem + super().resize_mem(new_size) + self._free_buffers() self._init_buffers(size, dtype, head_num, head_dim, layer_num) return - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index]} - - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - def copy_kv_from_other_dp_ranks( self, mem_managers: List["MemoryManager"], @@ -513,12 +439,12 @@ def __init__(self) -> None: self.dp_world_size = self.global_world_size // args.dp # 兼容多机 dp size=1 纯 tp 模式的情况 self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 - self.shared_tp_infos = [ - SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}") + self.shared_tp_can_use_token_nums = [ + SharedInt(f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] def get_unrefed_token_num(self, dp_rank_in_node: int): if self.is_multinode_tp: - return self.shared_tp_infos[0].get_value() - return self.shared_tp_infos[dp_rank_in_node].get_value() + return self.shared_tp_can_use_token_nums[0].get_value() + return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py new file mode 100644 index 0000000000..348b14192c --- /dev/null +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -0,0 +1,188 @@ +from typing import List, Tuple, Union + +import torch +import numpy as np + +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.common.allocator_utils import TokenAllocator +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt + +logger = init_logger(__name__) + +MAMBA_CACHE_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_mamba_cache_can_use_num" + + +class LayerCache: + def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_num: int): + self.size = size + self.dtype = dtype + self.shape = shape + self.layer_num = layer_num + + self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda") + + def get_cell_size(self): + return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) + + +class MambaCacheManager(TokenAllocator): + def __init__( + self, + size: int, + layer_num: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + ): + super().__init__(size, f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + self.conv_state_cache = LayerCache(size, conv_state_dtype, conv_state_shape, layer_num) + self.ssm_state_cache = LayerCache(size, ssm_state_dtype, ssm_state_shape, layer_num) + self.HOLD_BUFFER_INDEX = size + + logger.warning( + f"Linear attention state cache size: {size}\n" + f"Conv state use : " + f"{self.conv_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" + f"Ssm state use : " + f"{self.ssm_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" + ) + + def get_mamba_cache(self, layer_idx: int): + conv_state = self.conv_state_cache.buffer[layer_idx] + ssm_state = self.ssm_state_cache.buffer[layer_idx] + return conv_state, ssm_state + + def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): + """ + Copy buffers from source indices to destination indices using optimized Triton kernel. + + Args: + src_buffer_indexes: Source buffer indices (1D tensor) + dst_buffer_indexes: Destination buffer indices (1D tensor) + """ + assert src_buffer_indexes.dim() == 1 + assert dst_buffer_indexes.dim() == 1 + assert src_buffer_indexes.shape[0] == dst_buffer_indexes.shape[0] + + # Validate indices are within valid range [0, size] (size+1 is the buffer dim) + max_valid_idx = self.size # HOLD_BUFFER_INDEX = size is valid + src_max = src_buffer_indexes.max().item() if src_buffer_indexes.numel() > 0 else -1 + src_min = src_buffer_indexes.min().item() if src_buffer_indexes.numel() > 0 else -1 + dst_max = dst_buffer_indexes.max().item() if dst_buffer_indexes.numel() > 0 else -1 + dst_min = dst_buffer_indexes.min().item() if dst_buffer_indexes.numel() > 0 else -1 + + if src_min < 0 or src_max > max_valid_idx or dst_min < 0 or dst_max > max_valid_idx: + logger.error( + f"Invalid buffer indices: src=[{src_min}, {src_max}], dst=[{dst_min}, {dst_max}], " + f"valid range=[0, {max_valid_idx}], conv shape={self.conv_state_cache.buffer.shape}, " + f"ssm shape={self.ssm_state_cache.buffer.shape}" + ) + raise ValueError("Invalid buffer indices for copy_buffer_p2p") + + # Use PyTorch advanced indexing for buffer copy (safer than Triton for complex shapes) + # The buffer shape is [layer_num, buffer_size, *shape] + # We need to copy all layers for the given buffer indices + src_idx = src_buffer_indexes.long() + dst_idx = dst_buffer_indexes.long() + + # Copy conv_state: [layer_num, buffer_size, d1, d2] + self.conv_state_cache.buffer[:, dst_idx, ...] = self.conv_state_cache.buffer[:, src_idx, ...] + + # Copy ssm_state: [layer_num, buffer_size, d1, d2, d3] + self.ssm_state_cache.buffer[:, dst_idx, ...] = self.ssm_state_cache.buffer[:, src_idx, ...] + return + + def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + assert src_buffer_index.dim() == 1 + assert dst_buffer_indexes.dim() == 2 + assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] + + # Use PyTorch advanced indexing for broadcast copy + # src_buffer_index: [num_src] + # dst_buffer_indexes: [num_src, num_dst_per_src] + src_idx = src_buffer_index.long() + dst_idx = dst_buffer_indexes.long() + + # Broadcast each source to all its destinations + # For each (src, dst_group), copy buffer[src] to buffer[dst1], buffer[dst2], ... + num_src, num_dst_per_src = dst_idx.shape + for i in range(num_src): + src = src_idx[i : i + 1] # Keep as 1D tensor with 1 element + dsts = dst_idx[i, :] # 1D tensor with num_dst_per_src elements + # Copy conv_state + self.conv_state_cache.buffer[:, dsts, ...] = self.conv_state_cache.buffer[:, src, ...] + # Copy ssm_state + self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] + return + + def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + """ + Broadcast ONLY SSM states (not conv states) from source indices to destination indices. + + This is used for MTP mode where each buffer maintains its own independent conv state, + but SSM states need to be synchronized. + """ + assert src_buffer_index.dim() == 1 + assert dst_buffer_indexes.dim() == 2 + assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] + + # Use PyTorch advanced indexing for SSM-only broadcast copy + src_idx = src_buffer_index.long() + dst_idx = dst_buffer_indexes.long() + + # Broadcast each source to all its destinations (SSM only) + num_src = dst_idx.shape[0] + for i in range(num_src): + src = src_idx[i : i + 1] + dsts = dst_idx[i, :] + # Only copy ssm_state, NOT conv_state + self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] + return + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """ + Free the allocated cache buffers and clear them. + + Args: + free_index: Buffer indices to free (tensor or list of ints) + """ + # Convert to tensor if needed for indexing + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=torch.long, device="cuda") + else: + free_index_tensor = free_index.to(device="cuda", dtype=torch.long) + + # Clear the buffers for the freed indices + # Shape: [layer_num, buffer_index, *shape] + self.conv_state_cache.buffer[:, free_index_tensor, ...] = 0 + self.ssm_state_cache.buffer[:, free_index_tensor, ...] = 0 + + # Call parent's free method to update allocator state + super().free(free_index) + return + + +class ReadOnlyStaticsMambaCacheManager: + """ + 读取一些统计信息 + """ + + def __init__(self) -> None: + args = get_env_start_args() + self.global_world_size = args.tp + self.node_world_size = args.tp // args.nnodes + self.dp_world_size = self.global_world_size // args.dp + # 兼容多机 dp size=1 纯 tp 模式的情况 + self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 + self.shared_tp_can_use_token_nums = [ + SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") + for rank_in_node in range(0, self.node_world_size, self.dp_world_size) + ] + + def get_unrefed_token_num(self, dp_rank_in_node: int): + if self.is_multinode_tp: + return self.shared_tp_can_use_token_nums[0].get_value() + return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 33bdca4475..573fe50842 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,8 +1,10 @@ import torch import collections +from lightllm.common.basemodel.triton_kernel.alloc_buffer_kernel import alloc_buffer_for_req_triton from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional + from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args @@ -93,6 +95,18 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def alloc_buffer_for_req(self, req_index: torch.Tensor): + """Allocate buffers for requests. No-op for standard models without linear attention.""" + pass + + def free_buffer(self, free_buffer_indexes): + """Free buffer memory. No-op for standard models without linear attention.""" + pass + + def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): + """Copy buffer state between requests. No-op for standard models without linear attention.""" + pass + class ReqSamplingParamsManager: """ @@ -232,3 +246,35 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): p_token_counts_tensor.cuda(non_blocking=True), p_cumsum_seq_len_tensor.cuda(non_blocking=True), ) + + +class ReqManagerForMamba(ReqManager): + def __init__(self, max_request_num, max_sequence_length, mem_manager): + from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager + + super().__init__(max_request_num, max_sequence_length, mem_manager) + self.mtp_step = get_env_start_args().mtp_step + self.buffer_mem_manager: MambaCacheManager = self.mem_manager.mamba_cache_mem_manager + self.req_to_buffer_index = torch.zeros( + (self.max_request_num + 1, self.mtp_step + 1), dtype=torch.int32, device="cuda" + ) + self.req_to_buffer_index[self.HOLD_REQUEST_ID, :] = self.buffer_mem_manager.HOLD_BUFFER_INDEX + + def free_buffer(self, free_buffer_indexes: List[int]): + self.buffer_mem_manager.free(free_buffer_indexes) + return + + def alloc_buffer_for_req(self, req_index: torch.Tensor): + num_reqs = req_index.shape[0] + num_buffers_per_req = self.mtp_step + 1 + buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) + alloc_buffer_for_req_triton(req_index, buffer_indexes, self.req_to_buffer_index, self.mtp_step) + + def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): + # 获取目标请求的所有 MTP buffer (从 buffer[0] 到 buffer[mtp_step]) + mtp_range = torch.arange(0, self.mtp_step + 1, dtype=torch.int32, device="cuda") + all_mtp_buffers = self.req_to_buffer_index[tgt_req_index[:, None], mtp_range[None, :]] + + # 将 shared buffer 广播到所有 MTP step + self.buffer_mem_manager.copy_buffer_broadcast(src_buffer_index, all_mtp_buffers) + return diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5d9216c2ea --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 128, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..338af08a1d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 2 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..4bc06d07d9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..f1159e4357 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..c8fa422e0c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..40cdc996b9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,12 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..a40eda35d4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 1 + }, + "1024": { + "num_warps": 8 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 1 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 1 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b08208be2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 1 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 1 + }, + "128": { + "num_warps": 1 + }, + "16": { + "num_warps": 1 + }, + "16384": { + "num_warps": 2 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 8 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 8 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..27e4804a61 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 64, + "num_stages": 2, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..7749b3601f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,12 @@ +{ + "2": { + "BK": 64, + "num_stages": 2, + "num_warps": 4 + }, + "4": { + "BK": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..49c4dc63d1 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..907575d960 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "32": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 64, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..f525d11257 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,70 @@ +{ + "1024": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "128": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "131072": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16": { + "BLOCK_N": 256, + "num_warps": 4 + }, + "1600": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "256": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "262144": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "4096": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 256, + "num_warps": 4 + }, + "64": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "65536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "8": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..198a196dfb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..537c7a90eb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..9a6dcb6fbf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4096": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..e5a383f23f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..56c79e3a43 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..4843ed8ccf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..3c0e605b00 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..d82ca44a21 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "320": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..96eabffc42 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..07e5e6875f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 512, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..ff4632955f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 4, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..89ab51ff8c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 4, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..f4d29554da --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 1, + "num_warps": 8 + }, + "100": { + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "num_stages": 5, + "num_warps": 2 + }, + "16384": { + "num_stages": 1, + "num_warps": 2 + }, + "2048": { + "num_stages": 2, + "num_warps": 2 + }, + "256": { + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "num_stages": 5, + "num_warps": 2 + }, + "4096": { + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "num_stages": 5, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..8605a91680 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 2, + "num_warps": 1 + }, + "100": { + "num_stages": 4, + "num_warps": 4 + }, + "1024": { + "num_stages": 3, + "num_warps": 2 + }, + "128": { + "num_stages": 2, + "num_warps": 1 + }, + "16": { + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "num_stages": 3, + "num_warps": 2 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 5, + "num_warps": 2 + }, + "32": { + "num_stages": 2, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 2 + }, + "64": { + "num_stages": 2, + "num_warps": 1 + }, + "8": { + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..12993b0231 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "10": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "163840": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "20480": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..e08a58baf5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "10": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "163840": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "20480": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 32ccbe8337..af13e34cd9 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -7,6 +7,8 @@ from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel from lightllm.models.internlm2.model import Internlm2TpPartModel @@ -38,4 +40,8 @@ ) from lightllm.models.gpt_oss.model import GptOssTpPartModel from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel +from lightllm.models.qwen3_5.model import ( + Qwen3_5TpPartModel, + Qwen3_5MOETpPartModel, +) from .registry import get_model, get_model_class diff --git a/lightllm/models/qwen3next/__init__.py b/lightllm/models/qwen3next/__init__.py new file mode 100644 index 0000000000..a9d22c6643 --- /dev/null +++ b/lightllm/models/qwen3next/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel + +__all__ = ["Qwen3NextTpPartModel"] diff --git a/lightllm/models/qwen3next/buffer_pool.py b/lightllm/models/qwen3next/buffer_pool.py new file mode 100644 index 0000000000..42c4bcafc7 --- /dev/null +++ b/lightllm/models/qwen3next/buffer_pool.py @@ -0,0 +1,83 @@ +# lightllm/models/qwen3next/buffer_pool.py +import torch +from typing import Dict, Tuple + + +class Qwen3NextBufferPool: + """ + Buffer pool for Qwen3Next inference to reduce allocations. + + NOT thread-safe. Each GPU worker process should have its own pool instance. + + Manages reusable buffers for: + - Attention norm outputs + - FFN norm outputs + - FFN intermediate activations + - GDN intermediate tensors + """ + + def __init__(self, enable_stats: bool = False, max_buffers: int = 64): + self._buffers: Dict[Tuple[tuple, torch.dtype, torch.device], torch.Tensor] = {} + self._in_use: set = set() + self._max_buffers = max_buffers + self._access_order: list = [] # Track LRU order + self._enable_stats = enable_stats + self._stats = {"hits": 0, "misses": 0, "peak_buffers": 0, "evictions": 0} if enable_stats else None + + def get_buffer( + self, + shape: Tuple[int, ...], + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """Get a buffer from the pool or allocate a new one.""" + key = (shape, dtype, device) + + # Check if we have a matching buffer not in use + if key in self._buffers and key not in self._in_use: + self._in_use.add(key) + # Update LRU order + if key in self._access_order: + self._access_order.remove(key) + self._access_order.append(key) + if self._enable_stats: + self._stats["hits"] += 1 + return self._buffers[key] + + # Evict oldest unused buffer if at capacity + if len(self._buffers) >= self._max_buffers: + self._evict_one() + + # Allocate new buffer + buffer = torch.empty(shape, dtype=dtype, device=device) + self._buffers[key] = buffer + self._in_use.add(key) + self._access_order.append(key) + if self._enable_stats: + self._stats["misses"] += 1 + self._stats["peak_buffers"] = max(self._stats["peak_buffers"], len(self._buffers)) + return buffer + + def _evict_one(self): + """Evict oldest unused buffer (LRU).""" + for key in self._access_order: + if key not in self._in_use and key in self._buffers: + del self._buffers[key] + self._access_order.remove(key) + if self._enable_stats: + self._stats["evictions"] += 1 + return + + def release_all(self): + """Release all buffers back to the pool (call after forward pass).""" + self._in_use.clear() + + def clear(self): + """Clear all buffers (call when changing batch size significantly).""" + self._buffers.clear() + self._in_use.clear() + self._access_order.clear() + + def get_stats(self): + """Return buffer pool statistics (if enabled).""" + return self._stats.copy() if self._stats else None diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py new file mode 100644 index 0000000000..2883534a93 --- /dev/null +++ b/lightllm/models/qwen3next/infer_struct.py @@ -0,0 +1,62 @@ +import torch +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3NextInferStateInfo(LlamaInferStateInfo): + """ + Inference state for Qwen3Next with: + - gate_value attribute for output gating in full attention layers + - MTP-aware batching for multi-token prediction + - Custom buffer management for hybrid attention (full + linear) + """ + + def __init__(self): + super().__init__() + # For output gating in full attention layers + self.gate_value = None + # MTP-aware attributes + self.b_att_seq_len = None + self.att_batch_size = None + self.real_req_idx = None + self.mtp_buffer_idx_list = None + self.b_buffer_idx = None + + def init_some_extra_state(self, model): + """Initialize Qwen3Next-specific state""" + super().init_some_extra_state(model) + + args_mtp_step = get_env_start_args().mtp_step + mtp_size = args_mtp_step + 1 + + if self.is_prefill: + # Prefill: Standard initialization + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() + else: + # Decode: MTP-aware handling + # In MTP mode, each request has (mtp_step + 1) tokens + # att_batch_size is the number of unique requests + self.att_batch_size = self.batch_size // mtp_size + + # Use only the sequence lengths for the last token of each MTP group + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step::mtp_size].contiguous() + self.real_req_idx = self.b_req_idx[args_mtp_step::mtp_size] + else: + self.b_att_seq_len = self.b_seq_len + self.real_req_idx = self.b_req_idx + + # Buffer indices for Mamba cache (conv and SSM states) + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.real_req_idx, :].flatten().contiguous() + + # Create per-step buffer indices for MTP + if args_mtp_step > 0: + buffer_idx_list = [] + for step_id in range(mtp_size): + buffer_idx_list.append(self.b_buffer_idx[step_id::mtp_size].tolist()) + self.mtp_buffer_idx_list = torch.tensor( + buffer_idx_list, dtype=torch.int32, device=self.b_buffer_idx.device + ) + + return diff --git a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..9dcab4e6fc --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py @@ -0,0 +1,12 @@ +import torch + +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextPostLayerInfer(LlamaPostLayerInfer): + def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) + return out diff --git a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py new file mode 100644 index 0000000000..2da106dbb2 --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py @@ -0,0 +1,101 @@ +# lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py +import torch.nn.functional as F +from functools import partial +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +import os + + +class SharedExpertFFNMixin: + """ + Mixin providing shared expert + MoE FFN implementations. + + Used by both full attention and GDN layers in Qwen3Next. + + Requirements: + - Class must have: embed_dim_, tp_world_size_, alloc_tensor() + - Class must have MoE config: is_moe, n_routed_experts, num_experts_per_tok, norm_topk_prob + """ + + def _bind_ffn(self): + """Bind FFN implementation based on MoE configuration.""" + if self.is_moe: + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(SharedExpertFFNMixin._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(SharedExpertFFNMixin._ffn_with_shared_expert_tp, self) + else: + self._ffn = partial(SharedExpertFFNMixin._standard_ffn, self) + return + + def _ffn_core(self, input, layer_weight): + """Core FFN computation: gate_up -> silu_and_mul -> down.""" + input = input.view(-1, self.embed_dim_) + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + + if hasattr(self, "buffer_pool") and self.buffer_pool: + ffn1_out = self.buffer_pool.get_buffer((input.size(0), up_gate_out.size(1) // 2), input.dtype, input.device) + else: + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) + + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + return ffn2_out, input + + def _standard_ffn(self, input, infer_state, layer_weight): + """Standard FFN using shared expert weights (non-MoE layers).""" + ffn2_out, _ = self._ffn_core(input, layer_weight) + return ffn2_out + + def _compute_shared_expert(self, input, layer_weight): + """Compute shared expert FFN output with gating.""" + ffn2_out, input_view = self._ffn_core(input, layer_weight) + return F.sigmoid(layer_weight.shared_expert_gate.mm(input_view)) * ffn2_out, input_view + + def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (tensor parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert(input, layer_weight) + moe_out = self._moe_ffn(input, infer_state, layer_weight) + return shared_expert_out + moe_out + + def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (expert parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert(input, layer_weight) + moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) + return shared_expert_out + moe_out + + def _moe_ffn(self, input, infer_state, layer_weight): + """MoE FFN with tensor parallelism.""" + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + ) + return hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_edp(self, input, infer_state, layer_weight): + """MoE FFN with expert parallelism.""" + hidden_states = input + token_num, hidden_dim = hidden_states.shape + + router_logits = layer_weight.moe_gate.mm(hidden_states) + ep_output = layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + + ep_output = ep_output.view(token_num, hidden_dim) + return ep_output diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..cd5fd67d53 --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -0,0 +1,1067 @@ +import os +import torch + +import torch.distributed as dist +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl +from lightllm.utils.log_utils import init_logger +from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from typing import Tuple +from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.gdn_decode_mtp import ( + copy_conv_states, + copy_ssm_states, + copy_states_fused, +) +from lightllm.distributed import all_reduce +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward +from lightllm.models.qwen3next.triton_kernel.fused_add_gemma_rmsnorm import fused_add_gemma_rmsnorm +from lightllm.models.qwen3next.triton_kernel.fused_split_copy import fused_split_copy_qkvzba, fused_split_copy_qkv +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type +from functools import partial + +logger = init_logger(__name__) + + +class GemmaRMSNormMixin: + """ + Mixin providing Gemma-style RMSNorm implementations. + + Requirements: + - Class must have: eps_, alloc_tensor() + """ + + def _gemma_norm_with_pool(self, input, norm_weight): + """Apply Gemma RMSNorm.""" + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, norm_weight, self.eps_, out=out) + return out + + +class Qwen3NextFullAttentionBaseLayerInfer(GemmaRMSNormMixin, LlamaTransformerLayerInfer): + """ + Base class for Qwen3Next full attention layers. + Contains shared logic for both standard full attention and MTP layers. + """ + + def __init__(self, layer_num, network_config): + # Store Qwen3Next specific configs before calling super().__init__ + self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0) + self.n_routed_experts = network_config.get("num_experts", 0) + self.is_moe = ( + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 + ) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) + self.norm_topk_prob = network_config.get("norm_topk_prob", False) + + super().__init__(layer_num, network_config) + # Override head_dim which may be different in Qwen3Next + self.head_dim_ = network_config.get( + "head_dim", network_config["hidden_size"] // network_config["num_attention_heads"] + ) + + # Pre-allocated decode buffers (mirrors GDN layer pattern) + start_args = get_env_start_args() + self._decode_buffers = {} + self._graph_max_batch_size = start_args.graph_max_batch_size + + # Pre-compute dims for decode buffer pre-allocation + self.shared_inter_size = network_config.get("shared_expert_intermediate_size", 0) + self.tp_gate_up_dim = 2 * self.shared_inter_size // self.tp_world_size_ if self.shared_inter_size > 0 else 0 + self.tp_q_gate_dim = (self.tp_q_head_num_ + self.tp_o_head_num_) * self.head_dim_ + self.tp_kv_dim = (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_ + + return + + def _get_decode_buffer(self, name, max_shape, dtype, device): + """Get or create a pre-allocated buffer for the decode path.""" + key = (name, dtype, device if isinstance(device, str) else str(device)) + if key not in self._decode_buffers: + self._decode_buffers[key] = torch.empty(max_shape, dtype=dtype, device=device) + return self._decode_buffers[key] + + def _bind_func(self): + super()._bind_func() + self._bind_ffn() + return + + def _bind_norm(self): + """Use Gemma-style RMSNorm""" + self._att_norm = partial(Qwen3NextFullAttentionBaseLayerInfer._att_norm_impl, self) + self._ffn_norm = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_norm_impl, self) + return + + def _bind_ffn(self): + """Bind FFN implementation based on MoE configuration.""" + if self.is_moe: + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_tp, self) + else: + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._standard_ffn, self) + return + + def _ffn_core(self, input, layer_weight, is_decode=False): + """Core FFN computation: gate_up -> silu_and_mul -> down.""" + input = input.view(-1, self.embed_dim_) + if is_decode and self.tp_gate_up_dim > 0: + up_gate_buf = self._get_decode_buffer( + "up_gate_out", + (self._graph_max_batch_size, self.tp_gate_up_dim), + input.dtype, + input.device, + )[: input.size(0)] + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input, out=up_gate_buf) + else: + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + inter_dim = up_gate_out.size(1) // 2 + if is_decode: + ffn1_out = self._get_decode_buffer( + "ffn1_out", (self._graph_max_batch_size, inter_dim), input.dtype, input.device + )[: input.size(0)] + else: + ffn1_out = self.alloc_tensor((input.size(0), inter_dim), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + return ffn2_out, input + + def _standard_ffn(self, input, infer_state, layer_weight): + """Standard FFN using shared expert weights (non-MoE layers).""" + ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) + return ffn2_out + + def _compute_shared_expert(self, input, layer_weight, is_decode=False): + """Compute shared expert FFN output with gating.""" + ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) + gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() + ffn2_out.mul_(gate) + return ffn2_out, input_view + + def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (tensor parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert( + input, layer_weight, is_decode=not infer_state.is_prefill + ) + moe_out = self._moe_ffn(input, infer_state, layer_weight) + moe_out.add_(shared_expert_out) + return moe_out + + def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (expert parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert( + input, layer_weight, is_decode=not infer_state.is_prefill + ) + moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) + moe_out.add_(shared_expert_out) + return moe_out + + def _moe_ffn(self, input, infer_state, layer_weight): + """MoE FFN with tensor parallelism.""" + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + if not infer_state.is_prefill: + router_buf = self._get_decode_buffer( + "router_logits", + (self._graph_max_batch_size, self.n_routed_experts), + hidden_states.dtype, + hidden_states.device, + )[:num_tokens] + router_logits = layer_weight.moe_gate.mm(hidden_states, out=router_buf) + else: + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + ) + return hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_edp(self, input, infer_state, layer_weight): + """MoE FFN with expert parallelism.""" + hidden_states = input + token_num, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + ep_output = layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + ep_output = ep_output.view(token_num, hidden_dim) + return ep_output + + def _att_norm_impl( + self, + input, + _infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + return self._gemma_norm_with_pool(input, layer_weight.att_norm_weight_.weight) + + def _ffn_norm_impl( + self, + input, + _infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + return self._gemma_norm_with_pool(input, layer_weight.ffn_norm_weight_.weight) + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + QKV projection with output gating, Q/K normalization, and partial rotary embedding. + """ + input = input.view(-1, self.embed_dim_) + # Single fused GEMM for both Q and output gate projections + if not infer_state.is_prefill: + q_gate_buf = self._get_decode_buffer( + "q_gate_out", + (self._graph_max_batch_size, self.tp_q_gate_dim), + input.dtype, + input.device, + )[: input.size(0)] + q_gate = layer_weight.q_gate_proj.mm(input, out=q_gate_buf) + kv_buf = self._get_decode_buffer( + "kv_out", + (self._graph_max_batch_size, self.tp_kv_dim), + input.dtype, + input.device, + )[: input.size(0)] + kv_out = layer_weight.kv_proj.mm(input, out=kv_buf) + else: + q_gate = layer_weight.q_gate_proj.mm(input) + kv_out = layer_weight.kv_proj.mm(input) + q_dim = self.tp_q_head_num_ * self.head_dim_ + q = q_gate[:, :q_dim].contiguous() + # In-place sigmoid saves one allocation (gate_value is consumed once in _get_o) + infer_state.gate_value = q_gate[:, q_dim:].sigmoid_() + cache_kv = kv_out.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + # Q normalization (in-place via out=input) + gemma_rmsnorm_forward( + q.view(-1, self.head_dim_), + layer_weight.q_norm_weight_.weight, + eps=self.eps_, + out=q.view(-1, self.head_dim_), + ) + + # K normalization + k_input = cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]) + if not infer_state.is_prefill: + k_normed = self._get_decode_buffer( + "k_norm_out", + (self._graph_max_batch_size * self.tp_k_head_num_, cache_kv.shape[-1]), + k_input.dtype, + k_input.device, + )[: k_input.shape[0]] + gemma_rmsnorm_forward(k_input, layer_weight.k_norm_weight_.weight, eps=self.eps_, out=k_normed) + else: + k_normed = gemma_rmsnorm_forward(k_input, layer_weight.k_norm_weight_.weight, eps=self.eps_) + cache_kv[:, : self.tp_k_head_num_, :] = k_normed.view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + + # Rotary embedding with partial rotation support + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, + ) + return q, cache_kv + + def _get_o( + self, + input, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + """Output projection with gating (in-place multiply to save one allocation).""" + input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) + input.mul_(infer_state.gate_value) + infer_state.gate_value = None + o_tensor = layer_weight.o_proj.mm(input) + return o_tensor + + def token_forward(self, input_embdings, infer_state, layer_weight): + """Override token_forward to use pre-allocated decode buffers and fused kernels.""" + max_tokens = self._graph_max_batch_size + input1 = self._get_decode_buffer( + "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device + )[: input_embdings.shape[0]] + gemma_rmsnorm_forward(input_embdings, layer_weight.att_norm_weight_.weight, self.eps_, out=input1) + + o = self.token_attention_forward(input1, infer_state, layer_weight) + + # Fused residual add + FFN norm: saves 1 kernel launch + 1 read of input_embdings + input1 = self._get_decode_buffer( + "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device + )[: input_embdings.shape[0]] + fused_add_gemma_rmsnorm( + input_embdings, + o.view(-1, self.embed_dim_), + layer_weight.ffn_norm_weight_.weight, + self.eps_, + out=input1, + ) + o = None + + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + +class Qwen3NextFullAttentionTransformerLayerInfer(Qwen3NextFullAttentionBaseLayerInfer): + """ + Full attention layer for Qwen3Next that uses the abstracted attention backend. + Inherits from Qwen3NextFullAttentionBaseLayerInfer to get shared Qwen3Next logic. + """ + + pass + + +class Qwen3NextGatedDeltaNetTransformerLayerInfer(GemmaRMSNormMixin, TransformerLayerInferTpl): + """ + Linear attention (Gated Delta Networks) layer for Qwen3Next. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.network_config_ = network_config + + # MoE configuration + self.n_routed_experts = network_config.get("num_experts", 0) + self.is_moe = ( + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 + ) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) + self.norm_topk_prob = network_config.get("norm_topk_prob", False) + self.shared_inter_size = network_config.get("shared_expert_intermediate_size", 0) + + # Standard layer dimensions + self.eps_ = network_config["rms_norm_eps"] + self.embed_dim_ = network_config["hidden_size"] + + # Linear attention specific dimensions + self.num_v_heads = network_config["linear_num_value_heads"] + self.num_k_heads = network_config["linear_num_key_heads"] + self.head_k_dim = network_config["linear_key_head_dim"] + self.head_v_dim = network_config["linear_value_head_dim"] + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_dim = network_config["linear_conv_kernel_dim"] + self.activation = network_config["hidden_act"] + + # Tensor parallelism dimensions + self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ + self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ + self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ + self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ + self.tp_key_dim = self.key_dim // self.tp_world_size_ + self.tp_value_dim = self.value_dim // self.tp_world_size_ + + # Template required dimensions (not used for GDN but required by interface) + self.tp_q_head_num_ = self.tp_num_k_heads + self.tp_k_head_num_ = self.tp_num_k_heads + self.tp_v_head_num_ = self.tp_num_v_heads + self.tp_o_head_num_ = self.tp_num_v_heads + self.head_dim_ = self.head_v_dim + + assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" + self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + # MTP configuration + self.mtp_step = get_env_start_args().mtp_step + self.mtp_size = self.mtp_step + 1 + + # SSM state dtype optimization + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + start_args = get_env_start_args() + self.ssm_state_dtype = ssm_dtype_dict.get(start_args.mamba_ssm_data_type, torch.bfloat16) + + # Pre-compute whether dtype conversion is needed + # GDN kernel output dtype is self.data_type + # Conversion needed only if SSM state uses different dtype + self.needs_ssm_dtype_conversion = get_llm_data_type() != self.ssm_state_dtype + + # Pre-allocated decode buffers to avoid repeated allocation during CUDA graph replay. + # Buffers are lazily allocated on first decode call, sized to graph_max_batch_size. + self._decode_buffers = {} + self._graph_max_batch_size = start_args.graph_max_batch_size + + # Pre-compute FFN dims for decode buffer pre-allocation + self.tp_gate_up_dim = 2 * self.shared_inter_size // self.tp_world_size_ if self.shared_inter_size > 0 else 0 + + self._bind_func() + return + + def _get_decode_buffer(self, name, max_shape, dtype, device): + """Get or create a pre-allocated buffer for the decode path. + + On first call, allocates a buffer at max_shape. On subsequent calls, + returns the same buffer (caller should slice to actual batch size). + """ + key = (name, dtype, device if isinstance(device, str) else str(device)) + if key not in self._decode_buffers: + self._decode_buffers[key] = torch.empty(max_shape, dtype=dtype, device=device) + return self._decode_buffers[key] + + def _bind_func(self): + """Bind layer-specific implementations""" + self._bind_norm() + self._bind_ffn() + return + + def _bind_norm(self): + """Use Gemma-style RMSNorm""" + self._att_norm = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._att_norm_impl, self) + self._ffn_norm = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_norm_impl, self) + return + + def _bind_ffn(self): + """Bind FFN implementation based on MoE configuration.""" + if self.is_moe: + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_with_shared_expert_tp, self) + else: + self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._standard_ffn, self) + return + + def _ffn_core(self, input, layer_weight, is_decode=False): + """Core FFN computation: gate_up -> silu_and_mul -> down.""" + input = input.view(-1, self.embed_dim_) + if is_decode and self.tp_gate_up_dim > 0: + up_gate_buf = self._get_decode_buffer( + "up_gate_out", + (self._graph_max_batch_size * self.mtp_size, self.tp_gate_up_dim), + input.dtype, + input.device, + )[: input.size(0)] + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input, out=up_gate_buf) + else: + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + inter_dim = up_gate_out.size(1) // 2 + if is_decode: + ffn1_out = self._get_decode_buffer( + "ffn1_out", (self._graph_max_batch_size, inter_dim), input.dtype, input.device + )[: input.size(0)] + else: + ffn1_out = self.alloc_tensor((input.size(0), inter_dim), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + return ffn2_out, input + + def _standard_ffn(self, input, infer_state, layer_weight): + """Standard FFN using shared expert weights (non-MoE layers).""" + ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) + return ffn2_out + + def _compute_shared_expert(self, input, layer_weight, is_decode=False): + """Compute shared expert FFN output with gating.""" + ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) + gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() + ffn2_out.mul_(gate) + return ffn2_out, input_view + + def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (tensor parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert( + input, layer_weight, is_decode=not infer_state.is_prefill + ) + moe_out = self._moe_ffn(input, infer_state, layer_weight) + moe_out.add_(shared_expert_out) + return moe_out + + def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): + """FFN with shared expert + MoE (expert parallelism mode).""" + shared_expert_out, input = self._compute_shared_expert( + input, layer_weight, is_decode=not infer_state.is_prefill + ) + moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) + moe_out.add_(shared_expert_out) + return moe_out + + def _moe_ffn(self, input, infer_state, layer_weight): + """MoE FFN with tensor parallelism.""" + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + if not infer_state.is_prefill: + router_buf = self._get_decode_buffer( + "router_logits", + (self._graph_max_batch_size * self.mtp_size, self.n_routed_experts), + hidden_states.dtype, + hidden_states.device, + )[:num_tokens] + router_logits = layer_weight.moe_gate.mm(hidden_states, out=router_buf) + else: + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + ) + return hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_edp(self, input, infer_state, layer_weight): + """MoE FFN with expert parallelism.""" + hidden_states = input + token_num, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + ep_output = layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + ep_output = ep_output.view(token_num, hidden_dim) + return ep_output + + def _att_norm_impl( + self, + input, + _infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + return self._gemma_norm_with_pool(input, layer_weight.att_norm_weight_.weight) + + def _ffn_norm_impl( + self, + input, + _infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + return self._gemma_norm_with_pool(input, layer_weight.ffn_norm_weight_.weight) + + def _get_qkv( + self, + _input: torch.Tensor, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Not used by GDN - QKV projection handled in gdn_forward. + + GDN uses a fused projection that includes z, b, a parameters + in addition to q, k, v, so the standard template flow doesn't apply. + This method exists to satisfy the template interface. + """ + pass # Implementation in gdn_forward + + def _tpsp_get_qkv( + self, + _input: torch.Tensor, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """TPSP mode not implemented for GDN layers.""" + pass # No TPSP support planned + + def _get_o( + self, + _input, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """ + Not used by GDN - output projection handled in gdn_forward. + + Output computation is fused with GDN recurrence in gdn_forward. + """ + pass # Implementation in gdn_forward + + def _tpsp_get_o( + self, + _input, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """TPSP mode not implemented for GDN layers.""" + pass # No TPSP support planned + + def _context_attention_kernel( + self, + _q: torch.Tensor, + _kv: torch.Tensor, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Not used by GDN - attention computed in gdn_forward.""" + pass # Implementation in gdn_forward + + def _token_attention_kernel( + self, + _q: torch.Tensor, + _infer_state: Qwen3NextInferStateInfo, + _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Not used by GDN - attention computed in gdn_forward.""" + pass # Implementation in gdn_forward + + def _gdn_layer_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + is_prefill: bool, + ): + """Unified forward for both prefill and decode in GDN layers.""" + # Attention + GDN processing + if is_prefill: + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + else: + # Decode: use pre-allocated buffer to avoid alloc_tensor overhead + max_tokens = self._graph_max_batch_size * self.mtp_size + input1 = self._get_decode_buffer( + "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device + )[: input_embdings.shape[0]] + gemma_rmsnorm_forward(input_embdings, layer_weight.att_norm_weight_.weight, self.eps_, out=input1) + + gdn_out = self.gdn_forward(input1, infer_state, layer_weight, is_prefill=is_prefill) + if self.tp_world_size_ > 1: + all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + + # FFN + if is_prefill: + input_embdings.add_(gdn_out.view(-1, self.embed_dim_)) + gdn_out = None + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + else: + # Decode: fused residual add + FFN norm saves 1 kernel + 1 read of input_embdings + input1 = self._get_decode_buffer( + "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device + )[: input_embdings.shape[0]] + fused_add_gemma_rmsnorm( + input_embdings, + gdn_out.view(-1, self.embed_dim_), + layer_weight.ffn_norm_weight_.weight, + self.eps_, + out=input1, + ) + gdn_out = None + + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def context_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Override context_forward to use GDN logic instead of standard attention flow.""" + return self._gdn_layer_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + + def token_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Override token_forward to use GDN logic instead of standard attention flow.""" + return self._gdn_layer_forward(input_embdings, infer_state, layer_weight, is_prefill=False) + + def overlap_tpsp_token_forward( + self, + input_embdings, + input_embdings1, + infer_state: Qwen3NextInferStateInfo, + infer_state1: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Microbatch overlap for decode: process two half-batches sequentially. + Enables --enable_decode_microbatch_overlap for GDN layers.""" + input_embdings = self.token_forward(input_embdings, infer_state, layer_weight) + input_embdings1 = self.token_forward(input_embdings1, infer_state1, layer_weight) + return input_embdings, input_embdings1 + + def overlap_tpsp_context_forward( + self, + input_embdings, + input_embdings1, + infer_state: Qwen3NextInferStateInfo, + infer_state1: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Microbatch overlap for context: process two half-batches sequentially.""" + input_embdings = self.context_forward(input_embdings, infer_state, layer_weight) + input_embdings1 = self.context_forward(input_embdings1, infer_state1, layer_weight) + return input_embdings, input_embdings1 + + # ==================== GDN Helper Methods ==================== + + def _fix_query_key_value_ba_ordering(self, mixed_qkvzba, is_decode=False): + """ + Extract q, k, v, z, b, a from the MM output. + + After weight rearrangement at load time, the MM output is already in grouped layout: + [all_q | all_k | all_v | all_z | all_b | all_a] + so this is just simple slicing — no split+reshape+cat needed. + + Note: + Decode fast-path fused split-copy kernels are intentionally avoided here. + The explicit contiguous slicing path is slower but is more robust and + matches the reference behavior used in vLLM. + """ + qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim + z_end = qkv_dim + self.tp_value_dim + b_end = z_end + self.tp_num_v_heads + + if is_decode: + mixed_qkv = mixed_qkvzba[:, :qkv_dim].contiguous() + z = mixed_qkvzba[:, qkv_dim:z_end].contiguous().view(-1, self.tp_num_v_heads, self.head_v_dim) + b = mixed_qkvzba[:, z_end:b_end].contiguous() + a = mixed_qkvzba[:, b_end:].contiguous() + else: + mixed_qkv = mixed_qkvzba[:, :qkv_dim] + # .reshape() handles non-contiguous slices by copying when needed (unlike .view()) + z = mixed_qkvzba[:, qkv_dim:z_end].reshape(-1, self.tp_num_v_heads, self.head_v_dim) + # b and a must be contiguous: fused_gdn_gating_kernel uses raw pointer arithmetic + # (off = i_b * NUM_HEADS + head_off) that assumes contiguous layout. + # Non-contiguous slices have stride[0]=total_dim, causing wrong reads for i_b > 0. + b = mixed_qkvzba[:, z_end:b_end].contiguous() + a = mixed_qkvzba[:, b_end:].contiguous() + + return mixed_qkv, z, b, a + + def _rearrange_mixed_qkv(self, mixed_qkv, decode=False): + if mixed_qkv is None: + return None, None, None + if decode: + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + batch_size = mixed_qkv.shape[0] + query = query.contiguous().view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + key = key.contiguous().view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + value = value.contiguous().view(batch_size, 1, self.tp_num_v_heads, self.head_v_dim) + return query, key, value + else: + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + seq_len = query.shape[0] + query = query.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + key = key.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim) + return query, key, value + + def context_attention_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + return gdn_out + + def token_attention_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=False) + return gdn_out + + def _gdn_prefill_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Prefill kernel for GDN forward pass.""" + # Conv1D processing + mixed_qkv = mixed_qkv.transpose(0, 1) + out_tensor = causal_conv1d_fn( + mixed_qkv, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + query_start_loc=infer_state.b1_cu_q_seq_len, + cache_indices=infer_state.b_buffer_idx, + has_initial_state=infer_state.b_ready_cache_len > 0, + conv_states=conv_states, + activation=self.activation, + ) + mixed_qkv = out_tensor.transpose(0, 1) + + # Recurrent processing + query, key, value = self._rearrange_mixed_qkv(mixed_qkv) + initial_state = ssm_states[infer_state.b_buffer_idx] + # g and beta have shape (total_tokens, num_heads), need to unsqueeze to get (1, total_tokens, num_heads) + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g.unsqueeze(0), + beta=beta.unsqueeze(0), + initial_state=initial_state, + output_final_state=True, + cu_seqlens=infer_state.b1_cu_q_seq_len, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Use pre-computed dtype conversion flag to avoid runtime check + if self.needs_ssm_dtype_conversion: + ssm_states[infer_state.b_buffer_idx] = last_recurrent_state.to(self.ssm_state_dtype, copy=False) + else: + ssm_states[infer_state.b_buffer_idx] = last_recurrent_state + return core_attn_out + + def _gdn_decode_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Decode kernel for GDN forward pass (single-token, non-MTP mode). + Uses fused gating: g/beta computed inline in the recurrent kernel.""" + # Conv1D processing — mixed_qkv is pre-copied to contiguous buffer + # by _fix_query_key_value_ba_ordering (causal_conv1d_update requires contiguous input) + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + activation=self.activation, + conv_state_indices=infer_state.b_buffer_idx, + ) + + # Recurrent processing with fused gating + # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) + core_attn_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + initial_state=ssm_states, + inplace_final_state=True, + ssm_state_indices=infer_state.b_buffer_idx, + use_qk_l2norm_in_kernel=True, + A_log=layer_weight.linear_A_log.weight, + dt_bias=layer_weight.linear_dt_bias.weight, + a_raw=a, + b_raw=b, + ) + return core_attn_out + + def _gdn_decode_mtp_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """ + Optimized decode kernel for GDN forward pass (MTP mode with multiple steps). + + Key optimizations: + 1. Uses pre-allocated work buffer to avoid per-step .contiguous() allocations + 2. Uses optimized flat Triton kernels for state copying + 3. Direct slice assignment for output instead of .copy_() + + Note: Sequential processing is required because each MTP step depends on + the previous step's final state (both conv and SSM states). + """ + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // self.mtp_size + + # Pre-allocate output tensor + core_attn_out = torch.empty( + (total_tokens, 1, self.tp_num_v_heads, self.head_v_dim), + dtype=mixed_qkv.dtype, + device=mixed_qkv.device, + ) + + # Pre-allocate work buffer for conv1d input (avoids per-step .contiguous()) + qkv_work_buffer = torch.empty( + (batch_size, mixed_qkv.shape[-1]), + dtype=mixed_qkv.dtype, + device=mixed_qkv.device, + ) + + # Process each MTP step sequentially (required due to state dependencies) + for step_idx in range(self.mtp_size): + cur_buffer_idx = infer_state.mtp_buffer_idx_list[step_idx] + + # ========== Conv1D processing ========== + # Copy strided data to contiguous work buffer + qkv_work_buffer.copy_(mixed_qkv[step_idx :: self.mtp_size]) + + # causal_conv1d_update operates in-place on contiguous input + causal_conv1d_update( + qkv_work_buffer, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + activation=self.activation, + conv_state_indices=cur_buffer_idx, + ) + + # ========== Recurrent processing ========== + query_i, key_i, value_i = self._rearrange_mixed_qkv(qkv_work_buffer, decode=True) + g_i = g[step_idx :: self.mtp_size].unsqueeze(1) + beta_i = beta[step_idx :: self.mtp_size].unsqueeze(1) + + core_attn_out_i, _ = fused_recurrent_gated_delta_rule( + q=query_i, + k=key_i, + v=value_i, + g=g_i, + beta=beta_i, + initial_state=ssm_states, + inplace_final_state=True, + ssm_state_indices=cur_buffer_idx, + use_qk_l2norm_in_kernel=True, + ) + + # Direct slice assignment (no .copy_() needed) + core_attn_out[step_idx :: self.mtp_size] = core_attn_out_i + + # ========== State propagation to next step ========== + if step_idx < self.mtp_step: + next_buffer_idx = infer_state.mtp_buffer_idx_list[step_idx + 1] + if conv_states.is_contiguous() and ssm_states.is_contiguous(): + copy_states_fused(conv_states, ssm_states, cur_buffer_idx, next_buffer_idx) + else: + copy_conv_states(conv_states, cur_buffer_idx, next_buffer_idx) + copy_ssm_states(ssm_states, cur_buffer_idx, next_buffer_idx) + + return core_attn_out + + def gdn_forward( + self, + input: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + is_prefill: bool, + ): + assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) + + # Common preprocessing + input = input.view(-1, self.embed_dim_) + conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) + + if not is_prefill: + # Decode: pre-allocate GEMM output to avoid cache tensor manager overhead + in_proj_out_dim = self.tp_qkvz_dim + self.tp_ba_dim + in_proj_out = self._get_decode_buffer( + "in_proj_out", + (self._graph_max_batch_size * self.mtp_size, in_proj_out_dim), + input.dtype, + input.device, + )[: input.shape[0]] + mixed_qkvzba = layer_weight.linear_in_proj.mm(input, out=in_proj_out) + else: + mixed_qkvzba = layer_weight.linear_in_proj.mm(input) + # mixed_qkv is now returned pre-concatenated (no torch.cat needed) + mixed_qkv, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba, is_decode=not is_prefill) + + # Dispatch to appropriate kernel + if is_prefill: + # Prefill: compute g/beta upfront (chunk kernel doesn't support fused gating) + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) + core_attn_out = self._gdn_prefill_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + elif self.mtp_step == 0: + # Decode (non-MTP): fuse gating into recurrent kernel to save 2 kernel launches + core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) + else: + # Decode (MTP): compute g/beta upfront (multiple recurrent calls per step) + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) + core_attn_out = self._gdn_decode_mtp_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + + # Common postprocessing + num_tokens = z.shape[0] # batch (decode) or total_tokens (prefill/MTP) + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + if not is_prefill: + # Decode: use pre-allocated buffer for norm output to avoid alloc_tensor + max_decode_tokens = self._graph_max_batch_size * self.mtp_size + flat_size = max_decode_tokens * self.tp_num_v_heads + norm_out = self._get_decode_buffer( + "gdn_norm_out", + (flat_size, self.head_v_dim), + core_attn_out.dtype, + core_attn_out.device, + )[: core_attn_out.shape[0]] + else: + norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) + gated_rmsnorm_forward( + core_attn_out, + layer_weight.linear_norm.weight, + None, # RMSNormWeight has no bias + self.eps_, + z, + out=norm_out, + ) + # Merge head and value dims in a single view: (num_tokens * HV, V) → (num_tokens, HV * V) + core_attn_out = norm_out.view(num_tokens, -1) + + output = layer_weight.linear_out_proj.mm(core_attn_out) + # Note: all_reduce is handled by context_forward/token_forward callers + return output diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..d4e16555d9 --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -0,0 +1,313 @@ +import torch +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + RMSNormWeight, + TpParameterWeight, + KVROWNMMWeight, +) + + +class Qwen3NextFullAttentionTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_qkv(self): + # Override parent's QKVROWNMMWeight which requires kv_head_num % tp == 0. + # Qwen3-Next has very few KV heads (e.g., 2) so we use separate q + kv weights. + # KVROWNMMWeight handles the kv_head_num < tp_world_size case via repeating. + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + # Define o_gate weight name here (used by _split_q_with_gate during load) + self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + # Fused Q + gate projection: single GEMM outputs [q, gate] concatenated + self.q_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim, q_out_dim], + weight_names=[self._q_weight_name, self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("q_proj"), + ) + self.kv_proj = KVROWNMMWeight( + in_dim=in_dim, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("kv_proj"), + ) + + def _init_weight(self): + super()._init_weight() + # Additional architecture (o_gate is now fused into q_gate_proj in _init_qkv) + self._init_gate_shared_expert_weight() + return + + def load_hf_weights(self, weights): + self._split_q_with_gate(weights) + super().load_hf_weights(weights) + + def _split_q_with_gate(self, weights): + if self._q_weight_name in weights: + weight = weights[self._q_weight_name] + num_heads = self.q_head_num_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self._q_weight_name] = _q_proj + weights[self._o_gate_weight_name] = _gate_proj + + def _init_gate_shared_expert_weight(self): + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + hidden_size = self.network_config_["hidden_size"] + shared_inter = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[shared_inter, shared_inter], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=shared_inter, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + + +class Qwen3NextGatedDeltaNetTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + self.is_moe = ( + network_config["num_experts"] > 0 + and layer_num not in network_config["mlp_only_layers"] + and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + ) + super().__init__(layer_num, data_type, network_config, quant_cfg) + + def _parse_config(self): + super()._parse_config() + self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] + self.linear_num_k_heads = self.network_config_["linear_num_key_heads"] + self.linear_k_head_dim = self.network_config_["linear_key_head_dim"] + self.linear_v_head_dim = self.network_config_["linear_value_head_dim"] + + def _init_weight(self): + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self._init_gdn_weight() + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + if self.is_moe: + self._init_moe() + else: + self._init_ffn() + self._init_gate_shared_expert_weight() + + def _init_gdn_weight(self): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + hidden_size = self.network_config_["hidden_size"] + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + conv1d_channels = qk_dim + qk_dim + v_dim # q + k + v concatenated + kernel_size = self.network_config_.get("linear_conv_kernel_dim", 4) + + # Conv1d weight: after _preprocess_weight, shape is [channels, kernel_size]. + # ROWMMWeight row-slices out_dims (rows), matching TP split of channels dim. + # causal_conv1d_fn expects weight shape (dim, width) = (channels_per_tp, kernel_size). + self.linear_conv1d = ROWMMWeight( + in_dim=kernel_size, + out_dims=[conv1d_channels], + weight_names=f"{prefix}.conv1d.weight", + data_type=self.data_type_, + quant_method=None, + ) + + # in_proj_qkvz: q(qk_dim) + k(qk_dim) + v(v_dim) + z(v_dim) + # in_proj_ba: beta(num_v_heads) + alpha(num_v_heads) — per-head scalars + qkvz_dim = qk_dim + qk_dim + v_dim + v_dim + ba_dim = self.linear_num_v_heads + self.linear_num_v_heads + self.linear_in_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[qkvz_dim, ba_dim], + weight_names=[f"{prefix}.in_proj_qkvz.weight", f"{prefix}.in_proj_ba.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("in_proj_weight"), + ) + + self.linear_out_proj = COLMMWeight( + in_dim=v_dim, + out_dims=[hidden_size], + weight_names=f"{prefix}.out_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("out_proj_weight"), + ) + + split_n_embed = self.linear_num_v_heads // self.tp_world_size_ + self.linear_dt_bias = TpParameterWeight( + weight_name=f"{prefix}.dt_bias", + data_type=torch.float32, + split_n_embed=split_n_embed, + bias_name=None, + weight_shape=(self.linear_num_v_heads,), # Full shape before TP split + bias_shape=None, + ) + + self.linear_A_log = TpParameterWeight( + weight_name=f"{prefix}.A_log", + data_type=torch.float32, + split_n_embed=split_n_embed, + bias_name=None, + weight_shape=(self.linear_num_v_heads,), # Full shape before TP split + bias_shape=None, + ) + + # Norm is applied per-head across head_dim, not across all heads + linear_norm_dim = self.linear_v_head_dim + self.linear_norm = RMSNormWeight( + dim=linear_norm_dim, + weight_name=f"{prefix}.norm.weight", + data_type=self.data_type_, + ) + + def load_hf_weights(self, weights): + self._preprocess_weight(weights) + return super().load_hf_weights(weights) + + def _preprocess_weight(self, weights): + linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" + linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" + if linear_conv1d_weight_name in weights: + # squeeze [channels, 1, kernel] -> [channels, kernel], then rearrange for TP + # Result shape: [channels, kernel_size] — matches causal_conv1d_fn's (dim, width) + weights[linear_conv1d_weight_name] = self._parse_linear_conv1d( + weights[linear_conv1d_weight_name].squeeze(1) + ) + if linear_conv1d_bias_name in weights: + weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name]) + self._rearrange_gdn_in_proj_weights(weights) + + def _rearrange_gdn_in_proj_weights(self, weights): + """Rearrange in_proj_qkvz and in_proj_ba weight rows from interleaved per-k-head layout + to TP-aware grouped layout so that after ROWMMWeight's row-slicing, each rank's + MM output is already [q_chunk, k_chunk, v_chunk, z_chunk, b_chunk, a_chunk]. + + This eliminates the expensive split+reshape+cat in _fix_query_key_value_ba_ordering + at inference time, replacing it with simple slicing. + + The key challenge is that ROWMMWeight slices each weight as a contiguous row chunk + (rows [start:end]). So we arrange the rows such that each TP chunk contains + the grouped layout for that rank: + 1. Deinterleave from per-k-head groups into per-component tensors + 2. Chunk each component by TP + 3. Reassemble as [q_tp0, k_tp0, v_tp0, z_tp0, q_tp1, k_tp1, ...] so row-slicing + gives each rank [q_chunk, k_chunk, v_chunk, z_chunk]. + Same pattern as _parse_linear_conv1d uses for conv1d weights. + """ + num_k = self.linear_num_k_heads + k_dim = self.linear_k_head_dim + v_dim = self.linear_v_head_dim + num_v_per_k = self.linear_num_v_heads // num_k + tp = self.tp_world_size_ + + # Rearrange in_proj_qkvz + qkvz_name = f"model.layers.{self.layer_num_}.linear_attn.in_proj_qkvz.weight" + if qkvz_name in weights: + w = weights[qkvz_name] + hidden = w.shape[-1] + # Each k-head group: q(k_dim) + k(k_dim) + v(num_v_per_k * v_dim) + z(num_v_per_k * v_dim) rows + group_size = k_dim + k_dim + num_v_per_k * v_dim + num_v_per_k * v_dim + w = w.view(num_k, group_size, hidden) + v_block = num_v_per_k * v_dim + all_q = w[:, :k_dim, :].reshape(-1, hidden) # [total_q_dim, H] + all_k = w[:, k_dim : 2 * k_dim, :].reshape(-1, hidden) # [total_k_dim, H] + all_v = w[:, 2 * k_dim : 2 * k_dim + v_block, :].reshape(-1, hidden) # [total_v_dim, H] + all_z = w[:, 2 * k_dim + v_block :, :].reshape(-1, hidden) # [total_v_dim, H] + # Chunk each component by TP, interleave so row-slicing gives grouped layout per rank + q_chunks = all_q.chunk(tp, dim=0) + k_chunks = all_k.chunk(tp, dim=0) + v_chunks = all_v.chunk(tp, dim=0) + z_chunks = all_z.chunk(tp, dim=0) + weights[qkvz_name] = torch.cat( + [torch.cat([q_chunks[i], k_chunks[i], v_chunks[i], z_chunks[i]], dim=0) for i in range(tp)], + dim=0, + ) + + # Rearrange in_proj_ba + ba_name = f"model.layers.{self.layer_num_}.linear_attn.in_proj_ba.weight" + if ba_name in weights: + w = weights[ba_name] + hidden = w.shape[-1] + group_size = 2 * num_v_per_k + w = w.view(num_k, group_size, hidden) + all_b = w[:, :num_v_per_k, :].reshape(-1, hidden) # [total_num_v, H] + all_a = w[:, num_v_per_k:, :].reshape(-1, hidden) # [total_num_v, H] + b_chunks = all_b.chunk(tp, dim=0) + a_chunks = all_a.chunk(tp, dim=0) + weights[ba_name] = torch.cat( + [torch.cat([b_chunks[i], a_chunks[i]], dim=0) for i in range(tp)], + dim=0, + ) + + def _parse_linear_conv1d(self, weight): + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + q_bias, k_bias, v_bias = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0) + q_splits = q_bias.chunk(self.tp_world_size_, dim=0) + k_splits = k_bias.chunk(self.tp_world_size_, dim=0) + v_splits = v_bias.chunk(self.tp_world_size_, dim=0) + new_weight = torch.cat( + [torch.cat([q_splits[i], k_splits[i], v_splits[i]], dim=0) for i in range(self.tp_world_size_)], dim=0 + ) + return new_weight + + def _init_gate_shared_expert_weight(self): + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + hidden_size = self.network_config_["hidden_size"] + shared_inter = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[shared_inter, shared_inter], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=shared_inter, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py new file mode 100644 index 0000000000..7ac7149a06 --- /dev/null +++ b/lightllm/models/qwen3next/mem_manager.py @@ -0,0 +1,72 @@ +import torch +from typing import Tuple +from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager + +logger = init_logger(__name__) + + +class Qwen3NextHybridMemManager(MemoryManager): + def __init__( + self, + full_attn_cache_size, + linear_attn_cache_size, + dtype, + num_kv_heads, + head_dim, + layer_num, + mtp_layer_num, + full_attention_interval: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + max_req_num: int, + always_copy=False, + mem_fraction=0.9, + ): + + self.full_attention_interval = full_attention_interval + assert layer_num % full_attention_interval == 0 + self.layer_num = layer_num + self.mtp_layer_num = mtp_layer_num + self.full_attn_layer_num = layer_num // full_attention_interval + self.linear_attn_layer_num = layer_num - self.full_attn_layer_num + + self.mamba_cache_mem_manager = MambaCacheManager( + linear_attn_cache_size, + self.linear_attn_layer_num, + conv_state_dtype, + conv_state_shape, + ssm_state_dtype, + ssm_state_shape, + ) + + super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., + # None, kv_cache, mtp_kv_cache, mtp_kv_cache] + # Only full attention layers and MTP layers have KV cache. + self.kv_buffer = [None for _ in range(self.layer_num)] + for layer_id in range(self.full_attn_layer_num): + self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( + (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" + ) + for _ in range(self.mtp_layer_num): + self.kv_buffer.append(torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")) + + def free_all(self): + super().free_all() + self.mamba_cache_mem_manager.free_all() + return + + def get_cell_size(self): + # Only full attention layers and MTP layers have KV cache + kv_cache_layer_num = self.full_attn_layer_num + self.mtp_layer_num + return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype) + + def get_mamba_cache(self, layer_idx: int): + layer_idx_in_linear = layer_idx - (layer_idx // self.full_attention_interval) + return self.mamba_cache_mem_manager.get_mamba_cache(layer_idx_in_linear) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py new file mode 100644 index 0000000000..1234a659ed --- /dev/null +++ b/lightllm/models/qwen3next/model.py @@ -0,0 +1,157 @@ +import torch +from typing import Optional +import triton +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( + Qwen3NextFullAttentionTransformerLayerInfer, + Qwen3NextGatedDeltaNetTransformerLayerInfer, +) +from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.common.req_manager import ReqManagerForMamba +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights + +logger = init_logger(__name__) + + +@ModelRegistry("qwen3_next") +class Qwen3NextTpPartModel(Qwen3MOEModel): + + post_layer_infer_class = Qwen3NextPostLayerInfer + infer_state_class = Qwen3NextInferStateInfo + + is_hybrid_attention = True # Indicates model uses hybrid (full + linear) attention + use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states + + @classmethod + def get_radix_cache_class(cls): + from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache + + return HybridRadixCache + + def __init__(self, kvargs) -> None: + self.mem_manager: Qwen3NextHybridMemManager = None + + def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch.Tensor: + return torch.empty(size, device="cuda", dtype=torch.int8) + + # Set Triton allocator for TMA descriptors + # This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py + triton.set_allocator(_triton_allocator) + logger.info("Triton allocator set for Qwen3Next model") + super().__init__(kvargs) + + def autotune_layers(self): + return self.config["full_attention_interval"] + + def _init_config(self): + super()._init_config() + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + + def _init_mem_manager(self): + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + + start_args: StartArgs = get_env_start_args() + mamba_cache_size = start_args.mamba_cache_size + if mamba_cache_size is not None: + assert ( + mamba_cache_size >= start_args.running_max_req_size + ), "mamba_cache_size must be greater than running_max_req_size" + + self.num_linear_k_heads = self.config["linear_num_key_heads"] + self.num_linear_v_heads = self.config["linear_num_value_heads"] + self.head_linear_k_dim = self.config["linear_key_head_dim"] + self.head_linear_v_dim = self.config["linear_value_head_dim"] + + conv_kernel_size = self.config["linear_conv_kernel_dim"] + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) + + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + if start_args.mamba_ssm_data_type not in ssm_dtype_dict: + raise ValueError( + f"Invalid mamba_ssm_data_type: {start_args.mamba_ssm_data_type}." + f" Must be one of {list(ssm_dtype_dict.keys())}" + ) + + self.mem_manager = Qwen3NextHybridMemManager( + full_attn_cache_size=self.max_total_token_num, + linear_attn_cache_size=mamba_cache_size, + dtype=self.data_type, + num_kv_heads=self.num_kv_heads, + head_dim=self.config["head_dim"], + layer_num=self.config["n_layer"], + mtp_layer_num=start_args.mtp_step, + full_attention_interval=self.config["full_attention_interval"], + conv_state_dtype=self.data_type, + conv_state_shape=(conv_dim // self.tp_world_size_, conv_kernel_size - 1), + ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type], + ssm_state_shape=( + self.num_linear_v_heads // self.tp_world_size_, + self.head_linear_k_dim, + self.head_linear_v_dim, + ), + max_req_num=self.max_req_num, + mem_fraction=self.mem_fraction, + ) + + def _init_req_manager(self): + create_max_seq_len = 0 + + if self.batch_max_tokens is not None: + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) + if self.max_seq_length is not None: + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) + + self.req_manager = ReqManagerForMamba(self.max_req_num, create_max_seq_len, self.mem_manager) + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + self.trans_layers_weight = [ + ( + Qwen3NextFullAttentionTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + if (i + 1) % num_full_attention_layers == 0 + else Qwen3NextGatedDeltaNetTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + ) + for i in range(self.config["n_layer"]) + ] + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + + self.layers_infer = [ + ( + Qwen3NextFullAttentionTransformerLayerInfer(i, network_config=self.config) + if (i + 1) % num_full_attention_layers == 0 + else Qwen3NextGatedDeltaNetTransformerLayerInfer(i, network_config=self.config) + ) + for i in range(self.config["n_layer"]) + ] diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py new file mode 100644 index 0000000000..c6d099a2d8 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py @@ -0,0 +1,122 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d.py + +from typing import Optional + +import torch + +from sgl_kernel import causal_conv1d_fwd +from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = -1, + **kwargs, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + causal_conv1d_fwd( + x, + weight, + bias, + conv_states, + query_start_loc, + cache_indices, + has_initial_state, + activation in ["silu", "swish"], + pad_slot_id, + ) + return x + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError(f"activation must be None, silu, or swish, actual: {activation}") + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + causal_conv1d_update_kernel( + x, + conv_state, + weight, + bias, + activation_val, + cache_seqlens, + conv_state_indices, + pad_slot_id, + ) + if unsqueeze: + x = x.squeeze(-1) + return x diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py new file mode 100644 index 0000000000..2bde70bb99 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Adapted from +# https://github.com/vllm-project/vllm diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py new file mode 100644 index 0000000000..cd3b0962a3 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py new file mode 100644 index 0000000000..7b3067bbfb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=64, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=64, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2, + ) + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py new file mode 100644 index 0000000000..97933b2ac2 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp, safe_exp +from lightllm.common.triton_utils.autotuner import autotune + +NUM_WARPS = [2, 4, 8, 16] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + if SAVE_NEW_VALUE: + v_new += ((bos * H + i_h) * V).to(tl.int64) + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v = b_v * safe_exp(b_g_last - b_g)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k1, + mask=(o_k1 < K), + other=0.0, + ) + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k2, + mask=(o_k2 < K), + other=0.0, + ) + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k3, + mask=(o_k3 < K), + other=0.0, + ) + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k4, + mask=(o_k4 < K), + other=0.0, + ) + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v) + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_delta_h_configs(): + return [ + {"BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ] + + +def _get_chunk_delta_h_static_key(k, u, chunk_size): + B, T, Hg, K = k.shape + V = u.shape[-1] + H = u.shape[-2] + return {"H": H, "K": K, "V": V, "BT": chunk_size} + + +def _get_chunk_delta_h_run_key(k, u): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] + + +@autotune( + kernel_name="chunk_gated_delta_rule_fwd_h", + configs_gen_func=_get_chunk_delta_h_configs, + static_key_func=_get_chunk_delta_h_static_key, + run_key_func=_get_chunk_delta_h_run_key, +) +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + run_config=None, +) -> tuple[torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + # Extract config parameters + if run_config is None: + run_config = {"BV": 64, "num_warps": 2, "num_stages": 2} + + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, + ) + return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py new file mode 100644 index 0000000000..fc49763ecd --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp, safe_exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper +from lightllm.common.triton_utils.autotuner import autotune + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_o_configs(): + return [ + {"BK": BK, "BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_o_static_key(q, v, chunk_size): + B, T, Hg, K = q.shape + V = v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + return {"H": H, "K": K, "V": V, "BT": BT} + + +def _get_chunk_o_run_key(q, v): + # Return batch * heads as run key + return q.shape[0] * q.shape[2] + + +@autotune( + kernel_name="chunk_fwd_o", + configs_gen_func=_get_chunk_o_configs, + static_key_func=_get_chunk_o_static_key, + run_key_func=_get_chunk_o_run_key, +) +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor | None = None, # cumsum of log decay + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + run_config=None, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "BV": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, + ) + return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000..60a594c078 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import safe_exp +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * safe_exp(b_g_diff) + + b_A *= b_beta[:, None] + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_scaled_dot_kkt_configs(): + return [ + {"BK": BK, "num_warps": num_warps, "num_stages": num_stages} + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_scaled_dot_kkt_static_key(k, beta, chunk_size=64, cu_seqlens=None): + B, T, Hg, K = k.shape + H = beta.shape[-1] + IS_VARLEN = cu_seqlens is not None + return {"H": H, "K": K, "BT": chunk_size, "IS_VARLEN": IS_VARLEN} + + +def _get_chunk_scaled_dot_kkt_run_key(k, beta): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] + + +@autotune( + kernel_name="chunk_scaled_dot_kkt_fwd", + configs_gen_func=_get_chunk_scaled_dot_kkt_configs, + static_key_func=_get_chunk_scaled_dot_kkt_static_key, + run_key_func=_get_chunk_scaled_dot_kkt_run_key, +) +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + g: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, + run_config=None, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K = k.shape + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + BK=BK, + num_warps=num_warps, + num_stages=num_stages, + ) + return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py new file mode 100644 index 0000000000..6331e1602d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard +from lightllm.common.triton_utils.autotuner import autotune + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + else: + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_cumsum_scalar_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8]] + + +def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first): + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} + + +def _get_cumsum_scalar_run_key(g): + # Return total number of elements as run key + return g.shape[0] * g.shape[1] + + +@autotune( + kernel_name="chunk_local_cumsum_scalar", + configs_gen_func=_get_cumsum_scalar_configs, + static_key_func=_get_cumsum_scalar_static_key, + run_key_func=_get_cumsum_scalar_run_key, +) +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + run_config=None, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"num_warps": 2} + + num_warps = run_config.get("num_warps", 2) + + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=num_warps, + ) + return g + + +def _get_cumsum_vector_configs(): + return [{"BS": BS, "num_warps": num_warps} for BS in BS_LIST for num_warps in [2, 4, 8]] + + +def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first): + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "S": S, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} + + +def _get_cumsum_vector_run_key(g): + # Return batch * heads as run key + return g.shape[0] * g.shape[2] if len(g.shape) == 4 else g.shape[0] + + +@autotune( + kernel_name="chunk_local_cumsum_vector", + configs_gen_func=_get_cumsum_vector_configs, + static_key_func=_get_cumsum_vector_static_key, + run_key_func=_get_cumsum_vector_run_key, +) +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + run_config=None, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"BS": 32, "num_warps": 2} + + BS = run_config.get("BS", 32) + num_warps = run_config.get("num_warps", 2) + + grid = (triton.cdiv(S, BS), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=num_warps, + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + else: + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py new file mode 100644 index 0000000000..22a93a2c99 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -0,0 +1,492 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .op import exp + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + "HAS_SEPARATE_WRITE_INDICES": lambda args: args["ssm_state_write_indices"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + ssm_state_write_indices, # NEW: separate write indices for state propagation optimization + num_accepted_tokens, + # Fused gating parameters (only used when FUSE_GATING=True) + A_log, # [HV] per-head log decay + dt_bias, # [HV] per-head dt bias + a_raw, # [B*T, HV] raw alpha values (before softplus) + b_raw, # [B*T, HV] raw beta values (before sigmoid) + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + stride_write_indices_seq: tl.constexpr, # NEW: stride for write indices + stride_write_indices_tok: tl.constexpr, # NEW: stride for write indices + SOFTPLUS_BETA: tl.constexpr, # softplus beta parameter (default 1.0) + SOFTPLUS_THRESHOLD: tl.constexpr, # softplus threshold (default 20.0) + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, + HAS_SEPARATE_WRITE_INDICES: tl.constexpr, # NEW: whether to use separate write indices + FUSE_GATING: tl.constexpr, # whether to compute g/beta inline from raw values +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if FUSE_GATING: + # Fused gating: load per-head constants once, compute g/beta inline per token + b_A_log = tl.load(A_log + i_hv).to(tl.float32) + b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32) + p_a_raw = a_raw + bos * HV + i_hv + p_b_raw = b_raw + bos * HV + i_hv + else: + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token + ) + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + if FUSE_GATING: + # Compute g = -exp(A_log) * softplus(a_raw + dt_bias) inline + b_a = tl.load(p_a_raw).to(tl.float32) + x = b_a + b_dt_bias + softplus_x = tl.where( + SOFTPLUS_BETA * x <= SOFTPLUS_THRESHOLD, + (1.0 / SOFTPLUS_BETA) * tl.log(1.0 + tl.exp(SOFTPLUS_BETA * x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x + b_h *= exp(b_g) + # Compute beta = sigmoid(b_raw) inline + b_b = tl.load(p_b_raw).to(tl.float32) + b_beta = tl.sigmoid(b_b) + else: + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + # Use separate write indices if provided (for state propagation optimization) + # Otherwise fall back to read indices + if HAS_SEPARATE_WRITE_INDICES: + write_idx = tl.load(ssm_state_write_indices + i_n * stride_write_indices_seq + i_t).to(tl.int64) + else: + write_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) + p_ht = ht + write_idx * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + if FUSE_GATING: + p_a_raw += HV + p_b_raw += HV + else: + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + # Fused gating parameters + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK = triton.next_power_of_2(K) + if T == 1: + # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16) + # and more warps for better SM utilization at T=1 where there's no pipelining benefit + BV = min(triton.next_power_of_2(V), 32) + num_warps = 4 + num_stages = 1 + else: + # Prefill path: small BV for better pipelining across sequence length + BV = min(triton.next_power_of_2(V), 8) + num_warps = 1 + num_stages = 3 + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + + fuse_gating = A_log is not None + + if out is not None: + o = out.unsqueeze(0) if out.ndim == v.ndim else out + else: + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + # Strides for read indices + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + # Strides for write indices (if provided) + if ssm_state_write_indices is None: + stride_write_indices_seq, stride_write_indices_tok = 1, 1 + elif ssm_state_write_indices.ndim == 1: + stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 + else: + stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + ssm_state_write_indices=ssm_state_write_indices, + num_accepted_tokens=num_accepted_tokens, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw, + b_raw=b_raw, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + stride_write_indices_seq=stride_write_indices_seq, + stride_write_indices_tok=stride_write_indices_tok, + SOFTPLUS_BETA=1.0, + SOFTPLUS_THRESHOLD=20.0, + IS_BETA_HEADWISE=False if fuse_gating else (beta.ndim == v.ndim), + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, + FUSE_GATING=fuse_gating, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous() if g is not None else None, + beta=beta.contiguous() if beta is not None else None, + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + ssm_state_write_indices=ssm_state_write_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw.contiguous() if a_raw is not None else None, + b_raw=b_raw.contiguous() if b_raw is not None else None, + out=out, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor = None, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices for state propagation + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + # Fused gating: pass raw values to compute g/beta inline in the kernel + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + fuse_gating = A_log is not None + if not fuse_gating and beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + ssm_state_write_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + A_log, + dt_bias, + a_raw, + b_raw, + out, + ) + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py new file mode 100644 index 0000000000..8b1d59fc63 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py new file mode 100644 index 0000000000..29f892ef26 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +import torch + +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def _get_l2norm_kernel1_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16, 32]] + + +def _get_l2norm_kernel1_static_key(x): + D = x.shape[-1] + return {"D": D} + + +def _get_l2norm_kernel1_run_key(x): + return x.shape[0] # T + + +@autotune( + kernel_name="l2norm_fwd_kernel1", + configs_gen_func=_get_l2norm_kernel1_configs, + static_key_func=_get_l2norm_kernel1_static_key, + run_key_func=_get_l2norm_kernel1_run_key, +) +def _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD, run_config=None): + if run_config is None: + run_config = {"num_warps": 4} + + num_warps = run_config.get("num_warps", 4) + T = x.shape[0] + + l2norm_fwd_kernel1[(T,)](x, y, eps=eps, D=D, BD=BD, num_warps=num_warps) + + +def _get_l2norm_kernel_configs(): + return [{"BT": BT, "num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST] + + +def _get_l2norm_kernel_static_key(x): + D = x.shape[-1] + return {"D": D} + + +def _get_l2norm_kernel_run_key(x): + return x.shape[0] # T + + +@autotune( + kernel_name="l2norm_fwd_kernel", + configs_gen_func=_get_l2norm_kernel_configs, + static_key_func=_get_l2norm_kernel_static_key, + run_key_func=_get_l2norm_kernel_run_key, +) +def _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB, run_config=None): + if run_config is None: + run_config = {"BT": 32, "num_warps": 4} + + BT = run_config.get("BT", 32) + num_warps = run_config.get("num_warps", 4) + + grid = (triton.cdiv(T, BT),) + l2norm_fwd_kernel[grid](x, y, eps, NB=NB, T=T, D=D, BT=BT, BD=BD, num_warps=num_warps) + + +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB) + else: + _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD) + + return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py new file mode 100644 index 0000000000..2f69aa981d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import triton +import triton.language as tl + +from .utils import is_gather_supported + +exp = tl.exp +log = tl.log +log2 = tl.log2 + + +@triton.jit +def safe_exp(x): + """ + Numerically stable exponential function. + Only applies exp to non-positive values, returns 0 for positive values. + This prevents numerical overflow and improves stability. + """ + return exp(tl.where(x <= 0, x, float("-inf"))) + + +if not is_gather_supported: + + @triton.jit + def gather(src, index, axis, _builder=None): + """ + Gather operation that works when tl.gather is not supported. + This is a fallback implementation that returns None. + Just to make triton compiler happy. + """ + return None + +else: + gather = tl.gather + +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, "make_tensor_descriptor"): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py new file mode 100644 index 0000000000..b5b6cfc369 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -0,0 +1,462 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import os +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import make_tensor_descriptor +from .utils import input_guard, is_amd, is_tma_supported + + +def _ensure_triton_allocator(): + """Ensure Triton has an allocator set for kernels requiring scratch memory.""" + + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + +FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") +ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"] +assert ( + FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS +), f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}" + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + A = A + (bos * H + i_h) * BT + Ai = Ai + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + if not USE_TMA: + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + # [16, 16] + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16]) + b_A = desc.load([i_t * 16, offset]).to(tl.float32) + b_A = -tl.where(m_A, b_A, 0) + + for i in range(2, min(16, T - i_t * 16)): + # [16] + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + b_A = tl.where((o_i == i)[:, None], b_a, b_A) + b_A += m_I + if not USE_TMA: + p_Ai = tl.make_block_ptr(Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + + b_Ai_11 += m_I + b_Ai_22 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_A_33 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_A_44 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) + b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) + b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) + b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + for i in range(32 + 2, min(48, T - i_t * BT)): + b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32) + b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) + b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) + for i in range(48 + 2, min(64, T - i_t * BT)): + b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48) + b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) + b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) + b_Ai_11 += m_I + b_Ai_22 += m_I + b_Ai_33 += m_I + b_Ai_44 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) + b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) + b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) + b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) + b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + b_Ai_32 = -tl.dot( + tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), + b_Ai_22, + input_precision=DOT_PRECISION, + ) + b_Ai_43 = -tl.dot( + tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), + b_Ai_33, + input_precision=DOT_PRECISION, + ) + + b_Ai_31 = -tl.dot( + b_Ai_33, + tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_42 = -tl.dot( + b_Ai_44, + tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_41 = -tl.dot( + b_Ai_44, + tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@input_guard +def solve_tril( + A: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + """ + Compute the inverse of the matrix I + A + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, BT], where BT should only be 16, 32, or 64. + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. Default: `None`. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + output_dtype = A.dtype if output_dtype is None else output_dtype + + B, T, H, BT = A.shape + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + Ai = torch.zeros_like(A, dtype=output_dtype) + if BT == 16: + merge_fn = solve_tril_16x16_kernel + elif BT == 32: + merge_fn = merge_16x16_to_32x32_inverse_kernel + elif BT == 64: + merge_fn = merge_16x16_to_64x64_inverse_kernel + + # Ensure Triton allocator is set for TMA kernels that require scratch memory + if is_tma_supported: + _ensure_triton_allocator() + + merge_fn[NT, B * H]( + A=A, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + USE_TMA=is_tma_supported, + DOT_PRECISION=FLA_TRIL_PRECISION, + ) + return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py new file mode 100644 index 0000000000..cd7c2e3aeb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from collections.abc import Callable +from enum import Enum +from typing import Any, Literal + +import torch + +import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[tuple | None, dict | None, Any] = [] + cache_size = 8 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cpu" + + +@functools.cache +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = "cuda" +device_torch_lib = getattr(torch, device, None) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = True +is_gather_supported = hasattr(triton.language, "gather") +is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") +) + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py new file mode 100644 index 0000000000..08bb00e644 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: torch.LongTensor | None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py new file mode 100644 index 0000000000..6413158a66 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py @@ -0,0 +1,186 @@ +import torch + +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _fused_add_gemma_rmsnorm_kernel( + x_ptr, + r_ptr, + w_ptr, + y_ptr, + x_stride0, + x_stride1, + r_stride0, + r_stride1, + y_stride0, + y_stride1, + N: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused in-place residual add + Gemma RMSNorm. + + For each row: + 1. sum = x + residual (written back to x in-place) + 2. rstd = 1 / sqrt(mean(sum²) + eps) + 3. y = sum * rstd * (w + 1.0) (Gemma-style) + """ + row = tl.program_id(0) + x_ptr = x_ptr + row * x_stride0 + r_ptr = r_ptr + row * r_stride0 + y_ptr = y_ptr + row * y_stride0 + + # Pass 1: compute sum = x + residual, write back to x, accumulate sum² for variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + r = tl.load(r_ptr + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) + s = x + r + # Write sum back to x (in-place residual add) + tl.store(x_ptr + cols * x_stride1, s.to(x_ptr.dtype.element_ty), mask=mask) + _var += s * s + + var = tl.sum(_var, axis=0) / N + rstd = 1.0 / tl.sqrt(var + EPS) + + # Pass 2: normalize and apply Gemma-style linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + # Re-read x (now contains sum); hot in L2 from the write in pass 1 + s = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) + y = s * rstd * (w + 1.0) + tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) + + +def _get_fused_add_gemma_rmsnorm_configs(): + """Generate configurations for autotuning fused add + Gemma RMSNorm kernel.""" + configs = [] + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: + for num_warps in [1, 2, 4, 8]: + configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1}) + return configs + + +def _get_fused_add_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + N = x.shape[-1] + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(w.dtype), + "N": N, + } + + +@autotune( + kernel_name="fused_add_gemma_rmsnorm:v1", + configs_gen_func=_get_fused_add_gemma_rmsnorm_configs, + static_key_func=_get_fused_add_gemma_rmsnorm_static_key, + run_key_func=lambda x: x.shape[-1], + mutates_args=["x"], +) +def fused_add_gemma_rmsnorm(x, residual, w, eps, out=None, run_config: dict = None): + """Fused in-place residual add + Gemma RMSNorm. + + x: [M, N] - modified in-place (x += residual) + residual: [M, N] - residual to add (will be viewed as [-1, N]) + w: [N] - norm weight (Gemma-style: applies w + 1.0) + eps: float + out: [M, N] - output buffer (allocated if None) + Returns: out + """ + N = x.shape[-1] + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, N) + r_arg = residual.view(-1, N) + y_arg = y.view(-1, N) + + M = x_arg.shape[0] + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This fused_add_gemma_rmsnorm doesn't support feature dim >= 64KB.") + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} + + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + _fused_add_gemma_rmsnorm_kernel[(M,)]( + x_arg, + r_arg, + w, + y_arg, + x_stride0=x_arg.stride(0), + x_stride1=x_arg.stride(1), + r_stride0=r_arg.stride(0), + r_stride1=r_arg.stride(1), + y_stride0=y_arg.stride(0), + y_stride1=y_arg.stride(1), + N=N, + EPS=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y + + +def _fused_add_gemma_rmsnorm_torch(x, residual, weight, eps): + """Reference implementation for correctness testing.""" + original_dtype = x.dtype + x = x.to(torch.float32) + residual = residual.to(torch.float32) + s = x + residual + normed = s * torch.rsqrt(s.pow(2).mean(-1, keepdim=True) + eps) + out = normed * (1.0 + weight.float()) + return s.to(original_dtype), out.to(original_dtype) + + +def test_fused_add_gemma_rmsnorm(M=128, N=2048, dtype=torch.bfloat16, eps=1e-5, device="cuda"): + """Verify fused kernel matches separate add + gemma_rmsnorm.""" + x_shape = (M, N) + w_shape = (N,) + weight = torch.rand(w_shape, dtype=dtype, device=device) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + residual = 0.1 * torch.randn(x_shape, dtype=dtype, device=device) + + # Clone x for reference (since fused modifies x in-place) + x_ref = x.clone() + x_fused = x.clone() + + # Reference: separate add + norm + x_ref_sum, y_ref = _fused_add_gemma_rmsnorm_torch(x_ref, residual, weight, eps) + + # Fused kernel + y_fused = fused_add_gemma_rmsnorm(x_fused, residual, weight, eps) + + # Check x was modified in-place (x += residual) + print(f"Test: M={M}, N={N}, dtype={dtype}") + print(f" x in-place max delta: {torch.max(torch.abs(x_fused - x_ref_sum)):.6e}") + print(f" output max delta: {torch.max(torch.abs(y_fused - y_ref)):.6e}") + + atol = 1e-2 if dtype == torch.float32 else 5e-2 + assert torch.allclose(x_fused, x_ref_sum, atol=atol, rtol=0), "x in-place update mismatch!" + assert torch.allclose(y_fused, y_ref, atol=atol, rtol=0), "output mismatch!" + print(" PASSED") + + +if __name__ == "__main__": + test_fused_add_gemma_rmsnorm(M=1, N=2048) + test_fused_add_gemma_rmsnorm(M=128, N=2048) + test_fused_add_gemma_rmsnorm(M=1, N=2048, dtype=torch.float16) + test_fused_add_gemma_rmsnorm(M=64, N=4096, dtype=torch.float32) + print("All tests passed!") diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py new file mode 100644 index 0000000000..c816a20013 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -0,0 +1,87 @@ +# Adapted from https://github.com/sgl-project/sglang/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +# beta_output = b.sigmoid() +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_d = tl.program_id(0), tl.program_id(1) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask) + + +def _get_fused_gdn_gating_configs(): + return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [4, 8, 16, 32, 64] for nw in [1, 2, 4]] + + +def _get_fused_gdn_gating_static_key(a: torch.Tensor): + # group by head size and input dtype + return {"NUM_HEADS": a.shape[1], "a_dtype": str(a.dtype)} + + +@autotune( + kernel_name="fused_gdn_gating:v1", + configs_gen_func=_get_fused_gdn_gating_configs, + static_key_func=_get_fused_gdn_gating_static_key, + run_key_func=lambda a: a.shape[0], +) +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + if run_config is None: + run_config = {"BLK_HEADS": 8, "num_warps": 1} + + batch, num_heads = a.shape + grid = (batch, triton.cdiv(num_heads, run_config["BLK_HEADS"])) + g = torch.empty(batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(batch, num_heads, dtype=torch.float32, device=a.device) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + num_heads, + beta, + threshold, + run_config["BLK_HEADS"], + num_warps=run_config["num_warps"], + ) + return g, beta_output diff --git a/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py new file mode 100644 index 0000000000..f37d4911af --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py @@ -0,0 +1,163 @@ +""" +Fused QKV projection and GDN gating computation. + +This kernel fuses: +1. Linear projection (matmul with weight) +2. Output reorganization (split and reshape) +3. Gating computation (g and beta from a, b) + +This reduces kernel launches from 3 to 1 for the QKV+gating path. +""" + +import torch +import triton +import triton.language as tl +from typing import Tuple, Optional +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _fused_gdn_gating_only_kernel( + # Output pointers + g_ptr, + beta_ptr, + # Input pointers + a_ptr, + b_ptr, + A_log_ptr, + dt_bias_ptr, + # Dimensions + batch_size, + num_heads, + # Constants + beta_const: tl.constexpr, + threshold: tl.constexpr, + BLOCK_BATCH: tl.constexpr, + BLOCK_HEADS: tl.constexpr, +): + """ + Fused kernel for GDN gating computation with better memory access patterns. + + Computes: + - g = -exp(A_log) * softplus(a + dt_bias) + - beta = sigmoid(b) + """ + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + + batch_offs = pid_batch * BLOCK_BATCH + tl.arange(0, BLOCK_BATCH) + head_offs = pid_head * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS) + + batch_mask = batch_offs < batch_size + head_mask = head_offs < num_heads + mask = batch_mask[:, None] & head_mask[None, :] + + # Load A_log and dt_bias (broadcast across batch) + A_log = tl.load(A_log_ptr + head_offs, mask=head_mask, other=0.0) + dt_bias = tl.load(dt_bias_ptr + head_offs, mask=head_mask, other=0.0) + + # Load a and b + offs = batch_offs[:, None] * num_heads + head_offs[None, :] + a = tl.load(a_ptr + offs, mask=mask, other=0.0) + b = tl.load(b_ptr + offs, mask=mask, other=0.0) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + x = a.to(tl.float32) + dt_bias.to(tl.float32) + softplus_x = tl.where(beta_const * x <= threshold, (1.0 / beta_const) * tl.log(1.0 + tl.exp(beta_const * x)), x) + g = -tl.exp(A_log.to(tl.float32)) * softplus_x + + # Compute beta = sigmoid(b) + beta_out = tl.sigmoid(b.to(tl.float32)) + + # Store outputs with layout [1, batch, num_heads] + out_offs = batch_offs[:, None] * num_heads + head_offs[None, :] + tl.store(g_ptr + out_offs, g.to(g_ptr.dtype.element_ty), mask=mask) + tl.store(beta_ptr + out_offs, beta_out.to(beta_ptr.dtype.element_ty), mask=mask) + + +def _get_fused_gating_configs(): + """Generate autotuning configurations.""" + configs = [] + for block_batch in [1, 4, 8, 16]: + for block_heads in [8, 16, 32]: + for num_warps in [2, 4, 8]: + configs.append( + { + "BLOCK_BATCH": block_batch, + "BLOCK_HEADS": block_heads, + "num_warps": num_warps, + } + ) + return configs + + +def _get_fused_gating_static_key(a: torch.Tensor): + return {"dtype": str(a.dtype), "num_heads": a.shape[1]} + + +def _get_fused_gating_run_key(a: torch.Tensor): + return a.shape[0] + + +@autotune( + kernel_name="fused_gdn_gating_v2:v1", + configs_gen_func=_get_fused_gating_configs, + static_key_func=_get_fused_gating_static_key, + run_key_func=_get_fused_gating_run_key, + mutates_args=["g", "beta"], +) +def fused_gdn_gating_v2( + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + beta_const: float = 1.0, + threshold: float = 20.0, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Optimized GDN gating with pre-allocated output tensors. + + Args: + a: Input tensor [batch, num_heads] + b: Input tensor [batch, num_heads] + A_log: Log of A parameter [num_heads] + dt_bias: Bias for dt [num_heads] + g: Output tensor [1, batch, num_heads] (pre-allocated) + beta: Output tensor [1, batch, num_heads] (pre-allocated) + beta_const: Beta constant for softplus (default: 1.0) + threshold: Threshold for softplus approximation (default: 20.0) + run_config: Optional autotuning configuration + + Returns: + Tuple of (g, beta) - same tensors passed in, now filled + """ + batch_size, num_heads = a.shape + + if run_config is None: + run_config = {"BLOCK_BATCH": 8, "BLOCK_HEADS": 16, "num_warps": 4} + + grid = ( + triton.cdiv(batch_size, run_config["BLOCK_BATCH"]), + triton.cdiv(num_heads, run_config["BLOCK_HEADS"]), + ) + + _fused_gdn_gating_only_kernel[grid]( + g, + beta, + a, + b, + A_log, + dt_bias, + batch_size, + num_heads, + beta_const, + threshold, + run_config["BLOCK_BATCH"], + run_config["BLOCK_HEADS"], + num_warps=run_config["num_warps"], + ) + + return g, beta diff --git a/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py b/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py new file mode 100644 index 0000000000..5f4433fb34 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py @@ -0,0 +1,400 @@ +""" +Fused Split-Copy Triton Kernels for GDN Decode Path + +Replaces multiple separate .copy_() calls with single kernel launches to reduce +kernel launch overhead in the decode hot path (36 GDN layers per step). + +Kernel 1 (fused_split_copy_qkvzba): 4 copies → 1 kernel + Splits GEMM output [batch, total_dim] into qkv, z, b, a destination buffers. + +Kernel 2 (fused_split_copy_qkv): 3 copies → 1 kernel + Splits conv1d output [batch, qkv_dim] into q, k, v destination buffers. + Handles non-contiguous source (stride(0) != total_dim from column slicing). +""" + +import torch +import triton +import triton.language as tl + + +# ============================================================================= +# Kernel 1: Fused split-copy for qkv, z, b, a from GEMM output +# ============================================================================= + + +@triton.jit +def _fused_split_copy_qkvzba_kernel( + # Source pointer (contiguous GEMM output) + src_ptr, + # Destination pointers (pre-allocated contiguous buffers) + dst_qkv_ptr, + dst_z_ptr, + dst_b_ptr, + dst_a_ptr, + # Row strides + src_stride0, + dst_qkv_stride0, + dst_z_stride0, + dst_b_stride0, + dst_a_stride0, + # Segment boundaries (cumulative): [0, qkv_dim) [qkv_dim, z_end) [z_end, b_end) [b_end, total_dim) + qkv_dim, + z_end, + b_end, + total_dim, + # Block size + BLOCK_N: tl.constexpr, +): + """ + One program per (row, column_block). Loads a BLOCK_N chunk from the source row, + then conditionally stores to the correct destination based on column position. + + Grid: (batch, cdiv(total_dim, BLOCK_N)) + """ + row = tl.program_id(0) + col_block = tl.program_id(1) + + col_start = col_block * BLOCK_N + cols = col_start + tl.arange(0, BLOCK_N) + mask = cols < total_dim + + # Load source chunk + data = tl.load(src_ptr + row * src_stride0 + cols, mask=mask) + + # Store to qkv destination: columns [0, qkv_dim) + qkv_mask = mask & (cols < qkv_dim) + tl.store(dst_qkv_ptr + row * dst_qkv_stride0 + cols, data, mask=qkv_mask) + + # Store to z destination: columns [qkv_dim, z_end) + z_mask = mask & (cols >= qkv_dim) & (cols < z_end) + tl.store(dst_z_ptr + row * dst_z_stride0 + (cols - qkv_dim), data, mask=z_mask) + + # Store to b destination: columns [z_end, b_end) + b_mask = mask & (cols >= z_end) & (cols < b_end) + tl.store(dst_b_ptr + row * dst_b_stride0 + (cols - z_end), data, mask=b_mask) + + # Store to a destination: columns [b_end, total_dim) + a_mask = mask & (cols >= b_end) + tl.store(dst_a_ptr + row * dst_a_stride0 + (cols - b_end), data, mask=a_mask) + + +def fused_split_copy_qkvzba( + src: torch.Tensor, + dst_qkv: torch.Tensor, + dst_z: torch.Tensor, + dst_b: torch.Tensor, + dst_a: torch.Tensor, + qkv_dim: int, + z_dim: int, + b_dim: int, + a_dim: int, +): + """ + Fused split-copy from GEMM output into 4 contiguous destination buffers. + + Replaces: + conv_buf.copy_(mixed_qkvzba[:, :qkv_dim]) + z_buf.view(batch, -1).copy_(mixed_qkvzba[:, qkv_dim:z_end]) + b_buf.copy_(mixed_qkvzba[:, z_end:b_end]) + a_buf.copy_(mixed_qkvzba[:, b_end:]) + + Args: + src: [batch, total_dim] contiguous source (GEMM output) + dst_qkv: [batch, qkv_dim] contiguous destination for conv1d input + dst_z: [batch, z_dim] contiguous destination (z_buf viewed flat) + dst_b: [batch, b_dim] contiguous destination + dst_a: [batch, a_dim] contiguous destination + qkv_dim: width of qkv segment (tp_key_dim * 2 + tp_value_dim) + z_dim: width of z segment (tp_value_dim) + b_dim: width of b segment (tp_num_v_heads) + a_dim: width of a segment (tp_num_v_heads) + """ + total_dim = qkv_dim + z_dim + b_dim + a_dim + z_end = qkv_dim + z_dim + b_end = z_end + b_dim + + batch = src.shape[0] + BLOCK_N = 128 + num_col_blocks = triton.cdiv(total_dim, BLOCK_N) + + grid = (batch, num_col_blocks) + + _fused_split_copy_qkvzba_kernel[grid]( + src, + dst_qkv, + dst_z, + dst_b, + dst_a, + src.stride(0), + dst_qkv.stride(0), + dst_z.stride(0), + dst_b.stride(0), + dst_a.stride(0), + qkv_dim, + z_end, + b_end, + total_dim, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + + +# ============================================================================= +# Kernel 2: Fused split-copy for q, k, v from conv1d output +# ============================================================================= + + +@triton.jit +def _fused_split_copy_qkv_kernel( + # Source pointer (may be non-contiguous column slice) + src_ptr, + # Destination pointers (contiguous buffers) + dst_q_ptr, + dst_k_ptr, + dst_v_ptr, + # Row strides + src_stride0, + dst_q_stride0, + dst_k_stride0, + dst_v_stride0, + # Segment boundaries: [0, q_dim) [q_dim, qk_end) [qk_end, total_dim) + q_dim, + qk_end, + total_dim, + # Block size + BLOCK_N: tl.constexpr, +): + """ + One program per (row, column_block). Loads a BLOCK_N chunk from the source row, + then conditionally stores to q, k, or v destination. + + Supports non-contiguous source via src_stride0 (stride may be > total_dim + when source is a column slice of a larger tensor). + + Grid: (batch, cdiv(total_dim, BLOCK_N)) + """ + row = tl.program_id(0) + col_block = tl.program_id(1) + + col_start = col_block * BLOCK_N + cols = col_start + tl.arange(0, BLOCK_N) + mask = cols < total_dim + + # Load source chunk (use src_stride0 for row advancement) + data = tl.load(src_ptr + row * src_stride0 + cols, mask=mask) + + # Store to q destination: columns [0, q_dim) + q_mask = mask & (cols < q_dim) + tl.store(dst_q_ptr + row * dst_q_stride0 + cols, data, mask=q_mask) + + # Store to k destination: columns [q_dim, qk_end) + k_mask = mask & (cols >= q_dim) & (cols < qk_end) + tl.store(dst_k_ptr + row * dst_k_stride0 + (cols - q_dim), data, mask=k_mask) + + # Store to v destination: columns [qk_end, total_dim) + v_mask = mask & (cols >= qk_end) + tl.store(dst_v_ptr + row * dst_v_stride0 + (cols - qk_end), data, mask=v_mask) + + +def fused_split_copy_qkv( + src: torch.Tensor, + dst_q: torch.Tensor, + dst_k: torch.Tensor, + dst_v: torch.Tensor, + q_dim: int, + k_dim: int, + v_dim: int, + src_stride0: int, +): + """ + Fused split-copy from conv1d output into 3 contiguous q/k/v buffers. + + Replaces: + q_split, k_split, v_split = torch.split(mixed_qkv, [...], dim=-1) + q_buf.view(batch, -1).copy_(q_split) + k_buf.view(batch, -1).copy_(k_split) + v_buf.view(batch, -1).copy_(v_split) + + Args: + src: [batch, total_dim] source tensor (may be non-contiguous if column slice) + dst_q: [batch, q_dim] contiguous destination + dst_k: [batch, k_dim] contiguous destination + dst_v: [batch, v_dim] contiguous destination + q_dim: width of q segment (tp_key_dim) + k_dim: width of k segment (tp_key_dim) + v_dim: width of v segment (tp_value_dim) + src_stride0: row stride of source (may be > q_dim+k_dim+v_dim) + """ + total_dim = q_dim + k_dim + v_dim + qk_end = q_dim + k_dim + + batch = src.shape[0] + BLOCK_N = 128 + num_col_blocks = triton.cdiv(total_dim, BLOCK_N) + + grid = (batch, num_col_blocks) + + _fused_split_copy_qkv_kernel[grid]( + src, + dst_q, + dst_k, + dst_v, + src_stride0, + dst_q.stride(0), + dst_k.stride(0), + dst_v.stride(0), + q_dim, + qk_end, + total_dim, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + + +# ============================================================================= +# Test / Verification +# ============================================================================= + + +def test_fused_split_copy(): + """Verify fused kernels produce identical results to separate .copy_() calls.""" + torch.manual_seed(42) + device = "cuda" + dtype = torch.bfloat16 + + print("=" * 60) + print("Testing fused_split_copy_qkvzba") + print("=" * 60) + + # Typical dimensions for Qwen3-Coder-Next with TP=4 + # tp_key_dim=128, tp_value_dim=256, tp_num_v_heads=2 + qkv_dim = 128 + 128 + 256 # q + k + v = 512 + z_dim = 256 + b_dim = 2 + a_dim = 2 + total_dim = qkv_dim + z_dim + b_dim + a_dim # 772 + + for batch in [1, 4, 8, 32]: + src = torch.randn(batch, total_dim, dtype=dtype, device=device) + + # Reference: separate copies + ref_qkv = src[:, :qkv_dim].clone() + ref_z = src[:, qkv_dim : qkv_dim + z_dim].clone() + ref_b = src[:, qkv_dim + z_dim : qkv_dim + z_dim + b_dim].clone() + ref_a = src[:, qkv_dim + z_dim + b_dim :].clone() + + # Fused kernel + dst_qkv = torch.empty(batch, qkv_dim, dtype=dtype, device=device) + dst_z = torch.empty(batch, z_dim, dtype=dtype, device=device) + dst_b = torch.empty(batch, b_dim, dtype=dtype, device=device) + dst_a = torch.empty(batch, a_dim, dtype=dtype, device=device) + fused_split_copy_qkvzba(src, dst_qkv, dst_z, dst_b, dst_a, qkv_dim, z_dim, b_dim, a_dim) + + assert torch.equal(dst_qkv, ref_qkv), f"qkv mismatch at batch={batch}" + assert torch.equal(dst_z, ref_z), f"z mismatch at batch={batch}" + assert torch.equal(dst_b, ref_b), f"b mismatch at batch={batch}" + assert torch.equal(dst_a, ref_a), f"a mismatch at batch={batch}" + print(f" batch={batch:3d}: PASS") + + print() + print("=" * 60) + print("Testing fused_split_copy_qkv") + print("=" * 60) + + q_dim = 128 + k_dim = 128 + v_dim = 256 + qkv_dim = q_dim + k_dim + v_dim # 512 + + for batch in [1, 4, 8, 32]: + # Test with contiguous source + src = torch.randn(batch, qkv_dim, dtype=dtype, device=device) + + ref_q = src[:, :q_dim].clone() + ref_k = src[:, q_dim : q_dim + k_dim].clone() + ref_v = src[:, q_dim + k_dim :].clone() + + dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) + dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) + dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) + fused_split_copy_qkv(src, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src.stride(0)) + + assert torch.equal(dst_q, ref_q), f"q mismatch at batch={batch} (contiguous)" + assert torch.equal(dst_k, ref_k), f"k mismatch at batch={batch} (contiguous)" + assert torch.equal(dst_v, ref_v), f"v mismatch at batch={batch} (contiguous)" + print(f" batch={batch:3d} (contiguous src): PASS") + + # Test with non-contiguous source (column slice of wider tensor) + wider = torch.randn(batch, qkv_dim + 64, dtype=dtype, device=device) + src_nc = wider[:, :qkv_dim] # Non-contiguous: stride(0) = qkv_dim + 64 + assert src_nc.stride(0) == qkv_dim + 64, "expected non-contiguous slice" + + ref_q = src_nc[:, :q_dim].clone() + ref_k = src_nc[:, q_dim : q_dim + k_dim].clone() + ref_v = src_nc[:, q_dim + k_dim :].clone() + + dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) + dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) + dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) + fused_split_copy_qkv(src_nc, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src_nc.stride(0)) + + assert torch.equal(dst_q, ref_q), f"q mismatch at batch={batch} (non-contiguous)" + assert torch.equal(dst_k, ref_k), f"k mismatch at batch={batch} (non-contiguous)" + assert torch.equal(dst_v, ref_v), f"v mismatch at batch={batch} (non-contiguous)" + print(f" batch={batch:3d} (non-contiguous src): PASS") + + print() + print("=" * 60) + print("Testing edge cases") + print("=" * 60) + + # Edge case: different dimension ratios (small q/k, large v) + q_dim, k_dim, v_dim = 32, 32, 512 + qkv_dim = q_dim + k_dim + v_dim + batch = 2 + src = torch.randn(batch, qkv_dim, dtype=dtype, device=device) + + dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) + dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) + dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) + fused_split_copy_qkv(src, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src.stride(0)) + + assert torch.equal(dst_q, src[:, :q_dim]) + assert torch.equal(dst_k, src[:, q_dim : q_dim + k_dim]) + assert torch.equal(dst_v, src[:, q_dim + k_dim :]) + print(" asymmetric dims (32, 32, 512): PASS") + + # Edge case: float32 dtype + src_f32 = torch.randn(4, 772, dtype=torch.float32, device=device) + dst_qkv = torch.empty(4, 512, dtype=torch.float32, device=device) + dst_z = torch.empty(4, 256, dtype=torch.float32, device=device) + dst_b = torch.empty(4, 2, dtype=torch.float32, device=device) + dst_a = torch.empty(4, 2, dtype=torch.float32, device=device) + fused_split_copy_qkvzba(src_f32, dst_qkv, dst_z, dst_b, dst_a, 512, 256, 2, 2) + + assert torch.equal(dst_qkv, src_f32[:, :512]) + assert torch.equal(dst_z, src_f32[:, 512:768]) + assert torch.equal(dst_b, src_f32[:, 768:770]) + assert torch.equal(dst_a, src_f32[:, 770:]) + print(" float32 dtype: PASS") + + # Edge case: float16 dtype + src_f16 = torch.randn(4, 772, dtype=torch.float16, device=device) + dst_qkv = torch.empty(4, 512, dtype=torch.float16, device=device) + dst_z = torch.empty(4, 256, dtype=torch.float16, device=device) + dst_b = torch.empty(4, 2, dtype=torch.float16, device=device) + dst_a = torch.empty(4, 2, dtype=torch.float16, device=device) + fused_split_copy_qkvzba(src_f16, dst_qkv, dst_z, dst_b, dst_a, 512, 256, 2, 2) + + assert torch.equal(dst_qkv, src_f16[:, :512]) + assert torch.equal(dst_z, src_f16[:, 512:768]) + assert torch.equal(dst_b, src_f16[:, 768:770]) + assert torch.equal(dst_a, src_f16[:, 770:]) + print(" float16 dtype: PASS") + + print() + print("All tests passed!") + + +if __name__ == "__main__": + test_fused_split_copy() diff --git a/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py new file mode 100644 index 0000000000..89db5e00cb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py @@ -0,0 +1,174 @@ +import triton +import triton.language as tl +import torch +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + } +) +@triton.jit +def gated_rmsnorm_forward_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch (required, not optional) + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + Z += row * stride_z_row + group * N + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute variance (RMS norm doesn't use mean) + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + # RMS norm: compute variance directly without mean subtraction + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + # RMS norm: normalize without mean subtraction + x_hat = x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _get_gated_rmsnorm_configs(): + """Generate configurations for autotuning gated RMSNorm kernel.""" + configs = [] + # Different BLOCK_N sizes (powers of 2) + for block_n in [64, 128, 256, 512, 1024, 2048, 4096]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + # Skip configurations that are likely to be inefficient + if block_n >= 2048 and num_warps > 4: + continue + if block_n <= 128 and num_warps > 2: + continue + configs.append({"BLOCK_N": block_n, "num_warps": num_warps}) + return configs + + +def _get_gated_rmsnorm_static_key(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + M, N = x.shape + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(weight.dtype), + "N": N, + "has_bias": bias is not None, + } + + +@autotune( + kernel_name="gated_rmsnorm_forward:v1", + configs_gen_func=_get_gated_rmsnorm_configs, + static_key_func=_get_gated_rmsnorm_static_key, + run_key_func=lambda x: x.shape[0], +) +def gated_rmsnorm_forward( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + run_config: dict = None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + # z is required for gated_rmsnorm + assert z is not None, "z cannot be None for gated_rmsnorm_forward" + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + # For RMS norm, we still need rstd for the kernel + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + run_config = {"BLOCK_N": BLOCK_N, "num_warps": num_warps} + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + + # Validate BLOCK_N against group_size + if group_size > BLOCK_N: + # Fall back to largest valid BLOCK_N + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + grid = (M, ngroups) + gated_rmsnorm_forward_kernel[grid]( + x, + out, + weight, + bias, + z, + rstd, + x.stride(0), + out.stride(0), + z.stride(0), + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + num_warps=num_warps, + ) + return out diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py new file mode 100644 index 0000000000..5a39debaa9 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py @@ -0,0 +1,1333 @@ +""" +Optimized GDN Decode MTP (Multi-Token Prediction) Kernel + +This module provides an optimized Triton kernel for GDN decode with MTP support, +eliminating the need for sequential Python loops and reducing memory operations. + +Key optimizations: +1. Fused data reorganization from interleaved to batched layout +2. Parallel processing of all batch items with proper state indexing +3. Auto-tuned configurations for different batch sizes and model dimensions +""" + +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _reorganize_mtp_data_kernel( + # Input pointers (interleaved layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...]) + src_ptr, + # Output pointers (batched layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...]) + dst_ptr, + # Dimensions + batch_size, + mtp_size, + dim_size, + # Strides + src_stride_token, + src_stride_dim, + dst_stride_token, + dst_stride_dim, + # Block sizes + BLOCK_DIM: tl.constexpr, +): + """ + Reorganize data from interleaved MTP layout to batched layout. + + Input layout: [step0_batch0, step0_batch1, ..., step0_batchN, step1_batch0, ...] + Output layout: [batch0_step0, batch0_step1, ..., batch0_stepM, batch1_step0, ...] + + This enables efficient processing with the recurrent kernel. + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + block_dim_idx = tl.program_id(2) + + # Calculate source and destination token indices + src_token_idx = step_idx * batch_size + batch_idx + dst_token_idx = batch_idx * mtp_size + step_idx + + # Calculate dimension offsets + dim_start = block_dim_idx * BLOCK_DIM + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + mask = dim_offsets < dim_size + + # Load from source (interleaved layout) + src_offset = src_token_idx * src_stride_token + dim_offsets * src_stride_dim + data = tl.load(src_ptr + src_offset, mask=mask, other=0.0) + + # Store to destination (batched layout) + dst_offset = dst_token_idx * dst_stride_token + dim_offsets * dst_stride_dim + tl.store(dst_ptr + dst_offset, data, mask=mask) + + +@triton.jit +def _reorganize_mtp_data_back_kernel( + # Input pointers (batched layout): [batch_size, mtp_size, num_heads, head_dim] + src_ptr, + # Output pointers (interleaved layout): [total_tokens, 1, num_heads, head_dim] + dst_ptr, + # Dimensions + batch_size, + mtp_size, + num_heads, + head_dim, + # Strides for src: [batch_size, mtp_size, num_heads, head_dim] + src_stride_batch, + src_stride_mtp, + src_stride_head, + src_stride_dim, + # Strides for dst: [total_tokens, 1, num_heads, head_dim] + dst_stride_token, + dst_stride_seq, + dst_stride_head, + dst_stride_dim, + # Block sizes + BLOCK_HEAD: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """ + Reorganize output data from batched layout back to interleaved layout. + + Input shape: [batch_size, mtp_size, num_heads, head_dim] + Output shape: [batch_size * mtp_size, 1, num_heads, head_dim] (interleaved) + + Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + block_idx = tl.program_id(2) + + # Decompose block_idx into head and dim blocks + num_dim_blocks = tl.cdiv(head_dim, BLOCK_DIM) + block_head_idx = block_idx // num_dim_blocks + block_dim_idx = block_idx % num_dim_blocks + + # Calculate destination token index (interleaved) + dst_token_idx = step_idx * batch_size + batch_idx + + # Calculate offsets + head_start = block_head_idx * BLOCK_HEAD + dim_start = block_dim_idx * BLOCK_DIM + + head_offsets = head_start + tl.arange(0, BLOCK_HEAD) + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + mask = head_mask[:, None] & dim_mask[None, :] + + # Load from source (batched layout): [batch_size, mtp_size, num_heads, head_dim] + src_base = src_ptr + batch_idx * src_stride_batch + step_idx * src_stride_mtp + src_offset = head_offsets[:, None] * src_stride_head + dim_offsets[None, :] * src_stride_dim + data = tl.load(src_base + src_offset, mask=mask, other=0.0) + + # Store to destination (interleaved layout): [total_tokens, 1, num_heads, head_dim] + # The seq dimension (1) is skipped since it's always 0 + dst_base = dst_ptr + dst_token_idx * dst_stride_token + dst_offset = head_offsets[:, None] * dst_stride_head + dim_offsets[None, :] * dst_stride_dim + tl.store(dst_base + dst_offset, data, mask=mask) + + +def _get_reorganize_mtp_configs(): + """Generate candidate configurations for MTP data reorganization.""" + configs = [] + for block_dim in [64, 128, 256, 512]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_reorganize_static_key(src: torch.Tensor, mtp_size: int): + """Static key based on tensor properties.""" + return { + "dtype": str(src.dtype), + "mtp_size": mtp_size, + } + + +def _get_reorganize_run_key(src: torch.Tensor, mtp_size: int): + """Run key based on batch size and dimension.""" + total_tokens = src.shape[0] + batch_size = total_tokens // mtp_size + dim_size = src.shape[-1] + return f"{batch_size}_{dim_size}" + + +@autotune( + kernel_name="gdn_decode_mtp_reorganize:v1", + configs_gen_func=_get_reorganize_mtp_configs, + static_key_func=_get_reorganize_static_key, + run_key_func=_get_reorganize_run_key, + mutates_args=["dst"], +) +def reorganize_mtp_to_batched( + src: torch.Tensor, + dst: torch.Tensor, + mtp_size: int, + run_config: dict = None, +): + """ + Reorganize data from interleaved MTP layout to batched layout. + + Args: + src: Input tensor with interleaved layout [total_tokens, dim] + Layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...] + dst: Output tensor with batched layout [total_tokens, dim] + Layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...] + mtp_size: Number of MTP steps + run_config: Auto-tuned configuration + """ + total_tokens = src.shape[0] + batch_size = total_tokens // mtp_size + dim_size = src.shape[-1] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) + + grid = (batch_size, mtp_size, num_blocks_dim) + + _reorganize_mtp_data_kernel[grid]( + src, + dst, + batch_size, + mtp_size, + dim_size, + src.stride(0), + src.stride(-1) if src.ndim > 1 else 1, + dst.stride(0), + dst.stride(-1) if dst.ndim > 1 else 1, + BLOCK_DIM=BLOCK_DIM, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_reorganize_back_configs(): + """Generate candidate configurations for MTP output reorganization.""" + configs = [] + for block_head in [4, 8, 16, 32]: + for block_dim in [32, 64, 128]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3]: + if block_head * block_dim <= 4096: # Limit shared memory + configs.append( + { + "BLOCK_HEAD": block_head, + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_reorganize_back_static_key( + src: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, +): + """Static key for output reorganization.""" + return { + "dtype": str(src.dtype), + "mtp_size": mtp_size, + "num_heads": num_heads, + "head_dim": head_dim, + } + + +def _get_reorganize_back_run_key( + src: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, +): + """Run key for output reorganization.""" + return batch_size + + +@autotune( + kernel_name="gdn_decode_mtp_reorganize_back:v1", + configs_gen_func=_get_reorganize_back_configs, + static_key_func=_get_reorganize_back_static_key, + run_key_func=_get_reorganize_back_run_key, + mutates_args=["dst"], +) +def reorganize_mtp_output_to_interleaved( + src: torch.Tensor, + dst: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, + run_config: dict = None, +): + """ + Reorganize output from batched layout back to interleaved layout. + + Args: + src: Input tensor [batch_size, mtp_size, num_heads, head_dim] (4D) + dst: Output tensor [batch_size * mtp_size, 1, num_heads, head_dim] (4D) + batch_size: Number of batch items + mtp_size: Number of MTP steps + num_heads: Number of attention heads + head_dim: Head dimension + run_config: Auto-tuned configuration + + Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] + """ + if run_config is None: + BLOCK_HEAD = min(triton.next_power_of_2(num_heads), 16) + BLOCK_DIM = min(triton.next_power_of_2(head_dim), 64) + num_warps = 4 + num_stages = 2 + else: + BLOCK_HEAD = run_config["BLOCK_HEAD"] + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_head_blocks = triton.cdiv(num_heads, BLOCK_HEAD) + num_dim_blocks = triton.cdiv(head_dim, BLOCK_DIM) + num_blocks_total = num_head_blocks * num_dim_blocks + + grid = (batch_size, mtp_size, num_blocks_total) + + # src is 4D: [batch_size, mtp_size, num_heads, head_dim] + # dst is 4D: [total_tokens, 1, num_heads, head_dim] + _reorganize_mtp_data_back_kernel[grid]( + src, + dst, + batch_size, + mtp_size, + num_heads, + head_dim, + src.stride(0), # batch stride + src.stride(1), # mtp stride + src.stride(2), # head stride + src.stride(3), # dim stride + dst.stride(0), # token stride + dst.stride(1), # seq stride (=1) + dst.stride(2), # head stride + dst.stride(3), # dim stride + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_DIM=BLOCK_DIM, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@triton.jit +def _prepare_mtp_indices_kernel( + # Input indices (per-step buffer indices) + buffer_idx_ptr, + # Output 2D indices for recurrent kernel + output_idx_ptr, + # Dimensions + batch_size, + mtp_size, + # Strides + input_stride, + output_stride_batch, + output_stride_step, +): + """ + Prepare 2D indices for the fused recurrent kernel. + + Input: mtp_size tensors of shape [batch_size] (buffer indices for each step) + Output: 2D tensor [batch_size, mtp_size] for ssm_state_indices + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + + # Load the buffer index for this batch and step + buffer_idx = tl.load(buffer_idx_ptr + step_idx * input_stride + batch_idx) + + # Store to the 2D output + output_offset = batch_idx * output_stride_batch + step_idx * output_stride_step + tl.store(output_idx_ptr + output_offset, buffer_idx) + + +def prepare_mtp_state_indices( + mtp_buffer_idx_list: list, + batch_size: int, + device: torch.device, +) -> torch.Tensor: + """ + Prepare 2D state indices for the fused recurrent kernel. + + Args: + mtp_buffer_idx_list: List of buffer index tensors, one per MTP step + batch_size: Number of batch items + device: Target device + + Returns: + 2D tensor of shape [batch_size, mtp_size] for ssm_state_indices + """ + + # Stack indices to create [mtp_size, batch_size] tensor + stacked_indices = torch.stack(mtp_buffer_idx_list, dim=0) + + # Transpose to get [batch_size, mtp_size] + return stacked_indices.T.contiguous() + + +@triton.jit +def _fused_conv1d_mtp_step_kernel( + # Input/output data + mixed_qkv_ptr, + # Conv state buffer + conv_states_ptr, + # Conv weight and bias + conv_weight_ptr, + conv_bias_ptr, + # Buffer indices (one per MTP step, each [batch_size]) + buffer_indices_ptr, + next_buffer_indices_ptr, + # Dimensions + batch_size, + dim_size, + conv_width, + # Step info + step_idx, + mtp_size, + is_last_step: tl.constexpr, + # Strides + qkv_stride_token, + qkv_stride_dim, + state_stride_buffer, + state_stride_dim, + state_stride_width, + weight_stride_dim, + weight_stride_width, + # Block sizes + BLOCK_DIM: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, +): + """ + Fused kernel for conv1d update in MTP decode. + + Handles one MTP step for all batch items: + 1. Reads current conv state + 2. Updates with new input + 3. Computes conv1d output + 4. Optionally copies state to next MTP step + """ + batch_idx = tl.program_id(0) + block_dim_idx = tl.program_id(1) + + # Calculate token index in interleaved layout + token_idx = step_idx * batch_size + batch_idx + + # Load buffer indices + cur_buffer_idx = tl.load(buffer_indices_ptr + batch_idx).to(tl.int64) + + # Calculate dimension offsets + dim_start = block_dim_idx * BLOCK_DIM + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + dim_mask = dim_offsets < dim_size + + # Load input value + input_offset = token_idx * qkv_stride_token + dim_offsets * qkv_stride_dim + input_val = tl.load(mixed_qkv_ptr + input_offset, mask=dim_mask, other=0.0) + + # Load conv bias + bias_val = tl.load(conv_bias_ptr + dim_offsets, mask=dim_mask, other=0.0) + + # Compute conv1d output and update state + output_val = bias_val + state_base = conv_states_ptr + cur_buffer_idx * state_stride_buffer + + # Process each position in the conv window + for w in range(conv_width): + # Load weight for this position + weight_offset = dim_offsets * weight_stride_dim + w * weight_stride_width + weight_val = tl.load(conv_weight_ptr + weight_offset, mask=dim_mask, other=0.0) + + if w < conv_width - 1: + # Load from state buffer + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + state_val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) + output_val += state_val * weight_val + else: + # Use current input for the last position + output_val += input_val * weight_val + + # Update conv state (shift and insert new value) + for w in range(conv_width - 2, -1, -1): + if w == conv_width - 2: + # Insert new input at the end + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + tl.store(state_base + state_offset, input_val, mask=dim_mask) + else: + # Shift state + src_offset = dim_offsets * state_stride_dim + (w + 1) * state_stride_width + dst_offset = dim_offsets * state_stride_dim + w * state_stride_width + val = tl.load(state_base + src_offset, mask=dim_mask, other=0.0) + tl.store(state_base + dst_offset, val, mask=dim_mask) + + # Apply activation (SiLU) + if ACTIVATION_SILU: + output_val = output_val * tl.sigmoid(output_val) + + # Store output + tl.store(mixed_qkv_ptr + input_offset, output_val, mask=dim_mask) + + # Copy state to next step if not last + if not is_last_step: + next_buffer_idx = tl.load(next_buffer_indices_ptr + batch_idx).to(tl.int64) + next_state_base = conv_states_ptr + next_buffer_idx * state_stride_buffer + + for w in range(conv_width - 1): + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) + tl.store(next_state_base + state_offset, val, mask=dim_mask) + + +def _get_conv1d_mtp_configs(): + """Generate candidate configurations for conv1d MTP kernel.""" + configs = [] + for block_dim in [64, 128, 256, 512]: + for num_warps in [2, 4, 8]: + for num_stages in [1, 2, 3]: + configs.append( + { + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_conv1d_mtp_static_key( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + mtp_size: int, +): + """Static key for conv1d MTP kernel.""" + return { + "dtype": str(mixed_qkv.dtype), + "dim_size": mixed_qkv.shape[-1], + "conv_width": conv_weight.shape[-1], + "mtp_size": mtp_size, + } + + +def _get_conv1d_mtp_run_key( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + mtp_size: int, +): + """Run key for conv1d MTP kernel.""" + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // mtp_size + return batch_size + + +@autotune( + kernel_name="gdn_conv1d_mtp:v1", + configs_gen_func=_get_conv1d_mtp_configs, + static_key_func=_get_conv1d_mtp_static_key, + run_key_func=_get_conv1d_mtp_run_key, + mutates_args=["mixed_qkv", "conv_states"], +) +def fused_conv1d_mtp_update( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + mtp_buffer_idx_list: list, + mtp_size: int, + activation_silu: bool = True, + run_config: dict = None, +): + """ + Fused conv1d update for all MTP steps. + + Args: + mixed_qkv: Input tensor [batch_size * mtp_size, dim] (interleaved) + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] + conv_weight: Conv weights [dim, conv_width] + conv_bias: Conv bias [dim] + mtp_buffer_idx_list: List of buffer index tensors per step + mtp_size: Number of MTP steps + activation_silu: Whether to apply SiLU activation + run_config: Auto-tuned configuration + """ + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // mtp_size + dim_size = mixed_qkv.shape[-1] + conv_width = conv_weight.shape[-1] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) + + for step_idx in range(mtp_size): + is_last_step = step_idx == mtp_size - 1 + cur_indices = mtp_buffer_idx_list[step_idx] + next_indices = mtp_buffer_idx_list[step_idx + 1] if not is_last_step else cur_indices + + grid = (batch_size, num_blocks_dim) + + _fused_conv1d_mtp_step_kernel[grid]( + mixed_qkv, + conv_states, + conv_weight, + conv_bias, + cur_indices, + next_indices, + batch_size, + dim_size, + conv_width, + step_idx, + mtp_size, + is_last_step, + mixed_qkv.stride(0), + mixed_qkv.stride(-1) if mixed_qkv.ndim > 1 else 1, + conv_states.stride(0), + conv_states.stride(1), + conv_states.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + BLOCK_DIM=BLOCK_DIM, + ACTIVATION_SILU=activation_silu, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@triton.jit +def _copy_ssm_state_kernel( + # SSM state buffer + ssm_states_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + num_heads, + key_dim, + value_dim, + # Strides + state_stride_buffer, + state_stride_head, + state_stride_key, + state_stride_value, + # Block sizes + BLOCK_KEY: tl.constexpr, + BLOCK_VALUE: tl.constexpr, +): + """ + Copy SSM states from source indices to destination indices. + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + block_idx = tl.program_id(2) + + # Calculate block positions + num_value_blocks = tl.cdiv(value_dim, BLOCK_VALUE) + block_key_idx = block_idx // num_value_blocks + block_value_idx = block_idx % num_value_blocks + + key_start = block_key_idx * BLOCK_KEY + value_start = block_value_idx * BLOCK_VALUE + + key_offsets = key_start + tl.arange(0, BLOCK_KEY) + value_offsets = value_start + tl.arange(0, BLOCK_VALUE) + + key_mask = key_offsets < key_dim + value_mask = value_offsets < value_dim + mask = key_mask[:, None] & value_mask[None, :] + + # Load indices + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate offsets + src_base = ssm_states_ptr + src_idx * state_stride_buffer + head_idx * state_stride_head + dst_base = ssm_states_ptr + dst_idx * state_stride_buffer + head_idx * state_stride_head + + offsets = key_offsets[:, None] * state_stride_key + value_offsets[None, :] * state_stride_value + + # Copy data + data = tl.load(src_base + offsets, mask=mask, other=0.0) + tl.store(dst_base + offsets, data, mask=mask) + + +@triton.jit +def _copy_conv_state_kernel( + # Conv state buffer [num_buffers, dim, conv_width-1] + conv_states_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + dim_size, + width_size, + num_width_blocks, # Precomputed to avoid runtime division + # Strides + state_stride_buffer, + state_stride_dim, + state_stride_width, + # Block sizes + BLOCK_DIM: tl.constexpr, + BLOCK_WIDTH: tl.constexpr, +): + """ + Copy conv states from source indices to destination indices. + + Conv state shape: [num_buffers, dim, conv_width-1] + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Calculate block positions using precomputed num_width_blocks + block_dim_idx = block_idx // num_width_blocks + block_width_idx = block_idx % num_width_blocks + + dim_start = block_dim_idx * BLOCK_DIM + width_start = block_width_idx * BLOCK_WIDTH + + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + width_offsets = width_start + tl.arange(0, BLOCK_WIDTH) + + dim_mask = dim_offsets < dim_size + width_mask = width_offsets < width_size + mask = dim_mask[:, None] & width_mask[None, :] + + # Load indices + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate offsets + src_base = conv_states_ptr + src_idx * state_stride_buffer + dst_base = conv_states_ptr + dst_idx * state_stride_buffer + + offsets = dim_offsets[:, None] * state_stride_dim + width_offsets[None, :] * state_stride_width + + # Copy data + data = tl.load(src_base + offsets, mask=mask, other=0.0) + tl.store(dst_base + offsets, data, mask=mask) + + +def _get_conv_copy_configs(): + """Generate candidate configurations for conv state copy.""" + configs = [] + for block_dim in [64, 128, 256]: + for block_width in [2, 4, 8]: + for num_warps in [2, 4]: + configs.append( + { + "BLOCK_DIM": block_dim, + "BLOCK_WIDTH": block_width, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_conv_copy_static_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for conv copy.""" + return { + "dtype": str(conv_states.dtype), + "dim_size": conv_states.shape[1], + "width_size": conv_states.shape[2], + } + + +def _get_conv_copy_run_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for conv copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_conv_state_copy:v1", + configs_gen_func=_get_conv_copy_configs, + static_key_func=_get_conv_copy_static_key, + run_key_func=_get_conv_copy_run_key, + mutates_args=["conv_states"], +) +def copy_conv_states( + conv_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Copy conv states from source indices to destination indices. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + batch_size = src_indices.shape[0] + dim_size = conv_states.shape[1] + width_size = conv_states.shape[2] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 128)) + BLOCK_WIDTH = triton.next_power_of_2(min(width_size, 4)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + BLOCK_WIDTH = run_config["BLOCK_WIDTH"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_dim_blocks = triton.cdiv(dim_size, BLOCK_DIM) + num_width_blocks = triton.cdiv(width_size, BLOCK_WIDTH) + num_blocks_total = num_dim_blocks * num_width_blocks + + grid = (batch_size, num_blocks_total) + + _copy_conv_state_kernel[grid]( + conv_states, + src_indices, + dst_indices, + batch_size, + dim_size, + width_size, + num_width_blocks, # Pass precomputed value + conv_states.stride(0), + conv_states.stride(1), + conv_states.stride(2), + BLOCK_DIM=BLOCK_DIM, + BLOCK_WIDTH=BLOCK_WIDTH, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_ssm_copy_configs(): + """Generate candidate configurations for SSM state copy.""" + configs = [] + for block_key in [16, 32, 64]: + for block_value in [16, 32, 64, 128]: + for num_warps in [2, 4, 8]: + if block_key * block_value <= 4096: + configs.append( + { + "BLOCK_KEY": block_key, + "BLOCK_VALUE": block_value, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_ssm_copy_static_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for SSM copy.""" + return { + "dtype": str(ssm_states.dtype), + "num_heads": ssm_states.shape[1], + "key_dim": ssm_states.shape[2], + "value_dim": ssm_states.shape[3], + } + + +def _get_ssm_copy_run_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for SSM copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_ssm_state_copy:v1", + configs_gen_func=_get_ssm_copy_configs, + static_key_func=_get_ssm_copy_static_key, + run_key_func=_get_ssm_copy_run_key, + mutates_args=["ssm_states"], +) +def copy_ssm_states( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Copy SSM states from source indices to destination indices. + + Args: + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + batch_size = src_indices.shape[0] + num_heads = ssm_states.shape[1] + key_dim = ssm_states.shape[2] + value_dim = ssm_states.shape[3] + + if run_config is None: + BLOCK_KEY = triton.next_power_of_2(min(key_dim, 32)) + BLOCK_VALUE = triton.next_power_of_2(min(value_dim, 64)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_KEY = run_config["BLOCK_KEY"] + BLOCK_VALUE = run_config["BLOCK_VALUE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_key_blocks = triton.cdiv(key_dim, BLOCK_KEY) + num_value_blocks = triton.cdiv(value_dim, BLOCK_VALUE) + num_blocks_total = num_key_blocks * num_value_blocks + + grid = (batch_size, num_heads, num_blocks_total) + + _copy_ssm_state_kernel[grid]( + ssm_states, + src_indices, + dst_indices, + batch_size, + num_heads, + key_dim, + value_dim, + ssm_states.stride(0), + ssm_states.stride(1), + ssm_states.stride(2), + ssm_states.stride(3), + BLOCK_KEY=BLOCK_KEY, + BLOCK_VALUE=BLOCK_VALUE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +# ============================================================================= +# Optimized Flat Copy Kernels (for contiguous memory) +# ============================================================================= +# These kernels leverage the fact that both conv_states and ssm_states are +# contiguous in memory, allowing us to flatten the inner dimensions and use +# efficient 1D vectorized copy patterns. + + +@triton.jit +def _copy_state_flat_kernel( + # State buffer pointer (flattened view) + state_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + flat_size, # Total elements per buffer entry (flattened inner dims) + # Strides + stride_buffer, # Stride to next buffer entry (in elements) + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Optimized flat copy kernel for contiguous state buffers. + + Instead of using 2D/3D block patterns with stride calculations, this kernel + treats each buffer entry as a flat 1D array and uses vectorized loads/stores + for efficient memory transfer. + + Grid: (batch_size, num_blocks) where num_blocks = ceil(flat_size / BLOCK_SIZE) + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Calculate element range for this block + elem_start = block_idx * BLOCK_SIZE + elem_offsets = elem_start + tl.arange(0, BLOCK_SIZE) + elem_mask = elem_offsets < flat_size + + # Load buffer indices for this batch item + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate source and destination base pointers + src_base = state_ptr + src_idx * stride_buffer + dst_base = state_ptr + dst_idx * stride_buffer + + # Vectorized copy + data = tl.load(src_base + elem_offsets, mask=elem_mask, other=0.0) + tl.store(dst_base + elem_offsets, data, mask=elem_mask) + + +@triton.jit +def _copy_states_fused_kernel( + # Conv state buffer (flattened view) + conv_state_ptr, + # SSM state buffer (flattened view) + ssm_state_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + conv_flat_size, # Total elements per conv buffer entry + ssm_flat_size, # Total elements per ssm buffer entry + # Strides (in elements) + conv_stride_buffer, + ssm_stride_buffer, + # Block sizes + CONV_BLOCK_SIZE: tl.constexpr, + SSM_BLOCK_SIZE: tl.constexpr, +): + """ + Fused kernel to copy both conv_states and ssm_states in a single launch. + + This reduces kernel launch overhead by processing both state copies together. + Each thread block handles one batch item and copies both states sequentially. + + Grid: (batch_size, max(conv_blocks, ssm_blocks)) + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Load buffer indices (same for both conv and ssm) + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # ========== Copy Conv State ========== + conv_num_blocks = tl.cdiv(conv_flat_size, CONV_BLOCK_SIZE) + if block_idx < conv_num_blocks: + conv_elem_start = block_idx * CONV_BLOCK_SIZE + conv_elem_offsets = conv_elem_start + tl.arange(0, CONV_BLOCK_SIZE) + conv_mask = conv_elem_offsets < conv_flat_size + + conv_src_base = conv_state_ptr + src_idx * conv_stride_buffer + conv_dst_base = conv_state_ptr + dst_idx * conv_stride_buffer + + conv_data = tl.load(conv_src_base + conv_elem_offsets, mask=conv_mask, other=0.0) + tl.store(conv_dst_base + conv_elem_offsets, conv_data, mask=conv_mask) + + # ========== Copy SSM State ========== + ssm_num_blocks = tl.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) + if block_idx < ssm_num_blocks: + ssm_elem_start = block_idx * SSM_BLOCK_SIZE + ssm_elem_offsets = ssm_elem_start + tl.arange(0, SSM_BLOCK_SIZE) + ssm_mask = ssm_elem_offsets < ssm_flat_size + + ssm_src_base = ssm_state_ptr + src_idx * ssm_stride_buffer + ssm_dst_base = ssm_state_ptr + dst_idx * ssm_stride_buffer + + ssm_data = tl.load(ssm_src_base + ssm_elem_offsets, mask=ssm_mask, other=0.0) + tl.store(ssm_dst_base + ssm_elem_offsets, ssm_data, mask=ssm_mask) + + +def _get_flat_copy_configs(): + """Generate candidate configurations for flat copy kernel.""" + configs = [] + # Larger block sizes for better memory throughput on contiguous data + for block_size in [256, 512, 1024, 2048]: + for num_warps in [4, 8]: + configs.append( + { + "BLOCK_SIZE": block_size, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_conv_flat_copy_static_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for conv flat copy.""" + return { + "dtype": str(conv_states.dtype), + "flat_size": conv_states.shape[1] * conv_states.shape[2], + } + + +def _get_conv_flat_copy_run_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for conv flat copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_conv_state_flat_copy:v1", + configs_gen_func=_get_flat_copy_configs, + static_key_func=_get_conv_flat_copy_static_key, + run_key_func=_get_conv_flat_copy_run_key, + mutates_args=["conv_states"], +) +def copy_conv_states_flat( + conv_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Optimized flat copy for conv states leveraging contiguous memory. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert conv_states.is_contiguous(), "conv_states must be contiguous for flat copy" + + batch_size = src_indices.shape[0] + # Flatten inner dimensions + flat_size = conv_states.shape[1] * conv_states.shape[2] + stride_buffer = conv_states.stride(0) + + if run_config is None: + BLOCK_SIZE = 1024 + num_warps = 4 + num_stages = 2 + else: + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) + grid = (batch_size, num_blocks) + + _copy_state_flat_kernel[grid]( + conv_states, + src_indices, + dst_indices, + batch_size, + flat_size, + stride_buffer, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_ssm_flat_copy_static_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for ssm flat copy.""" + return { + "dtype": str(ssm_states.dtype), + "flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], + } + + +def _get_ssm_flat_copy_run_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for ssm flat copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_ssm_state_flat_copy:v1", + configs_gen_func=_get_flat_copy_configs, + static_key_func=_get_ssm_flat_copy_static_key, + run_key_func=_get_ssm_flat_copy_run_key, + mutates_args=["ssm_states"], +) +def copy_ssm_states_flat( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Optimized flat copy for SSM states leveraging contiguous memory. + + Args: + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert ssm_states.is_contiguous(), "ssm_states must be contiguous for flat copy" + + batch_size = src_indices.shape[0] + # Flatten inner dimensions (num_heads * key_dim * value_dim) + flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] + stride_buffer = ssm_states.stride(0) + + if run_config is None: + BLOCK_SIZE = 1024 + num_warps = 4 + num_stages = 2 + else: + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) + grid = (batch_size, num_blocks) + + _copy_state_flat_kernel[grid]( + ssm_states, + src_indices, + dst_indices, + batch_size, + flat_size, + stride_buffer, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_fused_copy_configs(): + """Generate candidate configurations for fused copy kernel.""" + configs = [] + # Use power-of-2 block sizes for both conv and ssm + for conv_block in [256, 512, 1024]: + for ssm_block in [256, 512, 1024]: + for num_warps in [4, 8]: + configs.append( + { + "CONV_BLOCK_SIZE": conv_block, + "SSM_BLOCK_SIZE": ssm_block, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_fused_copy_static_key( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for fused copy.""" + return { + "conv_dtype": str(conv_states.dtype), + "ssm_dtype": str(ssm_states.dtype), + "conv_flat_size": conv_states.shape[1] * conv_states.shape[2], + "ssm_flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], + } + + +def _get_fused_copy_run_key( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for fused copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_states_fused_copy:v1", + configs_gen_func=_get_fused_copy_configs, + static_key_func=_get_fused_copy_static_key, + run_key_func=_get_fused_copy_run_key, + mutates_args=["conv_states", "ssm_states"], +) +def copy_states_fused( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Fused copy for both conv and SSM states in a single kernel launch. + + This reduces kernel launch overhead by processing both state copies together. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert conv_states.is_contiguous(), "conv_states must be contiguous for fused copy" + assert ssm_states.is_contiguous(), "ssm_states must be contiguous for fused copy" + + batch_size = src_indices.shape[0] + + # Flatten inner dimensions + conv_flat_size = conv_states.shape[1] * conv_states.shape[2] + ssm_flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] + + conv_stride_buffer = conv_states.stride(0) + ssm_stride_buffer = ssm_states.stride(0) + + if run_config is None: + CONV_BLOCK_SIZE = 512 + SSM_BLOCK_SIZE = 512 + num_warps = 4 + num_stages = 2 + else: + CONV_BLOCK_SIZE = run_config["CONV_BLOCK_SIZE"] + SSM_BLOCK_SIZE = run_config["SSM_BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + # Grid covers both conv and ssm blocks + conv_num_blocks = triton.cdiv(conv_flat_size, CONV_BLOCK_SIZE) + ssm_num_blocks = triton.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) + max_blocks = max(conv_num_blocks, ssm_num_blocks) + grid = (batch_size, max_blocks) + + _copy_states_fused_kernel[grid]( + conv_states, + ssm_states, + src_indices, + dst_indices, + batch_size, + conv_flat_size, + ssm_flat_size, + conv_stride_buffer, + ssm_stride_buffer, + CONV_BLOCK_SIZE=CONV_BLOCK_SIZE, + SSM_BLOCK_SIZE=SSM_BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py new file mode 100644 index 0000000000..0a2b4bd662 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py @@ -0,0 +1,141 @@ +import torch + +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _gemma_rmsnorm_fwd_kernel( + x_ptr, + w_ptr, + y_ptr, + x_stride0, + x_stride1, + y_stride0, + y_stride1, + N: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + x_ptr = x_ptr + row * x_stride0 + y_ptr = y_ptr + row * y_stride0 + + _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + cols * x_stride1, mask=cols < N, other=0.0).to(tl.float32) + _sum += x * x + + var = tl.sum(_sum, axis=0) / N + rstd = 1 / tl.sqrt(var + EPS) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) + x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + w = w + 1.0 + y = x_hat * w + # Write output + tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) + + +def _get_gemma_rmsnorm_configs(): + """Generate configurations for autotuning gemma RMSNorm kernel.""" + configs = [] + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: + for num_warps in [1, 2, 4, 8]: + # num_stages has minimal impact on this simple kernel, use 1 + configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1}) + return configs + + +def _get_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + N = x.shape[-1] + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(w.dtype), + "N": N, + } + + +@autotune( + kernel_name="gemma_rmsnorm_forward:v1", + configs_gen_func=_get_gemma_rmsnorm_configs, + static_key_func=_get_gemma_rmsnorm_static_key, + run_key_func=lambda x: x.shape[-1], +) +def gemma_rmsnorm_forward(x, w, eps, out=None, run_config: dict = None): + # Inplace gemma RMS Norm + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + N = x.shape[-1] + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, N) + y_arg = y.view(-1, N) + + M, _ = x_arg.shape + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This gemma rmsnorm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} + + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + _gemma_rmsnorm_fwd_kernel[(M,)]( + x_arg, + w, + y_arg, + x_stride0=x.stride(0), + x_stride1=x.stride(1), + y_stride0=y.stride(0), + y_stride1=y.stride(1), + N=N, + EPS=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y + + +def _gemma_rmsnorm_fwd_torch(x, weight, eps): + original_dtype = x.dtype + x = x.to(torch.float32) + x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + x = x * (1.0 + weight.float()) + return x.to(original_dtype) + + +def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + # forward pass + y_tri = gemma_rmsnorm_forward(x, weight, eps) + y_ref = _gemma_rmsnorm_fwd_torch(x, weight, eps) + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + # Use appropriate tolerance based on dtype + atol = 1e-2 if dtype == torch.float32 else 5e-2 + assert torch.allclose(y_tri, y_ref, atol=atol, rtol=0) + return diff --git a/lightllm/models/qwen3next_mtp/__init__.py b/lightllm/models/qwen3next_mtp/__init__.py new file mode 100644 index 0000000000..779237817d --- /dev/null +++ b/lightllm/models/qwen3next_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel + +__all__ = ["Qwen3NextMTPModel"] diff --git a/lightllm/models/qwen3next_mtp/layer_infer/__init__.py b/lightllm/models/qwen3next_mtp/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..2918fca79c --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py @@ -0,0 +1,16 @@ +import torch +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextMTPPostLayerInfer(LlamaPostLayerInfer): + """ + Qwen3Next MTP Post Layer Inference. + Uses gemma_rmsnorm for normalization (same as Qwen3Next). + """ + + def _norm(self, input, infer_state, layer_weight: Qwen3NextMTPPreAndPostLayerWeight) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) + return out diff --git a/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..4fc207648c --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,68 @@ +import torch + +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextMTPPreLayerInfer(LlamaPreLayerInfer): + """ + Qwen3Next MTP Pre-Layer Inference. + Similar to DeepSeek MTP but with different weight structure. + + MTP forward flow: + 1. Get embedding from input_ids + 2. Get hidden state from main model (passed via infer_state) + 3. Normalize embedding with pre_fc_norm_embedding + 4. Normalize hidden with pre_fc_norm_hidden + 5. Concat normalized embedding and hidden + 6. Project through fc to get hidden_dim output + """ + + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + + # Normalize embedding + input_embdings_normed = self.alloc_tensor(input_embdings.shape, input_embdings.dtype) + gemma_rmsnorm_forward( + input_embdings, layer_weight.pre_fc_norm_embedding_weight_.weight, self.eps_, out=input_embdings_normed + ) + + # Normalize hidden state + tgt_embdings_normed = self.alloc_tensor(tgt_embdings.shape, tgt_embdings.dtype) + gemma_rmsnorm_forward( + tgt_embdings, layer_weight.pre_fc_norm_hidden_weight_.weight, self.eps_, out=tgt_embdings_normed + ) + + # Concat normalized embedding and hidden + cat_embdings = torch.cat((input_embdings_normed, tgt_embdings_normed), dim=-1) + + # Project to hidden_size + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.fc_weight_.shape[1]), dtype=cat_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.fc_weight_, out=ans_logics) + + return ans_logics + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..03630c17c1 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py @@ -0,0 +1,30 @@ +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import Qwen3NextFullAttentionBaseLayerInfer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3NextMTPTransformerLayerInfer(Qwen3NextFullAttentionBaseLayerInfer): + """ + Qwen3Next MTP Transformer Layer Inference. + MTP layers use full attention (not linear attention) with MoE FFN and shared expert. + Inherits shared methods from Qwen3NextFullAttentionBaseLayerInfer. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) + self.tp_v_head_num_ = max(self.tp_v_head_num_, 1) + return + + def _bind_ffn(self): + """MTP always uses shared expert + MoE""" + from functools import partial + import os + + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_tp, self) + return diff --git a/lightllm/models/qwen3next_mtp/layer_weights/__init__.py b/lightllm/models/qwen3next_mtp/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..8a74ef8567 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,47 @@ +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpGEMMANormWeight + + +class Qwen3NextMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + self.wte_weight_ = None + self.lm_head_weight_ = None + + hidden_size = network_config["hidden_size"] + # Use Gemma-style normalization for all MTP norm layers + self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.norm.weight", + data_type=self.data_type_, + ) + self.pre_fc_norm_embedding_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_embedding.weight", + data_type=self.data_type_, + ) + self.pre_fc_norm_hidden_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_hidden.weight", + data_type=self.data_type_, + ) + return + + def load_hf_weights(self, weights): + if "mtp.fc.weight" in weights: + self.fc_weight_ = self._cuda(weights["mtp.fc.weight"]).t() + + # Load weights for norm weight objects + self.final_norm_weight_.load_hf_weights(weights) + self.pre_fc_norm_embedding_weight_.load_hf_weights(weights) + self.pre_fc_norm_hidden_weight_.load_hf_weights(weights) + + return + + def verify_load(self): + # Verify all norm weights loaded correctly + return ( + self.final_norm_weight_.verify_load() + and self.pre_fc_norm_embedding_weight_.verify_load() + and self.pre_fc_norm_hidden_weight_.verify_load() + ) diff --git a/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..d52da5647d --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,141 @@ +import os +import torch +import math +import numpy as np +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.utils.envs_utils import enable_env_vars +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + RMSNormWeight, + QKRMSNORMWeight, + KVROWNMMWeight, +) +from functools import partial + + +class Qwen3NextMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_weight_names(self): + self._q_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.q_proj.weight" + self._q_norm_name = f"mtp.layers.{self.layer_num_}.self_attn.q_norm.weight" + self._q_bias_name = None + self._k_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.k_proj.weight" + self._k_norm_name = f"mtp.layers.{self.layer_num_}.self_attn.k_norm.weight" + self._k_bias_name = None + self._v_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.v_proj.weight" + self._v_bias_name = None + self._kv_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.kv_proj.weight" + self._kv_bias_name = None + self._o_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.o_proj.weight" + self._o_bias_name = None + self._att_norm_weight_name = f"mtp.layers.{self.layer_num_}.input_layernorm.weight" + self._att_norm_bias_name = None + self._ffn_norm_weight_name = f"mtp.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None + + def _init_qkv(self): + # Override parent's QKVROWNMMWeight which requires kv_head_num % tp == 0. + # Qwen3-Next has few KV heads; KVROWNMMWeight handles repeating. + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=self._q_weight_name, + data_type=self.data_type_, + bias_names=self._q_bias_name, + quant_method=self.get_quant_method("q_proj"), + ) + self.kv_proj = KVROWNMMWeight( + in_dim=in_dim, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("kv_proj"), + ) + + def _init_weight(self): + self._init_moe() + self._init_shared_expert_weight() + + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + + self._init_qkv() + self._init_o() + self.q_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim, weight_name=self._q_norm_name, data_type=self.data_type_ + ) + self.k_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim, weight_name=self._k_norm_name, data_type=self.data_type_ + ) + self._o_gate_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + q_out_dim = self.q_head_num_ * self.head_dim + self.o_gate_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[q_out_dim], + weight_names=self._o_gate_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + return + + def load_hf_weights(self, weights): + self._split_q_with_gate(weights) + super().load_hf_weights(weights) + + def _init_shared_expert_weight(self): + prefix = f"mtp.layers.{self.layer_num_}.mlp.shared_expert" + hidden_size = self.network_config_["hidden_size"] + shared_inter = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[shared_inter, shared_inter], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=shared_inter, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"mtp.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + + def _split_q_with_gate(self, weights): + if self._q_weight_name in weights: + weight = weights[self._q_weight_name] + num_heads = self.q_head_num_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self._q_weight_name] = _q_proj + weights[self._o_gate_weight_name] = _gate_proj diff --git a/lightllm/models/qwen3next_mtp/model.py b/lightllm/models/qwen3next_mtp/model.py new file mode 100644 index 0000000000..92e4918bea --- /dev/null +++ b/lightllm/models/qwen3next_mtp/model.py @@ -0,0 +1,101 @@ +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3next_mtp.layer_infer.pre_layer_infer import Qwen3NextMTPPreLayerInfer +from lightllm.models.qwen3next_mtp.layer_infer.transformer_layer_infer import Qwen3NextMTPTransformerLayerInfer +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.qwen3next_mtp.layer_weights.transformer_layer_weight import Qwen3NextMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights +from lightllm.models.registry import ModelRegistry + + +@ModelRegistry("qwen3next_mtp") +class Qwen3NextMTPModel(Qwen3NextTpPartModel): + + pre_and_post_weight_class = Qwen3NextMTPPreAndPostLayerWeight + pre_layer_infer_class = Qwen3NextMTPPreLayerInfer + transformer_weight_class = Qwen3NextMTPTransformerLayerWeight + transformer_layer_infer_class = Qwen3NextMTPTransformerLayerInfer + + def __init__(self, kvargs: dict): + self.mtp_n_layers = 1 + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + """Extract main model and memory layer start from kwargs.""" + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start") + return + + def autotune_layers(self): + return 1 + + def _init_some_value(self): + self.layers_num = self.mtp_n_layers + + def _init_config(self): + super()._init_config() + self.config["n_layers"] = self.mtp_n_layers + self.config["num_hidden_layers"] = self.mtp_n_layers + return + + def _init_custom(self): + """Initialize custom components, sharing cos/sin cache with main model.""" + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + """Share request manager with main model.""" + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + """Share memory manager with main model.""" + self.mem_manager = self.main_model.mem_manager + return + + def _check_mem_size(self): + """Skip mem size check for MTP models since they share memory with main model.""" + self.max_total_token_num = self.mem_manager.size + return + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(self.mtp_n_layers) + ] + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + self.layers_infer = [ + self.transformer_layer_infer_class( + i * self.config["full_attention_interval"] - 1, # Ensure full attention layer + network_config=self.config, + ) + for i in range(self.mtp_n_layers) + ] + # Ensure full attention layer + for i, layer in enumerate(self.layers_infer): + layer.layer_num_ = i + self.mem_layer_start + return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 96126744af..4d122f615d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -128,7 +128,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"], + choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2", "qwen3_coder"], default=None, help="tool call parser type", ) @@ -551,7 +551,15 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--mtp_mode", - choices=["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att", None], + choices=[ + "vanilla_with_att", + "eagle_with_att", + "vanilla_no_att", + "eagle_no_att", + "qwen3next_vanilla", + "qwen3next_eagle", + None, + ], default=None, help="""Supported MTP modes. None: Disables MTP. @@ -621,6 +629,14 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument("--mamba_cache_size", type=int, default=3000, help="""The size of linear attn cache. """) + parser.add_argument( + "--mamba_ssm_data_type", + type=str, + choices=["bfloat16", "float32"], + default="float32", + help="the data type of the model weight", + ) parser.add_argument( "--hardware_platform", type=str, diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d91bb1d947..fc14314ae3 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -176,10 +176,24 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req multimodal_params_dict["images"].append({"type": "base64", "data": data}) else: raise ValueError("Unrecognized image input.") + elif img.startswith("file://"): + # Local file path with file:// prefix + file_path = img[7:] # Remove "file://" prefix + with open(file_path, "rb") as f: + multimodal_params_dict["images"].append( + {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} + ) else: - raise ValueError( - "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." - ) + # Treat as local file path + if os.path.isfile(img): + with open(img, "rb") as f: + multimodal_params_dict["images"].append( + {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} + ) + else: + raise ValueError( + "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." + ) tools = None if request.tools and request.tool_choice != "none": diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 111def60c2..34dd69c801 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -132,7 +132,8 @@ def normal_or_p_d_start(args): # mtp params check if args.mtp_mode is not None: - assert args.mtp_draft_model_dir is not None + if args.mtp_draft_model_dir is None: + args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step assert args.mtp_step > 0 else: assert args.mtp_draft_model_dir is None diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index a369cf7f7f..d8d2c6ff8b 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -31,7 +31,8 @@ class StartArgs: batch_max_tokens: Optional[int] = field(default=None) eos_id: List[int] = field(default_factory=list) tool_call_parser: Optional[str] = field( - default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]} + default=None, + metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]}, ) reasoning_parser: Optional[str] = field( default=None, @@ -54,7 +55,7 @@ class StartArgs: }, ) chat_template: Optional[str] = field(default=None) - running_max_req_size: int = field(default=1000) + running_max_req_size: int = field(default=512) tp: int = field(default=1) dp: int = field(default=1) nnodes: int = field(default=1) @@ -107,7 +108,7 @@ class StartArgs: disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) prefll_cudagraph_max_handle_token: int = field(default=512) - graph_max_batch_size: int = field(default=256) + graph_max_batch_size: int = field(default=512) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=0) @@ -134,7 +135,18 @@ class StartArgs: ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) mtp_mode: Optional[str] = field( - default=None, metadata={"choices": ["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att"]} + default=None, + metadata={ + "choices": [ + "vanilla_with_att", + "eagle_with_att", + "vanilla_no_att", + "eagle_no_att", + "qwen3next_vanilla", + "qwen3next_eagle", + None, + ] + }, ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) @@ -162,3 +174,7 @@ class StartArgs: # multi_modal enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) + + # hybrid attention model (Qwen3Next) + mamba_cache_size: int = field(default=800) + mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py new file mode 100644 index 0000000000..2a4fe06628 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -0,0 +1,206 @@ +from typing import Set, Protocol, List, Optional, Tuple + +import torch +from sortedcontainers import SortedSet + +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class HybridRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager): + super().__init__(unique_name, total_token_num, rank_in_node, kv_cache_mem_manager) + assert hasattr(kv_cache_mem_manager, "mamba_cache_mem_manager") + self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.mamba_cache_mem_manager + self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,)) + + def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): + if need_buffer_num > self.buffer_mem_manager.can_use_mem_size: + need_evict_buffer_num = need_buffer_num - self.buffer_mem_manager.can_use_mem_size + + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self._evict_buffer(need_evict_buffer_num, release_buffer, release_mem) + self.buffer_mem_manager.free(release_buffers) + if len(release_mems) > 0: + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return + + def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback): + while need_evict_buffer_num > 0: + node = self.evict_buffer_set.pop(0) + assert node.buffer_idx is not None + evict_buffer_callback(node.buffer_idx) + node.buffer_idx = None + need_evict_buffer_num -= 1 + # 当一个节点的buffer_idx变为None时,事实上无法在后续进行match, + # 但当该节点子节点或者引用数不为0时,仍然需要保留, 否则则应该被删除 + if node.is_leaf() and node.ref_counter == 0: + self.evict_tree_set.discard(node) + evict_token_callback(node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + return + + def insert_for_hybrid_radix_cache(self, reqs): + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + reqs_to_insert = [req for req in reqs if req.cur_kv_len < req.get_cur_total_len()] + + if len(reqs_to_insert) == 0: + return + + self.free_radix_cache_to_get_enough_buffer(len(reqs_to_insert)) + req_idxes = torch.tensor([req.req_idx for req in reqs_to_insert], dtype=torch.int64, device="cuda") + req_to_buffer_index = g_infer_context.req_manager.req_to_buffer_index + # Make contiguous and convert to int64 for Triton kernel compatibility + cur_buffer_indexes = req_to_buffer_index[req_idxes, 0].contiguous().to(torch.int64) + + new_buffer_indexes = self.buffer_mem_manager.alloc(len(reqs_to_insert)) + # Move to CUDA and convert to int64, ensure contiguous + new_buffer_indexes_cuda = new_buffer_indexes.to(device="cuda", dtype=torch.int64).contiguous() + + self.buffer_mem_manager.copy_buffer_p2p(cur_buffer_indexes, new_buffer_indexes_cuda) + + for i, req in enumerate(reqs_to_insert): + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() + prefix_len, new_shared_kv_node = super().insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + self.dec_node_ref_counter(req.shared_kv_node) + self.add_node_ref_counter(new_shared_kv_node) + self.add_buffer_idx_to_node(new_shared_kv_node, new_buffer_indexes[i].item()) + req.extra_need_to_free_token_index.append( + g_infer_context.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] + ) + req.shared_kv_node = new_shared_kv_node + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + miss_prefix_len = 0 + evict_token_list = [] + while tree_node != self.root_node and tree_node.buffer_idx is None: + if tree_node.is_leaf(): + self.evict_tree_set.discard(tree_node) + + # Only update ref_counter when update_refs is True to maintain consistency + # with _match_prefix_helper which only increments ref_counter when update_refs=True + if update_refs: + if tree_node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) + tree_node.ref_counter -= 1 # 只减少当前节点,不递归 + + if tree_node.is_leaf() and tree_node.ref_counter == 0: + evict_token_list.append(tree_node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) + parent_node: TreeNode = tree_node.parent + parent_node.remove_child(tree_node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + tree_node = parent_node + else: + if tree_node.is_leaf(): + self.evict_tree_set.add(tree_node) + tree_node = tree_node.parent + miss_prefix_len += len(ans_value_list.pop()) + + if len(evict_token_list) > 0: + evict_token_value = torch.concat(evict_token_list) + self.mem_manager.free(evict_token_value) + + if tree_node == self.root_node: + return None, miss_prefix_len, None + + update_node = tree_node + while update_node != self.root_node: + if update_node.buffer_idx is not None: + self.evict_buffer_set.discard(update_node) + update_node.update_buffer_time() + self.evict_buffer_set.add(update_node) + update_node = update_node.parent + + value = torch.concat(ans_value_list) + return tree_node, miss_prefix_len, value + + def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int): + """Set buffer_idx for a node and add it to evict_buffer_set.""" + self.evict_buffer_set.discard(node) + if node.is_leaf(): + self.evict_tree_set.discard(node) + if node.buffer_idx is not None: + self.buffer_mem_manager.free([node.buffer_idx]) + node.buffer_idx = buffer_idx + node.update_buffer_time() + self.evict_buffer_set.add(node) + if node.is_leaf(): + self.evict_tree_set.add(node) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self.evict(need_evict_token_num, release_buffer, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + if len(release_buffers) > 0: + self.buffer_mem_manager.free(release_buffers) + return + + def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), f"error evict tree node state: {node.ref_counter}, {len(node.children)}" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + if node.buffer_idx is not None: + self.evict_buffer_set.discard(node) + evict_buffer_callback(node.buffer_idx) + node.buffer_idx = None + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 09bc938f23..3b59401144 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -101,6 +101,14 @@ def get_tokenizer( tokenizer = QWen3VLTokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) + elif model_type in ["qwen3_5", "qwen3_5_moe"] and "vision_config" in model_cfg: + from transformers import AutoProcessor + from ..models.qwen3_5.model import QWen3_5Tokenizer + + processor = AutoProcessor.from_pretrained(tokenizer_name) + tokenizer = QWen3_5Tokenizer( + tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + ) elif model_cfg.get("thinker_config") is not None: from transformers import AutoProcessor diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 790f185f25..fa0a9e3c71 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -87,6 +87,22 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: except: pass + # Qwen3.5 checkpoints can have an eos_token_id in config that differs from + # tokenizer.eos_token_id. In practice tokenizer.eos_token_id is the reliable + # stop id (<|im_end|>) for detokenization/stop behavior. + try: + config_json = get_config_json(model_path) + model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") + if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + if tokenizer.eos_token_id is not None: + return [int(tokenizer.eos_token_id)] + except Exception: + # Fall back to config-based lookup below. + pass + eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) if isinstance(eos_token_id, int): return [eos_token_id] diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 7a7a9be121..cdafb88873 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -158,7 +158,7 @@ def get_kv_quant_calibration_inference_count(): @lru_cache(maxsize=None) def get_triton_autotune_level(): - return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0)) + return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 1)) g_model_init_done = False From c757b062f17d6a0a2623d9faa11af8cdf0fa664f Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 20 Feb 2026 03:30:55 +0000 Subject: [PATCH 091/180] refactor: simplify mamba buffer copy and integrate Triton kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Reduce Triton kernels from 6 (1D/2D/3D × p2p/broadcast) to 2 (1D only) by flattening contiguous trailing dimensions via tensor view - Wire up MambaCacheManager to use the Triton kernels instead of PyTorch advanced indexing with Python for-loops - Cast strides to int64 in kernels to prevent pointer arithmetic overflow - Add Qwen3.5 multimodal vision-language model support --- .../common/basemodel/attention_vit/fa3/fp.py | 3 +- .../triton_kernel/mamba_buffer_copy.py | 671 +----------------- .../mamba_cache_mem_manager/cache_manager.py | 91 +-- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 4 +- lightllm/models/qwen2_vl/vision_process.py | 5 +- lightllm/models/qwen35_moe/model.py | 42 ++ lightllm/models/qwen3_5/__init__.py | 17 + lightllm/models/qwen3_5/infer_struct.py | 110 +++ .../models/qwen3_5/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 121 ++++ .../models/qwen3_5/layer_weights/__init__.py | 0 .../layer_weights/transformer_layer_weight.py | 166 +++++ lightllm/models/qwen3_5/model.py | 229 ++++++ lightllm/server/build_prompt.py | 23 +- lightllm/server/core/objs/sampling_params.py | 52 +- lightllm/server/function_call_parser.py | 224 ++++++ .../router/dynamic_prompt/radix_cache.py | 14 +- .../server/router/model_infer/infer_batch.py | 136 +++- .../model_infer/mode_backend/base_backend.py | 54 +- .../mode_backend/chunked_prefill/impl.py | 32 + .../mode_backend/dp_backend/impl.py | 28 + .../visualserver/model_infer/model_rpc.py | 2 +- test_gsmk.py | 241 +++++++ 23 files changed, 1506 insertions(+), 759 deletions(-) create mode 100644 lightllm/models/qwen35_moe/model.py create mode 100644 lightllm/models/qwen3_5/__init__.py create mode 100644 lightllm/models/qwen3_5/infer_struct.py create mode 100644 lightllm/models/qwen3_5/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/qwen3_5/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3_5/model.py create mode 100644 test_gsmk.py diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index 406ff7408d..d5e623b188 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -45,7 +45,8 @@ def _vit_att_fwd( False, window_size[0], window_size[1], - 0.0, + 0, # attention_chunk + 0.0, # softcap is_rotary_interleaved=False, scheduler_metadata=None, num_splits=1, diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py index b4a91f7861..6a1d8adbd5 100644 --- a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -1,10 +1,3 @@ -""" -Optimized Mamba Buffer Copy Kernels with Autotune Support - -This module provides auto-tuned Triton kernels for efficient buffer copying operations -in Mamba-style models, including support for MTP (Multi-Token Prediction) buffer broadcasting. -""" - import torch import triton import triton.language as tl @@ -35,6 +28,10 @@ def _copy_buffer_p2p_1d_kernel( layer_idx = tl.program_id(1) + layer_idx_offset block_d_idx = tl.program_id(2) + # Cast strides to int64 to prevent overflow in pointer arithmetic + stride_layer = stride_layer.to(tl.int64) + stride_index = stride_index.to(tl.int64) + # Load source and destination indices for this pair src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) @@ -58,66 +55,6 @@ def _copy_buffer_p2p_1d_kernel( tl.store(dst_ptr, data, mask=mask) -@triton.jit -def _copy_buffer_p2p_2d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - pair_idx_offset, - layer_idx_offset, - stride_layer, - stride_index, - stride_d1, - stride_d2, - d1_size, - d2_size, - num_blocks_d2, - BLOCK_D1: tl.constexpr, - BLOCK_D2: tl.constexpr, -): - """ - Kernel to copy 2D buffer from source indices to destination indices. - - Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2) - Each program copies one 2D block for one (pair, layer) combination. - """ - pair_idx = tl.program_id(0) + pair_idx_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_idx = tl.program_id(2) - - # Decompose block_idx into d1 and d2 block indices - block_d1_idx = block_idx // num_blocks_d2 - block_d2_idx = block_idx % num_blocks_d2 - - # Load source and destination indices - src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) - dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) - - # Calculate offsets for this block - d1_start = block_d1_idx * BLOCK_D1 - d2_start = block_d2_idx * BLOCK_D2 - - d1_offsets = d1_start + tl.arange(0, BLOCK_D1) - d2_offsets = d2_start + tl.arange(0, BLOCK_D2) - - # Create mask for valid indices - d1_mask = d1_offsets < d1_size - d2_mask = d2_offsets < d2_size - mask = d1_mask[:, None] & d2_mask[None, :] - - # Calculate base pointers - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index - - # Calculate full offsets - offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 - - # Load and store - data = tl.load(base_src + offsets, mask=mask, other=0.0) - tl.store(base_dst + offsets, data, mask=mask) - - @triton.jit def _copy_buffer_broadcast_1d_kernel( src_buffer_ptr, @@ -142,6 +79,10 @@ def _copy_buffer_broadcast_1d_kernel( layer_idx = tl.program_id(1) + layer_idx_offset block_d_idx = tl.program_id(2) + # Cast strides to int64 to prevent overflow in pointer arithmetic + stride_layer = stride_layer.to(tl.int64) + stride_index = stride_index.to(tl.int64) + # Load source index src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) @@ -168,219 +109,6 @@ def _copy_buffer_broadcast_1d_kernel( tl.store(dst_ptr, data, mask=mask) -@triton.jit -def _copy_buffer_broadcast_2d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - copy_idx_offset, - layer_idx_offset, - stride_layer, - stride_index, - stride_d1, - stride_d2, - d1_size, - d2_size, - num_blocks_d2, - num_dst_per_src, - BLOCK_D1: tl.constexpr, - BLOCK_D2: tl.constexpr, -): - """ - Broadcast kernel for 2D buffer copy (one source to multiple destinations). - - Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2) - """ - src_idx_in_batch = tl.program_id(0) + copy_idx_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_idx = tl.program_id(2) - - # Decompose block_idx - block_d1_idx = block_idx // num_blocks_d2 - block_d2_idx = block_idx % num_blocks_d2 - - # Load source index - src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) - - # Calculate offsets - d1_start = block_d1_idx * BLOCK_D1 - d2_start = block_d2_idx * BLOCK_D2 - - d1_offsets = d1_start + tl.arange(0, BLOCK_D1) - d2_offsets = d2_start + tl.arange(0, BLOCK_D2) - - d1_mask = d1_offsets < d1_size - d2_mask = d2_offsets < d2_size - mask = d1_mask[:, None] & d2_mask[None, :] - - # Calculate source pointer and load data once - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 - data = tl.load(base_src + offsets, mask=mask, other=0.0) - - # Broadcast to all destinations - for dst_offset in range(num_dst_per_src): - dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset - dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) - - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index - tl.store(base_dst + offsets, data, mask=mask) - - -@triton.jit -def _copy_buffer_p2p_3d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - pair_idx_offset, - layer_idx_offset, - stride_layer, - stride_index, - stride_d1, - stride_d2, - stride_d3, - d1_size, - d2_size, - d3_size, - num_blocks_d2, - num_blocks_d3, - BLOCK_D1: tl.constexpr, - BLOCK_D2: tl.constexpr, - BLOCK_D3: tl.constexpr, -): - """ - Optimized kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). - - Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) - Each program copies one 3D block for one (pair, layer) combination. - """ - pair_idx = tl.program_id(0) + pair_idx_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_idx = tl.program_id(2) - - # Decompose block_idx into d1, d2, d3 block indices - block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) - temp = block_idx % (num_blocks_d2 * num_blocks_d3) - block_d2_idx = temp // num_blocks_d3 - block_d3_idx = temp % num_blocks_d3 - - # Load source and destination indices for this pair - src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) - dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) - - # Calculate offsets for this block - d1_start = block_d1_idx * BLOCK_D1 - d2_start = block_d2_idx * BLOCK_D2 - d3_start = block_d3_idx * BLOCK_D3 - - d1_offsets = d1_start + tl.arange(0, BLOCK_D1) - d2_offsets = d2_start + tl.arange(0, BLOCK_D2) - d3_offsets = d3_start + tl.arange(0, BLOCK_D3) - - # Create masks for valid indices - d1_mask = d1_offsets < d1_size - d2_mask = d2_offsets < d2_size - d3_mask = d3_offsets < d3_size - - # 3D mask: [BLOCK_D1, BLOCK_D2, BLOCK_D3] - mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] - - # Calculate base pointers - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index - - # Calculate full 3D offsets - offsets = ( - d1_offsets[:, None, None] * stride_d1 - + d2_offsets[None, :, None] * stride_d2 - + d3_offsets[None, None, :] * stride_d3 - ) - - # Load and store - data = tl.load(base_src + offsets, mask=mask, other=0.0) - tl.store(base_dst + offsets, data, mask=mask) - - -@triton.jit -def _copy_buffer_broadcast_3d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - copy_idx_offset, - layer_idx_offset, - stride_layer, - stride_index, - stride_d1, - stride_d2, - stride_d3, - d1_size, - d2_size, - d3_size, - num_blocks_d2, - num_blocks_d3, - num_dst_per_src, - BLOCK_D1: tl.constexpr, - BLOCK_D2: tl.constexpr, - BLOCK_D3: tl.constexpr, -): - """ - Broadcast kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). - - Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) - Each program loads once from source and broadcasts to all destinations. - """ - src_idx_in_batch = tl.program_id(0) + copy_idx_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_idx = tl.program_id(2) - - # Decompose block_idx into d1, d2, d3 block indices - block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) - temp = block_idx % (num_blocks_d2 * num_blocks_d3) - block_d2_idx = temp // num_blocks_d3 - block_d3_idx = temp % num_blocks_d3 - - # Load source index - src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) - - # Calculate offsets for this block - d1_start = block_d1_idx * BLOCK_D1 - d2_start = block_d2_idx * BLOCK_D2 - d3_start = block_d3_idx * BLOCK_D3 - - d1_offsets = d1_start + tl.arange(0, BLOCK_D1) - d2_offsets = d2_start + tl.arange(0, BLOCK_D2) - d3_offsets = d3_start + tl.arange(0, BLOCK_D3) - - # Create masks - d1_mask = d1_offsets < d1_size - d2_mask = d2_offsets < d2_size - d3_mask = d3_offsets < d3_size - - mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] - - # Calculate source pointer and load data once - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - - offsets = ( - d1_offsets[:, None, None] * stride_d1 - + d2_offsets[None, :, None] * stride_d2 - + d3_offsets[None, None, :] * stride_d3 - ) - - data = tl.load(base_src + offsets, mask=mask, other=0.0) - - # Broadcast to all destinations for this source - for dst_offset in range(num_dst_per_src): - dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset - dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) - - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index - tl.store(base_dst + offsets, data, mask=mask) - - # ==================== Config Generation Functions ==================== @@ -400,47 +128,6 @@ def _get_buffer_copy_1d_configs(): return configs -def _get_buffer_copy_2d_configs(): - """Generate candidate configurations for 2D buffer copy.""" - configs = [] - for block_d1 in [16, 32, 64, 128]: - for block_d2 in [16, 32, 64, 128, 256]: - for num_warps in [2, 4, 8]: - for num_stages in [2, 3, 4]: - configs.append( - { - "BLOCK_D1": block_d1, - "BLOCK_D2": block_d2, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - -def _get_buffer_copy_3d_configs(): - """Generate candidate configurations for 3D buffer copy (5D tensor).""" - configs = [] - for block_d1 in [8, 16, 32]: - for block_d2 in [8, 16, 32, 64]: - for block_d3 in [8, 16, 32, 64, 128]: - for num_warps in [4, 8]: - for num_stages in [2, 3]: - # Skip configs that are too large for shared memory - if block_d1 * block_d2 * block_d3 > 32768: - continue - configs.append( - { - "BLOCK_D1": block_d1, - "BLOCK_D2": block_d2, - "BLOCK_D3": block_d3, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - # ==================== Static and Run Key Functions ==================== @@ -450,7 +137,7 @@ def _get_buffer_copy_static_key(src_buffer: torch.Tensor): return { "ndim": len(shape), "layer_num": shape[0], - "d_sizes": str(shape[2:]), # Dimension sizes + "d_sizes": str(shape[2:]), "dtype": str(src_buffer.dtype), } @@ -483,7 +170,6 @@ def _copy_buffer_p2p_1d_autotuned( d_size = src_buffer.shape[2] if run_config is None: - # Default config if autotune is disabled BLOCK_D = triton.next_power_of_2(min(d_size, 256)) num_warps = 4 if BLOCK_D > 256 else 2 num_stages = 2 @@ -523,75 +209,6 @@ def _copy_buffer_p2p_1d_autotuned( ) -@autotune( - kernel_name="mamba_buffer_copy_p2p_2d:v1", - configs_gen_func=_get_buffer_copy_2d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, - mutates_args=["dst_buffer"], -) -def _copy_buffer_p2p_2d_autotuned( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, - run_config: dict = None, -): - """Auto-tuned 2D buffer copy.""" - num_pairs = src_indexes.shape[0] - layer_num = src_buffer.shape[0] - d1_size = src_buffer.shape[2] - d2_size = src_buffer.shape[3] - - if run_config is None: - # Default config if autotune is disabled - BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) - BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_D1 = run_config["BLOCK_D1"] - BLOCK_D2 = run_config["BLOCK_D2"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) - num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) - num_blocks_total = num_blocks_d1 * num_blocks_d2 - - MAX_GRID_SIZE = 65535 - - for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): - pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) - pair_chunk_size = pair_chunk_end - pair_chunk_start - - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - - grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) - - _copy_buffer_p2p_2d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - pair_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - src_buffer.stride(3), - d1_size, - d2_size, - num_blocks_d2, - BLOCK_D1=BLOCK_D1, - BLOCK_D2=BLOCK_D2, - num_warps=num_warps, - num_stages=num_stages, - ) - - @autotune( kernel_name="mamba_buffer_broadcast_1d:v1", configs_gen_func=_get_buffer_copy_1d_configs, @@ -653,231 +270,19 @@ def _copy_buffer_broadcast_1d_autotuned( ) -@autotune( - kernel_name="mamba_buffer_broadcast_2d:v1", - configs_gen_func=_get_buffer_copy_2d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, - mutates_args=["dst_buffer"], -) -def _copy_buffer_broadcast_2d_autotuned( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, - run_config: dict = None, -): - """Auto-tuned 2D buffer broadcast (one src to multiple dst).""" - num_src = src_indexes.shape[0] - layer_num = src_buffer.shape[0] - d1_size = src_buffer.shape[2] - d2_size = src_buffer.shape[3] - num_dst_per_src = dst_indexes.shape[0] // num_src - - if run_config is None: - BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) - BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_D1 = run_config["BLOCK_D1"] - BLOCK_D2 = run_config["BLOCK_D2"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) - num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) - num_blocks_total = num_blocks_d1 * num_blocks_d2 - - MAX_GRID_SIZE = 65535 - - for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): - src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) - src_chunk_size = src_chunk_end - src_chunk_start - - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - - grid = (src_chunk_size, layer_chunk_size, num_blocks_total) - - _copy_buffer_broadcast_2d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - src_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - src_buffer.stride(3), - d1_size, - d2_size, - num_blocks_d2, - num_dst_per_src, - BLOCK_D1=BLOCK_D1, - BLOCK_D2=BLOCK_D2, - num_warps=num_warps, - num_stages=num_stages, - ) - - -@autotune( - kernel_name="mamba_buffer_copy_p2p_3d:v1", - configs_gen_func=_get_buffer_copy_3d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, - mutates_args=["dst_buffer"], -) -def _copy_buffer_p2p_3d_autotuned( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, - run_config: dict = None, -): - """Auto-tuned 3D data buffer copy (5D tensor).""" - num_pairs = src_indexes.shape[0] - layer_num = src_buffer.shape[0] - d1_size = src_buffer.shape[2] - d2_size = src_buffer.shape[3] - d3_size = src_buffer.shape[4] - - if run_config is None: - BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) - BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) - BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) - num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 - num_stages = 2 - else: - BLOCK_D1 = run_config["BLOCK_D1"] - BLOCK_D2 = run_config["BLOCK_D2"] - BLOCK_D3 = run_config["BLOCK_D3"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) - num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) - num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) - num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 - - MAX_GRID_SIZE = 65535 - - for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): - pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) - pair_chunk_size = pair_chunk_end - pair_chunk_start - - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - - grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) - - _copy_buffer_p2p_3d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - pair_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - src_buffer.stride(3), - src_buffer.stride(4), - d1_size, - d2_size, - d3_size, - num_blocks_d2, - num_blocks_d3, - BLOCK_D1=BLOCK_D1, - BLOCK_D2=BLOCK_D2, - BLOCK_D3=BLOCK_D3, - num_warps=num_warps, - num_stages=num_stages, - ) - - -@autotune( - kernel_name="mamba_buffer_broadcast_3d:v1", - configs_gen_func=_get_buffer_copy_3d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, - mutates_args=["dst_buffer"], -) -def _copy_buffer_broadcast_3d_autotuned( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, - run_config: dict = None, -): - """Auto-tuned 3D data buffer broadcast (5D tensor, one src to multiple dst).""" - num_src = src_indexes.shape[0] - layer_num = src_buffer.shape[0] - d1_size = src_buffer.shape[2] - d2_size = src_buffer.shape[3] - d3_size = src_buffer.shape[4] - num_dst_per_src = dst_indexes.shape[0] // num_src - - if run_config is None: - BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) - BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) - BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) - num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 - num_stages = 2 - else: - BLOCK_D1 = run_config["BLOCK_D1"] - BLOCK_D2 = run_config["BLOCK_D2"] - BLOCK_D3 = run_config["BLOCK_D3"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) - num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) - num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) - num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 - - MAX_GRID_SIZE = 65535 - - for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): - src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) - src_chunk_size = src_chunk_end - src_chunk_start +# ==================== Unified Interface ==================== - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - grid = (src_chunk_size, layer_chunk_size, num_blocks_total) +def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: + """Flatten all dimensions after [layer_num, buffer_size] into one. - _copy_buffer_broadcast_3d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - src_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - src_buffer.stride(3), - src_buffer.stride(4), - d1_size, - d2_size, - d3_size, - num_blocks_d2, - num_blocks_d3, - num_dst_per_src, - BLOCK_D1=BLOCK_D1, - BLOCK_D2=BLOCK_D2, - BLOCK_D3=BLOCK_D3, - num_warps=num_warps, - num_stages=num_stages, - ) - - -# ==================== Unified Interface ==================== + For a contiguous buffer of shape [L, B, d1, d2, ...], returns a view + of shape [L, B, d1*d2*...]. This is a zero-copy operation. + """ + if buffer.ndim == 3: + return buffer + L, B = buffer.shape[:2] + return buffer.view(L, B, -1) def copy_buffer_p2p( @@ -889,7 +294,8 @@ def copy_buffer_p2p( """ Copy buffers from source indices to destination indices with auto-tuning. - Supports 3D (conv states), 4D (standard buffers), and 5D (SSM states) buffers. + Supports any buffer shape [layer_num, buffer_size, ...] as long as the + trailing dimensions are contiguous (which is the default for torch.zeros). Args: src_buffer: Source buffer tensor [layer_num, buffer_size, ...] @@ -901,20 +307,9 @@ def copy_buffer_p2p( assert src_indexes.shape == dst_indexes.shape assert len(src_indexes.shape) == 1 - if len(src_buffer.shape) == 3: - # 1D case: (layer_num, buffer_size, d) - _copy_buffer_p2p_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) - - elif len(src_buffer.shape) == 4: - # 2D case: (layer_num, buffer_size, d1, d2) - _copy_buffer_p2p_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) - - elif len(src_buffer.shape) == 5: - # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory - _copy_buffer_p2p_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) - - else: - raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") + src_flat = _flatten_trailing_dims(src_buffer) + dst_flat = _flatten_trailing_dims(dst_buffer) + _copy_buffer_p2p_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) def copy_buffer_broadcast( @@ -939,23 +334,11 @@ def copy_buffer_broadcast( assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D, got shape {dst_indexes.shape}" num_src = src_indexes.shape[0] - assert num_src == dst_indexes.shape[0], f"Mismatch: src_indexes {num_src} vs dst_indexes {dst_indexes.shape[0]}" # Flatten dst_indexes for kernel dst_indexes_flat = dst_indexes.reshape(-1).contiguous() - if len(src_buffer.shape) == 3: - # 1D case - _copy_buffer_broadcast_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) - - elif len(src_buffer.shape) == 4: - # 2D case - _copy_buffer_broadcast_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) - - elif len(src_buffer.shape) == 5: - # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory - _copy_buffer_broadcast_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) - - else: - raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") + src_flat = _flatten_trailing_dims(src_buffer) + dst_flat = _flatten_trailing_dims(dst_buffer) + _copy_buffer_broadcast_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat) diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 348b14192c..272a999bb1 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -6,6 +6,7 @@ from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.common.allocator_utils import TokenAllocator +from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_buffer_p2p, copy_buffer_broadcast from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt @@ -56,67 +57,20 @@ def get_mamba_cache(self, layer_idx: int): return conv_state, ssm_state def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): - """ - Copy buffers from source indices to destination indices using optimized Triton kernel. - - Args: - src_buffer_indexes: Source buffer indices (1D tensor) - dst_buffer_indexes: Destination buffer indices (1D tensor) - """ - assert src_buffer_indexes.dim() == 1 - assert dst_buffer_indexes.dim() == 1 - assert src_buffer_indexes.shape[0] == dst_buffer_indexes.shape[0] - - # Validate indices are within valid range [0, size] (size+1 is the buffer dim) - max_valid_idx = self.size # HOLD_BUFFER_INDEX = size is valid - src_max = src_buffer_indexes.max().item() if src_buffer_indexes.numel() > 0 else -1 - src_min = src_buffer_indexes.min().item() if src_buffer_indexes.numel() > 0 else -1 - dst_max = dst_buffer_indexes.max().item() if dst_buffer_indexes.numel() > 0 else -1 - dst_min = dst_buffer_indexes.min().item() if dst_buffer_indexes.numel() > 0 else -1 - - if src_min < 0 or src_max > max_valid_idx or dst_min < 0 or dst_max > max_valid_idx: - logger.error( - f"Invalid buffer indices: src=[{src_min}, {src_max}], dst=[{dst_min}, {dst_max}], " - f"valid range=[0, {max_valid_idx}], conv shape={self.conv_state_cache.buffer.shape}, " - f"ssm shape={self.ssm_state_cache.buffer.shape}" - ) - raise ValueError("Invalid buffer indices for copy_buffer_p2p") - - # Use PyTorch advanced indexing for buffer copy (safer than Triton for complex shapes) - # The buffer shape is [layer_num, buffer_size, *shape] - # We need to copy all layers for the given buffer indices - src_idx = src_buffer_indexes.long() - dst_idx = dst_buffer_indexes.long() - - # Copy conv_state: [layer_num, buffer_size, d1, d2] - self.conv_state_cache.buffer[:, dst_idx, ...] = self.conv_state_cache.buffer[:, src_idx, ...] - - # Copy ssm_state: [layer_num, buffer_size, d1, d2, d3] - self.ssm_state_cache.buffer[:, dst_idx, ...] = self.ssm_state_cache.buffer[:, src_idx, ...] - return + copy_buffer_p2p( + self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes + ) + copy_buffer_p2p( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes + ) def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): - assert src_buffer_index.dim() == 1 - assert dst_buffer_indexes.dim() == 2 - assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] - - # Use PyTorch advanced indexing for broadcast copy - # src_buffer_index: [num_src] - # dst_buffer_indexes: [num_src, num_dst_per_src] - src_idx = src_buffer_index.long() - dst_idx = dst_buffer_indexes.long() - - # Broadcast each source to all its destinations - # For each (src, dst_group), copy buffer[src] to buffer[dst1], buffer[dst2], ... - num_src, num_dst_per_src = dst_idx.shape - for i in range(num_src): - src = src_idx[i : i + 1] # Keep as 1D tensor with 1 element - dsts = dst_idx[i, :] # 1D tensor with num_dst_per_src elements - # Copy conv_state - self.conv_state_cache.buffer[:, dsts, ...] = self.conv_state_cache.buffer[:, src, ...] - # Copy ssm_state - self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] - return + copy_buffer_broadcast( + self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) + copy_buffer_broadcast( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): """ @@ -125,22 +79,9 @@ def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_i This is used for MTP mode where each buffer maintains its own independent conv state, but SSM states need to be synchronized. """ - assert src_buffer_index.dim() == 1 - assert dst_buffer_indexes.dim() == 2 - assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] - - # Use PyTorch advanced indexing for SSM-only broadcast copy - src_idx = src_buffer_index.long() - dst_idx = dst_buffer_indexes.long() - - # Broadcast each source to all its destinations (SSM only) - num_src = dst_idx.shape[0] - for i in range(num_src): - src = src_idx[i : i + 1] - dsts = dst_idx[i, :] - # Only copy ssm_state, NOT conv_state - self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] - return + copy_buffer_broadcast( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) def free(self, free_index: Union[torch.Tensor, List[int]]): """ diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..825a985b46 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -227,14 +227,14 @@ def _init_datatype(self): def rot_pos_emb(self, grid_thw): pos_ids = [] s = self.spatial_merge_size - for _, h, w in grid_thw: + for t, h, w in grid_thw: pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() cos_full, sin_full = self.rotary_pos_emb(max_grid_size) diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index f2cd38ec8e..bc313fe467 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -187,7 +187,10 @@ def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torc if image.mode != "RGB": image = image.convert("RGB") image_arr = np.asarray(image, dtype=np.uint8) - image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) + # Copy to ensure writable array (avoids PyTorch warning for read-only NumPy arrays) + image_data = ( + torch.from_numpy(image_arr.copy()).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) + ) grouped_images, grouped_images_index = group_images_by_shape( [image_data], disable_grouping=self.disable_grouping diff --git a/lightllm/models/qwen35_moe/model.py b/lightllm/models/qwen35_moe/model.py new file mode 100644 index 0000000000..ee149f3a81 --- /dev/null +++ b/lightllm/models/qwen35_moe/model.py @@ -0,0 +1,42 @@ +import os +import json + +from lightllm.models.qwen3_vl.model import QWen3VLTokenizer +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.common.build_utils import repair_config +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights + + +class QWen35Tokenizer(QWen3VLTokenizer): + def __init__(self, tokenizer=None, image_processor=None, **kwargs): + super().__init__(tokenizer, image_processor, **kwargs) + + +@ModelRegistry(["qwen3_5"], is_multimodal=True) +class Qwen35MoeTpPartModel(Qwen3NextTpPartModel): + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["text_config"] + + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + repair_config(self.config, same_names=["intermediate_size", "moe_intermediate_size"]) + + # Handle fine-tuning config if present + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + + def _load_hf_weights(self): + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + return diff --git a/lightllm/models/qwen3_5/__init__.py b/lightllm/models/qwen3_5/__init__.py new file mode 100644 index 0000000000..47667a92d5 --- /dev/null +++ b/lightllm/models/qwen3_5/__init__.py @@ -0,0 +1,17 @@ +""" +Qwen3.5 Multimodal Model Module + +Provides Qwen3.5 multimodal models with hybrid attention and vision-language support. +""" + +from .model import ( + Qwen3_5TpPartModel, + Qwen3_5MOETpPartModel, + QWen3_5Tokenizer, +) + +__all__ = [ + "Qwen3_5TpPartModel", + "Qwen3_5MOETpPartModel", + "QWen3_5Tokenizer", +] diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py new file mode 100644 index 0000000000..9ce407cacf --- /dev/null +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -0,0 +1,110 @@ +""" +Qwen3.5 Multimodal Inference State + +This module provides inference state for Qwen3.5 multimodal model that combines: +- Qwen3Next features (output gating, MTP-aware batching, hybrid attention buffer management) +- Qwen3VL multimodal support (mrope position encoding for images/videos) +""" + +import torch +from typing import List + +from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen35InferStateInfo(Qwen2VLInferStateInfo): + """ + Inference state for Qwen3.5 multimodal model with: + - gate_value attribute for output gating in full attention layers + - MTP-aware batching for multi-token prediction + - Custom buffer management for hybrid attention (full + linear) + - mrope position encoding support for multimodal inputs + """ + + def __init__(self): + super().__init__() + # For output gating in full attention layers (from Qwen3Next) + self.gate_value = None + # MTP-aware attributes (from Qwen3Next) + self.b_att_seq_len = None + self.att_batch_size = None + self.real_req_idx = None + self.mtp_buffer_idx_list = None + self.b_buffer_idx = None + + def _compute_mrope_delta(self, images: List) -> int: + """Compute the position delta for mrope based on image tokens. + + The position delta is the sum of all image position deltas (grid_thwd[3]) + which accounts for the extra position IDs consumed by multimodal content. + """ + position_delta = 0 + for image in images: + position_delta += image["grid_thwd"][3] + return position_delta + + def init_some_extra_state(self, model): + """Initialize Qwen3.5-specific state including mrope and MTP support""" + # First, initialize mrope position encoding using parent class + # which now has the corrected delta computation + rope_scaling = model.config.get("rope_scaling", {}) + self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) + + # Call the grandparent's (LlamaInferStateInfo) init_some_extra_state first + # to set up basic state + from lightllm.common.basemodel.infer_struct import InferStateInfo + + InferStateInfo.init_some_extra_state(self, model) + + # Now handle mrope position encoding with corrected delta computation + if self.is_prefill: + self.position_ids = self.get_mrope_position(self.multimodal_params) + else: + # Decode phase: compute correct mrope delta + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + b_position_delta[batch_idx] = self._compute_mrope_delta(p.get("images", [])) + + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1) + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids] + self.position_sin = model._sin_cached[self.position_ids] + + # Now handle MTP-aware batching (from Qwen3Next) + args_mtp_step = get_env_start_args().mtp_step + mtp_size = args_mtp_step + 1 + + if self.is_prefill: + # Prefill: Standard initialization + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() + else: + # Decode: MTP-aware handling + # In MTP mode, each request has (mtp_step + 1) tokens + # att_batch_size is the number of unique requests + self.att_batch_size = self.batch_size // mtp_size + + # Use only the sequence lengths for the last token of each MTP group + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step::mtp_size].contiguous() + self.real_req_idx = self.b_req_idx[args_mtp_step::mtp_size] + else: + self.b_att_seq_len = self.b_seq_len + self.real_req_idx = self.b_req_idx + + # Buffer indices for Mamba cache (conv and SSM states) + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.real_req_idx, :].flatten().contiguous() + + # Create per-step buffer indices for MTP + if args_mtp_step > 0: + buffer_idx_list = [] + for step_id in range(mtp_size): + buffer_idx_list.append(self.b_buffer_idx[step_id::mtp_size].tolist()) + self.mtp_buffer_idx_list = torch.tensor( + buffer_idx_list, dtype=torch.int32, device=self.b_buffer_idx.device + ) + + return diff --git a/lightllm/models/qwen3_5/layer_infer/__init__.py b/lightllm/models/qwen3_5/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..3bbc0ee3be --- /dev/null +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -0,0 +1,121 @@ +import torch +import torch.distributed as dist +from typing import Tuple + +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( + Qwen3NextFullAttentionTransformerLayerInfer, + Qwen3NextGatedDeltaNetTransformerLayerInfer, +) +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen35FullAttentionTransformerLayerInfer(Qwen3NextFullAttentionTransformerLayerInfer): + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + # Initialize mrope section from config + rope_scaling = network_config.get("rope_scaling", {}) + mrope_section = rope_scaling.get("mrope_section", [11, 11, 10]) + self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + + # Q and gate projection + if not infer_state.is_prefill: + q_gate_buf = self._get_decode_buffer( + "q_gate_out", + (self._graph_max_batch_size, self.tp_q_gate_dim), + input.dtype, + input.device, + )[: input.size(0)] + q_gate = layer_weight.q_gate_proj.mm(input, out=q_gate_buf) + kv_buf = self._get_decode_buffer( + "kv_out", + (self._graph_max_batch_size, self.tp_kv_dim), + input.dtype, + input.device, + )[: input.size(0)] + kv_out = layer_weight.kv_proj.mm(input, out=kv_buf) + else: + q_gate = layer_weight.q_gate_proj.mm(input) + kv_out = layer_weight.kv_proj.mm(input) + + q_dim = self.tp_q_head_num_ * self.head_dim_ + q = q_gate[:, :q_dim].contiguous() + # In-place sigmoid for gate + infer_state.gate_value = q_gate[:, q_dim:].sigmoid_() + cache_kv = kv_out.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + # Q normalization (in-place) + from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + gemma_rmsnorm_forward( + q.view(-1, self.head_dim_), + layer_weight.q_norm_weight_.weight, + eps=self.eps_, + out=q.view(-1, self.head_dim_), + ) + + k_input = cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]) + if not infer_state.is_prefill: + k_normed = self._get_decode_buffer( + "k_norm_out", + (self._graph_max_batch_size * self.tp_k_head_num_, cache_kv.shape[-1]), + k_input.dtype, + k_input.device, + )[: k_input.shape[0]] + gemma_rmsnorm_forward(k_input, layer_weight.k_norm_weight_.weight, eps=self.eps_, out=k_normed) + else: + k_normed = gemma_rmsnorm_forward(k_input, layer_weight.k_norm_weight_.weight, eps=self.eps_) + cache_kv[:, : self.tp_k_head_num_, :] = k_normed.view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + + if hasattr(infer_state, "position_cos") and infer_state.position_cos is not None: + rotary_dim = int(self.head_dim_ * self.partial_rotary_factor) + + q_rotary = q.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, :rotary_dim].contiguous() + k_rotary = cache_kv[:, : self.tp_k_head_num_, :rotary_dim].contiguous() + + mrope_triton_fused( + q_rotary, + k_rotary, + infer_state.position_cos, + infer_state.position_sin, + self.mrope_section, + is_interleaved=True, # Qwen3 uses interleaved mrope + ) + + q.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, :rotary_dim] = q_rotary + cache_kv[:, : self.tp_k_head_num_, :rotary_dim] = k_rotary + else: + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, + ) + + return q, cache_kv + + +class Qwen35GatedDeltaNetTransformerLayerInfer(Qwen3NextGatedDeltaNetTransformerLayerInfer): + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + rope_scaling = network_config.get("rope_scaling", {}) + mrope_section = rope_scaling.get("mrope_section", [11, 11, 10]) + self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") diff --git a/lightllm/models/qwen3_5/layer_weights/__init__.py b/lightllm/models/qwen3_5/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..ca1f9d992e --- /dev/null +++ b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py @@ -0,0 +1,166 @@ +import torch + +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def split_fused_expert_weights(weights, layer_num, moe_intermediate_size): + layer_prefix = f"model.layers.{layer_num}." + keys = list(weights.keys()) + gate_up_count = 0 + down_count = 0 + num_experts = 0 + + for k in keys: + if not k.startswith(layer_prefix): + continue + + if "mlp.experts.gate_up_proj" in k: + fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size] + num_experts = fused_weight.shape[0] + + prefix = k.rsplit(".gate_up_proj", 1)[0] + + gate_weight = fused_weight[:, :moe_intermediate_size, :] + up_weight = fused_weight[:, moe_intermediate_size:, :] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] + weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] + + gate_up_count += 1 + + elif "mlp.experts.down_proj" in k: + down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] + num_experts = down_weight.shape[0] + + prefix = k.rsplit(".down_proj", 1)[0] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx] + + down_count += 1 + + +class Qwen35NextFullAttentionTransformerLayerWeight(Qwen3NextFullAttentionTransformerLayerWeight): + def load_hf_weights(self, weights): + self._split_fused_expert_weights(weights) + super().load_hf_weights(weights) + + def _split_fused_expert_weights(self, weights): + moe_intermediate_size = self.network_config_.get("moe_intermediate_size") + if moe_intermediate_size is None: + moe_intermediate_size = self.network_config_.get("intermediate_size") + + if moe_intermediate_size is None: + logger.warning( + f"Layer {self.layer_num_}: Cannot find moe_intermediate_size in config, " + "skipping fused expert weight splitting" + ) + return + + layer_prefix = f"model.layers.{self.layer_num_}.mlp.experts" + has_fused_weights = any(layer_prefix in k and ("gate_up_proj" in k or "down_proj" in k) for k in weights.keys()) + + if has_fused_weights: + split_fused_expert_weights(weights, self.layer_num_, moe_intermediate_size) + + +class Qwen35NextGatedDeltaNetTransformerLayerWeight(Qwen3NextGatedDeltaNetTransformerLayerWeight): + def _init_gdn_weight(self): + # Initialize everything from parent first, then override only linear_in_proj. + super()._init_gdn_weight() + + prefix = f"model.layers.{self.layer_num_}.linear_attn" + hidden_size = self.network_config_["hidden_size"] + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + + # NOTE: keep grouped layout directly (q, k, v, z, b, a). + self.linear_in_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[ + qk_dim, + qk_dim, + v_dim, + v_dim, + self.linear_num_v_heads, + self.linear_num_v_heads, + ], + weight_names=[ + f"{prefix}.in_proj_q.weight", + f"{prefix}.in_proj_k.weight", + f"{prefix}.in_proj_v.weight", + f"{prefix}.in_proj_z.weight", + f"{prefix}.in_proj_b.weight", + f"{prefix}.in_proj_a.weight", + ], + data_type=self.data_type_, + quant_method=self.get_quant_method("in_proj_weight"), + ) + + def load_hf_weights(self, weights): + self._split_fused_expert_weights(weights) + super().load_hf_weights(weights) + + def _preprocess_weight(self, weights): + # Keep parent conv1d preprocessing path. + linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" + linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" + + if linear_conv1d_weight_name in weights: + weights[linear_conv1d_weight_name] = self._parse_linear_conv1d( + weights[linear_conv1d_weight_name].squeeze(1) + ) + if linear_conv1d_bias_name in weights: + weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name]) + + self._split_linear_in_proj_qkv(weights) + + def _split_linear_in_proj_qkv(self, weights): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + qkv_name = f"{prefix}.in_proj_qkv.weight" + if qkv_name not in weights: + return + + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + expected_rows = 2 * qk_dim + v_dim + + qkv = weights[qkv_name] + if qkv.shape[0] != expected_rows: + logger.warning( + f"Layer {self.layer_num_}: unexpected in_proj_qkv shape " + f"{tuple(qkv.shape)}, expected first dim {expected_rows}; skip split" + ) + return + + q, k, v = torch.split(qkv, [qk_dim, qk_dim, v_dim], dim=0) + weights[f"{prefix}.in_proj_q.weight"] = q + weights[f"{prefix}.in_proj_k.weight"] = k + weights[f"{prefix}.in_proj_v.weight"] = v + del weights[qkv_name] + + def _split_fused_expert_weights(self, weights): + moe_intermediate_size = self.network_config_.get("moe_intermediate_size") + if moe_intermediate_size is None: + moe_intermediate_size = self.network_config_.get("intermediate_size") + + if moe_intermediate_size is None: + logger.warning( + f"Layer {self.layer_num_}: Cannot find moe_intermediate_size in config, " + "skipping fused expert weight splitting" + ) + return + + layer_prefix = f"model.layers.{self.layer_num_}.mlp.experts" + has_fused_weights = any(layer_prefix in k and ("gate_up_proj" in k or "down_proj" in k) for k in weights.keys()) + + if has_fused_weights: + split_fused_expert_weights(weights, self.layer_num_, moe_intermediate_size) diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py new file mode 100644 index 0000000000..fdbccdf787 --- /dev/null +++ b/lightllm/models/qwen3_5/model.py @@ -0,0 +1,229 @@ +import os +import json +import time +import gc +from safetensors import safe_open +from tqdm import tqdm +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35NextFullAttentionTransformerLayerWeight, + Qwen35NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.qwen3_vl.model import QWen3VLTokenizer +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( + Qwen35FullAttentionTransformerLayerInfer, + Qwen35GatedDeltaNetTransformerLayerInfer, +) +from lightllm.models.qwen3_5.infer_struct import Qwen35InferStateInfo +from lightllm.common.build_utils import repair_config +from lightllm.utils.log_utils import init_logger +import lightllm.utils.petrel_helper as utils + +logger = init_logger(__name__) + + +class QWen3_5Tokenizer(QWen3VLTokenizer): + """ + Tokenizer for Qwen3.5 multimodal model. + + Inherits all multimodal tokenization logic from Qwen3VL, + including image and video token handling. + """ + + def __init__(self, tokenizer=None, image_processor=None, **kwargs): + super().__init__(tokenizer, image_processor, **kwargs) + + +@ModelRegistry(["qwen3_5"], is_multimodal=True) +class Qwen3_5TpPartModel(Qwen3NextTpPartModel): + """ + Qwen3.5 Multimodal Model (Dense Variant) + + This model combines: + - Hybrid attention from Qwen3Next (Gated Delta Networks + Full Attention) + - Multimodal capabilities from Qwen3VL (image/video processing) + - Dense MLP layers (non-MoE) + + Architecture: + - Every Nth layer uses full attention (config: full_attention_interval) + - Other layers use linear attention (Gated Delta Networks) + - Vision encoder processes images/videos before text model + - Multimodal embeddings merged with text embeddings + """ + + # Override to use multimodal pre-layer for vision processing + pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer + + # Override to use multimodal pre/post weights (includes vision weights) + pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + + # Override to use Qwen3.5 infer state with mrope support + infer_state_class = Qwen35InferStateInfo + + def __init__(self, kvargs): + """ + Initialize Qwen3.5 model. + + Args: + kvargs: Dictionary containing: + - weight_dir: Path to model weights + - max_total_token_num: Maximum total tokens + - Additional model configuration + """ + super().__init__(kvargs) + logger.info("Initialized Qwen3.5 multimodal model") + + def _init_config(self): + """ + Load and parse Qwen3.5 configuration. + + Qwen3.5 uses a nested config structure: + { + "model_type": "qwen3_5", + "text_config": { ... }, + "vision_config": { ... } + } + + This method extracts the text_config for the language model + and stores vision_config for multimodal processing. + """ + config_path = os.path.join(self.weight_dir_, "config.json") + + with open(config_path, "r") as json_file: + all_config = json.load(json_file) + + # Extract text config for language model + self.config = all_config["text_config"] + + # Store vision config for multimodal components + self.vision_config = all_config.get("vision_config", None) + + if self.vision_config is None: + logger.warning("No vision_config found in checkpoint. " "Multimodal features may not work correctly.") + + # Apply standard config repairs + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + + # Qwen3.5 uses layer_types array instead of decoder_sparse_step for MoE placement + # Set default for decoder_sparse_step (used by inherited Qwen3Next weight initialization) + # Default to 1 meaning all layers with num_experts > 0 use MoE + if "decoder_sparse_step" not in self.config: + self.config["decoder_sparse_step"] = 1 + + # Ensure mlp_only_layers exists (default to empty list) + if "mlp_only_layers" not in self.config: + self.config["mlp_only_layers"] = [] + + # Qwen3.5 MoE uses moe_intermediate_size instead of intermediate_size + # Set intermediate_size for compatibility with base layer weight classes + if "intermediate_size" not in self.config: + if "moe_intermediate_size" in self.config: + self.config["intermediate_size"] = self.config["moe_intermediate_size"] + else: + # Default fallback: 4x hidden_size (common in transformer architectures) + self.config["intermediate_size"] = self.config.get("hidden_size", 4096) * 4 + + # Qwen3.5 stores RoPE config under text_config.rope_parameters. + # Qwen3Next/llama infer path expects flattened keys like rope_theta and + # partial_rotary_factor on the main config dict. + rope_parameters = self.config.get("rope_parameters") + if isinstance(rope_parameters, dict): + if "rope_theta" in rope_parameters and "rope_theta" not in self.config: + self.config["rope_theta"] = rope_parameters["rope_theta"] + if "partial_rotary_factor" in rope_parameters and "partial_rotary_factor" not in self.config: + self.config["partial_rotary_factor"] = rope_parameters["partial_rotary_factor"] + # Preserve the richer RoPE metadata in the expected field when absent. + if "rope_scaling" not in self.config: + self.config["rope_scaling"] = rope_parameters + + # MoE routing parameters - set defaults for Qwen3.5 compatibility + if "norm_topk_prob" not in self.config: + self.config["norm_topk_prob"] = True # Standard default for MoE models + + # Handle fine-tuning config if present + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + + # Calculate num_kv_heads for KV cache memory management + # Required by parent class _init_mem_manager() in Qwen3NextTpPartModel + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + self.trans_layers_weight = [ + ( + Qwen35NextFullAttentionTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + if (i + 1) % num_full_attention_layers == 0 + else Qwen35NextGatedDeltaNetTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + ) + for i in range(self.config["n_layer"]) + ] + + def _init_infer_layer(self): + """ + Initialize inference layers for Qwen3.5 multimodal model. + + Uses mrope-enabled transformer layers to properly handle image/video + tokens with 3D position encoding (temporal, height, width). + + This overrides the parent class to use Qwen35* layer classes instead + of Qwen3Next* layer classes. + """ + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + + self.layers_infer = [ + ( + Qwen35FullAttentionTransformerLayerInfer(i, network_config=self.config) + if (i + 1) % num_full_attention_layers == 0 + else Qwen35GatedDeltaNetTransformerLayerInfer(i, network_config=self.config) + ) + for i in range(self.config["n_layer"]) + ] + + +@ModelRegistry(["qwen3_5_moe"], is_multimodal=True) +class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): + """ + Qwen3.5-MoE Multimodal Model (Mixture of Experts Variant) + + Extends Qwen3.5 with sparse expert routing: + - Same hybrid attention architecture as Qwen3.5 + - MoE layers replace dense MLP layers + - Expert routing handled by Qwen3NextSparseMoeBlock (inherited) + + The MoE variant is automatically configured by inheriting from + Qwen3NextTpPartModel, which inherits from Qwen3MOEModel. + + No additional configuration needed - MoE support is built-in. + """ + + def __init__(self, kvargs): + """ + Initialize Qwen3.5-MoE model. + + Args: + kvargs: Dictionary containing: + - weight_dir: Path to model weights + - max_total_token_num: Maximum total tokens + - Additional model configuration + """ + super().__init__(kvargs) + logger.info("Initialized Qwen3.5-MoE multimodal model with expert routing") diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index f770459a55..5356da4caf 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -15,8 +15,28 @@ def init_tokenizer(args): async def build_prompt(request, tools) -> str: global tokenizer + import json + # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + + # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility + # Qwen's chat template expects arguments to be a dict (uses |items filter) + # but OpenAI format sends arguments as a JSON string + for msg in messages: + tool_calls = msg.get("tool_calls") + if tool_calls and isinstance(tool_calls, list): + for tool_call in tool_calls: + func = tool_call.get("function") + if func and isinstance(func, dict): + args = func.get("arguments") + if isinstance(args, str) and args: + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, TypeError): + # Keep original string if not valid JSON + pass + kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -32,7 +52,8 @@ async def build_prompt(request, tools) -> str: # This except branch will be triggered when the chosen model # has a different tools input format that is not compatiable # with openAI's apply_chat_template tool_call format, like Mistral. - tools = [t if "function" in t else {"function": t} for t in tools] + if tools is not None: + tools = [t if "function" in t else {"function": t} for t in tools] input_str = tokenizer.apply_chat_template( **kwargs, tokenize=True, diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index d955aa6a87..99331c061c 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -333,15 +333,31 @@ class SamplingParams(ctypes.Structure): def init(self, tokenizer, **kwargs): super().__init__() + # 移除kwargs中为null的参数,避免覆盖默认值 + kwargs = {k: v for k, v in kwargs.items() if v is not None} + self.best_of = kwargs.get("best_of", 1) self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) + do_sample = kwargs.get("do_sample", SamplingParams._do_sample) + self.do_sample = False if do_sample is None else do_sample + + presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) + self.presence_penalty = 0.0 if presence_penalty is None else presence_penalty + + frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) + self.frequency_penalty = 0.0 if frequency_penalty is None else frequency_penalty + + repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) + self.repetition_penalty = 1.0 if repetition_penalty is None else repetition_penalty + + temperature = kwargs.get("temperature", SamplingParams._temperature) + self.temperature = 1.0 if temperature is None else temperature + + top_p = kwargs.get("top_p", SamplingParams._top_p) + self.top_p = 1.0 if top_p is None else top_p + + top_k = kwargs.get("top_k", SamplingParams._top_k) + self.top_k = -1 if top_k is None else top_k self.ignore_eos = kwargs.get("ignore_eos", False) self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) self.max_new_tokens = kwargs.get("max_new_tokens", 16) @@ -408,13 +424,35 @@ def init(self, tokenizer, **kwargs): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() + # Some checkpoints store null sampling fields in generation_config.json. + # Keep robust numeric defaults instead of propagating None into ctypes fields. cls._do_sample = generation_cfg.get("do_sample", False) + if cls._do_sample is None: + cls._do_sample = False + cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) + if cls._presence_penalty is None: + cls._presence_penalty = 0.0 + cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) + if cls._frequency_penalty is None: + cls._frequency_penalty = 0.0 + cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) + if cls._repetition_penalty is None: + cls._repetition_penalty = 1.0 + cls._temperature = generation_cfg.get("temperature", 1.0) + if cls._temperature is None: + cls._temperature = 1.0 + cls._top_p = generation_cfg.get("top_p", 1.0) + if cls._top_p is None: + cls._top_p = 1.0 + cls._top_k = generation_cfg.get("top_k", -1) + if cls._top_k is None: + cls._top_k = -1 except: pass diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 9214715b1d..4c494d138b 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import json import orjson import logging @@ -1443,6 +1444,228 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text="", calls=calls) +class Qwen3CoderDetector(BaseFormatDetector): + """ + Detector for Qwen3-Coder XML-style function call format. + + Format Structure: + ``` + + + + value1 + + + value2 + + + + ``` + + Key differences from Qwen25Detector (JSON-based): + - Parameters are XML key-value pairs, not JSON objects + - Function name is embedded in the tag attribute + - Values need schema-aware type conversion (string by default) + + Reference: https://docs.vllm.ai/projects/recipes/en/latest/Qwen/Qwen3-Coder-480B-A35B.html + """ + + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + self.tool_call_separator = "\n" + + # Regex patterns + self.tool_call_block_regex = re.compile(r"(.*?)", re.DOTALL) + self.function_regex = re.compile(r"||(?=)|$)", re.DOTALL + ) + self._normal_text_buffer = "" + + def has_tool_call(self, text: str) -> bool: + return " Dict: + """Extract parameter type configuration from tool definitions.""" + for tool in tools: + if tool.function.name == func_name and tool.function.parameters: + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + return {} + + def _convert_param_value(self, value: str, param_name: str, param_config: Dict, func_name: str) -> Any: + """Convert parameter value based on schema type. Safe alternative to eval().""" + if value.lower() == "null": + return None + + if param_name not in param_config: + return value + + prop = param_config.get(param_name, {}) + param_type = str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + + if param_type in ("string", "str", "enum"): + return value + elif param_type.startswith("int") or param_type == "integer": + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ("number", "float", "double"): + try: + fv = float(value) + return int(fv) if fv == int(fv) else fv + except (ValueError, TypeError): + return value + elif param_type in ("boolean", "bool"): + return value.lower() == "true" + elif param_type in ("object", "array"): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError, ValueError): + try: + return ast.literal_eval(value) + except (ValueError, SyntaxError, TypeError): + return value + return value + + def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional[ToolCallItem]: + """Parse a single ... block into a ToolCallItem.""" + try: + end_index = function_str.index(">") + except ValueError: + return None + + func_name = function_str[:end_index].strip() + tool_indices = self._get_tool_indices(tools) + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + return None + + parameters_text = function_str[end_index + 1 :] + param_config = self._get_param_config(func_name, tools) + param_dict = {} + + for match in self.parameter_regex.findall(parameters_text): + try: + idx = match.index(">") + except ValueError: + continue + param_name = match[:idx].strip() + param_value = match[idx + 1 :] + # Strip leading/trailing newlines from value + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) + + return ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=json.dumps(param_dict, ensure_ascii=False), + ) + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + + if " StreamingParseResult: + """Streaming incremental parsing for Qwen3-Coder XML tool calls.""" + self._buffer += new_text + current_text = self._buffer + + if not self.has_tool_call(current_text): + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + return StreamingParseResult() + self._buffer = "" + cleaned = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=cleaned) + + # Check for complete tool call blocks + if self.eot_token in current_text: + result = self.detect_and_parse(current_text, tools) + last_end = current_text.rfind(self.eot_token) + if last_end != -1: + self._buffer = current_text[last_end + len(self.eot_token) :].lstrip() + else: + self._buffer = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + return result + + # Partial tool call - try to extract function name for early streaming + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls = [] + tool_call_start = current_text.find(self.bot_token) + if tool_call_start == -1: + return StreamingParseResult() + + content_after = current_text[tool_call_start + len(self.bot_token) :] + func_prefix = "") + if gt_pos == -1: + return StreamingParseResult() + + func_name = after_func[:gt_pos].strip() + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + if func_name and func_name in self._tool_indices and not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = {"name": func_name, "arguments": {}} + + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1461,6 +1684,7 @@ class FunctionCallParser: "mistral": MistralDetector, "qwen": Qwen25Detector, "qwen25": Qwen25Detector, + "qwen3_coder": Qwen3CoderDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 88b099459b..4403dba517 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -31,6 +31,12 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 + # Used by hybrid attention models (e.g., Qwen3Next) to track + # a per-request buffer_idx alongside the token-level KV cache. + # Pure attention models keep buffer_idx as None. + self.buffer_idx = None + self.buffer_time = time_gen.generate_time_id() + def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) @@ -78,6 +84,9 @@ def remove_child(self, child_node: "TreeNode"): def update_time(self): self.time_id = time_gen.generate_time_id() + def update_buffer_time(self): + self.buffer_time = time_gen.generate_time_id() + def is_leaf(self): return len(self.children) == 0 @@ -103,10 +112,10 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None, kv_cache_mem_manager=None): from lightllm.common.kv_cache_mem_manager import MemoryManager - self.mem_manager: MemoryManager = mem_manager + self.mem_manager: MemoryManager = kv_cache_mem_manager if kv_cache_mem_manager is not None else mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 @@ -359,6 +368,7 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: or parent_node.ref_counter != 0 or len(parent_node.children) != 1 or child_node.ref_counter != 0 + or parent_node.buffer_idx is not None ): return None diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538f..57241de967 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -7,10 +7,11 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any -from lightllm.common.req_manager import ReqManager +from lightllm.common.req_manager import ReqManager, ReqManagerForMamba from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.common.basemodel.infer_lock import g_infer_state_lock @@ -32,10 +33,13 @@ class InferenceContext: infer_req_ids = None vocab_size = None cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None + mtp_step: int = 0 overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream + use_mamba_model: bool = False + def register( self, backend, @@ -43,6 +47,7 @@ def register( radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int, + use_mamba_model: bool = False, ): self.args = get_env_start_args() from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend @@ -57,6 +62,14 @@ def register( self.infer_req_ids = [] self.vocab_size = vocab_size + + self.use_mamba_model = use_mamba_model + if self.use_mamba_model: + assert self.radix_cache is None or isinstance( + self.radix_cache, HybridRadixCache + ), "Mamba model only support HybridRadixCache" + assert isinstance(self.req_manager, ReqManagerForMamba), "Mamba model only support ReqManagerForMamba" + self.mtp_step = get_env_start_args().mtp_step return def init_cpu_embed_cache_client(self): @@ -73,6 +86,27 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: self.cpu_kv_cache_stream = torch.cuda.Stream() return self.cpu_kv_cache_stream + def _alloc_and_copy_req_buffers(self, req_objs: List["InferReq"]) -> None: + """Allocate and copy buffers for requests. Delegates to req_manager which handles model-specific logic.""" + if not req_objs: + return + + if self.radix_cache is not None and hasattr(self.radix_cache, "free_radix_cache_to_get_enough_buffer"): + self.radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs) * (self.mtp_step + 1)) + + request_indices_gpu = torch.tensor([r.req_idx for r in req_objs], device="cuda", dtype=torch.int64) + self.req_manager.alloc_buffer_for_req(request_indices_gpu) + + if self.radix_cache is None: + return + + copy_data = [(r.req_idx, r.shared_kv_node.buffer_idx) for r in req_objs if r.shared_kv_node is not None] + if copy_data: + copy_indices, copy_buffers = zip(*copy_data) + copy_indices_tensor = torch.tensor(copy_indices, device="cuda", dtype=torch.int64) + copy_buffers_tensor = torch.tensor(copy_buffers, device="cuda", dtype=torch.int64) + self.req_manager.copy_buffer_from_another_buffer(copy_buffers_tensor, copy_indices_tensor) + def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] request_ids = [] @@ -111,9 +145,15 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: slave_req: InferReq = slave_req slave_req.related_master_req = master_req + self._alloc_and_copy_req_buffers(req_objs) + return req_objs def free_a_req_mem(self, free_token_index: List, req: "InferReq"): + # If no KV cache has been allocated yet, there's nothing to free + if req.cur_kv_len == 0: + return + if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) else: @@ -122,7 +162,8 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): # .cpu() 是 流内阻塞操作 value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len, _ = self.radix_cache.insert(key, value) + prefix_len, node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -130,6 +171,50 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> bool: + # 返回该请求的 mamba buffer 是否需要手动释放 + if req.cur_kv_len == 0: + return True + + if self.radix_cache is None: + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) + else: + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + + prefix_len, node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len <= prefix_len + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + + if len(req.extra_need_to_free_token_index) > 0: + free_token_index.extend(req.extra_need_to_free_token_index) + req.extra_need_to_free_token_index = [] + + if node.buffer_idx is None: + req_to_buffer_index = self.req_manager.req_to_buffer_index + buffer_idx = req_to_buffer_index[req.req_idx, 0].item() + self.radix_cache.add_buffer_idx_to_node(node, buffer_idx) + # 该请求的 buffer 已经被插入到 radix cache 中,不需要手动释放 + return False + return True + + def _free_req_mem_and_buffers(self, free_token_index: List, free_buffer_index: List, req: "InferReq"): + """释放请求的 KV cache 和 buffer 内存""" + if self.use_mamba_model: + need_free_base_buffer = self.free_a_req_mem_for_mamba(free_token_index, req) + req_to_buffer_index = self.req_manager.req_to_buffer_index + if need_free_base_buffer: + free_buffer_index.extend(req_to_buffer_index[req.req_idx, :].tolist()) + elif self.mtp_step > 0: + free_buffer_index.extend(req_to_buffer_index[req.req_idx, 1:].tolist()) + else: + self.free_a_req_mem(free_token_index, req) + def _save_promptcache_kvbuffer(self): """ save prompt cache kv buffer @@ -151,19 +236,23 @@ def _filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] + free_buffer_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) - + self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) free_req_index.append(req.req_idx) # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) - free_token_index = custom_cat(free_token_index) - self.req_manager.free(free_req_index, free_token_index) + if len(free_token_index) != 0: + free_token_index = custom_cat(free_token_index) + self.req_manager.free(free_req_index, free_token_index) + + if self.use_mamba_model and len(free_buffer_index) != 0: + self.req_manager.free_buffer(free_buffer_index) finished_req_ids_set = set(finished_request_ids) self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set] @@ -191,12 +280,15 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if pause_reqs: g_infer_state_lock.acquire() + pause_req_indices = [] free_token_index = [] + free_buffer_index = [] for req in pause_reqs: + pause_req_indices.append(req.req_idx) if self.args.diverse_mode: # 发生暂停的时候,需要清除 diverse 模式下的主从关系 req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) + self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) req.cur_kv_len = 0 req.shm_req.shm_cur_kv_len = req.cur_kv_len assert req.wait_pause is True @@ -209,13 +301,16 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) + if self.use_mamba_model and len(free_buffer_index) != 0: + self.req_manager.free_buffer(free_buffer_index) + g_infer_state_lock.release() return self def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int): if paused_reqs: g_infer_state_lock.acquire() - + revovered_reqs = [] for req in paused_reqs: prefill_need_token_num = req.get_cur_total_len() if prefill_need_token_num > can_alloc_token_num: @@ -226,7 +321,9 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo if is_master_in_dp: req.shm_req.is_paused = False can_alloc_token_num -= prefill_need_token_num + revovered_reqs.append(req) + self._alloc_and_copy_req_buffers(revovered_reqs) g_infer_state_lock.release() return @@ -351,6 +448,11 @@ def __init__( self.nixl_pd_task_failed_num: int = 0 self.nixl_trans_device_id: int = -1 + # 在开启radix cache的情况下,用于标记命中情况,用于插入算法 + self.mamba_model_match_len = 0 + self.mamba_buffer_insert_len = 0 + self.extra_need_to_free_token_index = [] + # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED @@ -402,7 +504,7 @@ def _match_radix_cache(self): input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + share_node, miss_prefix_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) if share_node is not None: self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len @@ -411,6 +513,13 @@ def _match_radix_cache(self): self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 + if g_infer_context.use_mamba_model: + MAMBA_PREFILL_BLOCK_SIZE = 128 + MAMBA_MIN_INSERT_LEN = 1024 + miss_prefix_len = miss_prefix_len - miss_prefix_len % MAMBA_PREFILL_BLOCK_SIZE + if miss_prefix_len > MAMBA_MIN_INSERT_LEN: + self.mamba_buffer_insert_len = miss_prefix_len + self.shm_req.shm_cur_kv_len = self.cur_kv_len return @@ -458,13 +567,18 @@ def get_input_token_ids(self): return self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] def get_chuncked_input_token_ids(self): - chunked_start = self.cur_kv_len - chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + # 复用 get_chuncked_input_token_len 的逻辑,保持一致性 + chunked_end = self.get_chuncked_input_token_len() return self.shm_req.shm_prompt_ids.arr[0:chunked_end] def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + + if self.mamba_buffer_insert_len > 0: + chunked_end = min(self.get_cur_total_len(), chunked_start + self.mamba_buffer_insert_len) + self.mamba_buffer_insert_len = 0 + return chunked_end def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8b085c45ed..0ba4b9248c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -9,7 +9,6 @@ from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger from lightllm.models import get_model -from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -42,6 +41,7 @@ from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token @@ -172,12 +172,16 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + + self.use_buffer_manager = getattr(self.model, "use_buffer_manager", False) + + radix_cache_class = self.model.get_radix_cache_class() self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, - mem_manager=self.model.mem_manager, + kv_cache_mem_manager=self.model.mem_manager, ) if self.use_dynamic_prompt_cache else None @@ -189,12 +193,18 @@ def init_model(self, kvargs): self.logger.info(f"loaded model class {self.model.__class__}") + # Check if the model uses Mamba (linear attention) layers + from lightllm.common.req_manager import ReqManagerForMamba + + use_mamba_model = isinstance(self.model.req_manager, ReqManagerForMamba) + g_infer_context.register( backend=self, req_manager=self.model.req_manager, radix_cache=self.radix_cache, shm_req_manager=self.shm_req_manager, vocab_size=self.model.vocab_size, + use_mamba_model=use_mamba_model, ) # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 @@ -287,21 +297,33 @@ def decode(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): raise NotImplementedError() def init_mtp_draft_model(self, main_kvargs: dict): - # 当前只支持 deepseekv3 模式的 mtp + # Support deepseekv3 and qwen3_next MTP modes self.mtp_step = self.args.mtp_step - self.draft_models: List[Deepseek3MTPModel] = [] + self.draft_models = [] os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: + if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att", "qwen3next_vanilla"]: num_mtp_modules = self.args.mtp_step - elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: + elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att", "qwen3next_eagle"]: num_mtp_modules = 1 else: assert False, f"error mtp mode {self.args.mtp_mode}" for i in range(num_mtp_modules): + # Get MTP model config first to calculate mem_layer_start mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) + + # Calculate mem_layer_start: main model layers + previous MTP model layers + # For models with integrated MTP (like qwen3_next), each MTP module has 1 layer + # For models with separate MTP configs, use the config's num_hidden_layers + model_type = mtp_model_cfg.get("model_type", "") + if model_type == "qwen3_next": + # Qwen3Next has integrated MTP with 1 layer per module + mtp_layers_per_module = 1 + else: + mtp_layers_per_module = mtp_model_cfg["num_hidden_layers"] + mem_layer_start = self.model.config["num_hidden_layers"] + i * mtp_layers_per_module mtp_model_kvargs = { "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, @@ -314,7 +336,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): "data_type": main_kvargs.get("data_type", "float16"), "graph_max_batch_size": main_kvargs.get("graph_max_batch_size", 16), "graph_max_len_in_batch": main_kvargs.get("graph_max_len_in_batch", 8196), - "disable_cudagraph": main_kvargs.get("disable_cudagraph", False), + "disable_cudagraph": True, # Disable CUDA graphs for MTP draft models "mem_fraction": main_kvargs["mem_fraction"], "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), "quant_type": main_kvargs.get("quant_type", None), @@ -322,23 +344,27 @@ def init_mtp_draft_model(self, main_kvargs: dict): "run_mode": "normal", "main_model": self.model, "mtp_previous_draft_models": self.draft_models.copy(), + "mem_layer_start": mem_layer_start, + "mtp_index": i, } - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) - if mtp_model_cfg["model_type"] == "deepseek_v3": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + # Select MTP model class based on model type + model_type = mtp_model_cfg.get("model_type", "") + if model_type == "deepseek_v3": self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "qwen3_moe": + elif model_type == "qwen3_moe": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "mistral": + elif model_type == "mistral": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) + elif model_type == "qwen3_next": + self.draft_models.append(Qwen3NextMTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "glm4_moe_lite": assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) else: - assert False, f"error mtp mode {mtp_model_cfg['model_type']}" + raise ValueError(f"Unsupported MTP model type: {model_type}") self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224ebc..3cabd97baa 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -24,6 +24,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from .control_state import ControlState logger = init_logger(__name__) @@ -50,6 +51,14 @@ def __init__(self) -> None: self.classed_req_strict_prefill = False return + def _maybe_insert_hybrid_radix_cache(self, run_reqs: List[InferReq]): + # Insert hybrid radix cache entries if applicable, use for hybrid attention models. + if self.use_buffer_manager and self.radix_cache is not None: + torch.cuda.synchronize() + g_infer_state_lock.acquire() + self.radix_cache.insert_for_hybrid_radix_cache(run_reqs) + g_infer_state_lock.release() + def infer_loop(self): torch.cuda.set_device(get_current_device_id()) try: @@ -136,6 +145,9 @@ def prefill_normal( extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + + self._maybe_insert_hybrid_radix_cache(run_reqs) + # 第四阶段 event_pack.notify_pre_post_handle() return @@ -219,6 +231,8 @@ def prefill_mtp( nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + self._maybe_insert_hybrid_radix_cache(run_reqs) + # 第四阶段 event_pack.notify_pre_post_handle() return @@ -258,6 +272,24 @@ def decode_mtp( key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) + + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 (accept_len == 1 means buffer[0] is already correct) + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = model_input.b_req_idx[b_req_mtp_start_loc[mask]] + # Source: the accepted buffer (at index accept_len - 1) + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + # Destination: buffer[0] for each request + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + # P2P copy both conv_states and ssm_states + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e76..c5dd768224 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -454,6 +454,20 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): gpu_tensor=mtp_accept_len, ) + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() @@ -767,6 +781,20 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ) all_next_token_ids.append(next_token_ids) + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 3e97f4de3e..ed4665e725 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -68,7 +68,7 @@ def exposed_init_model(self, kvargs): self.model = ( Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) - elif self.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + elif self.model_type in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]: self.model = ( Qwen3VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) diff --git a/test_gsmk.py b/test_gsmk.py new file mode 100644 index 0000000000..78a5aa467f --- /dev/null +++ b/test_gsmk.py @@ -0,0 +1,241 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + "repetition_penalty": 1.0, + "top_p": 1.0, + "top_k": 1, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + "Invalid API response format. 'generated_text' should be a non-empty list, " + f"got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + # First try to find the answer after "####" marker (GSM8K format) + match = re.search(r"####\s*(-?\d+)", answer_str) + if match: + try: + return ast.literal_eval(match.group(1)) + except SyntaxError: + pass + # Fallback: find all numbers and take the last one + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=256) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. " + "Using {} questions.".format(num_questions, max_available, num_shots, max_available) + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(label != INVALID for label in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=few_shot_examples + questions[i], + temperature=0, + max_tokens=1024, + stop=["Question", "Assistant:", "<|separator|>", "Human:", "\n\nQuestion"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args) From 340d11c574aefb2a979a39ad177bdc03c46c86f6 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 21 Feb 2026 08:04:05 +0000 Subject: [PATCH 092/180] fix conv3d --- lightllm/models/qwen2_vl/qwen2_visual.py | 2 ++ .../qwen3_omni_visual.py | 2 ++ lightllm/models/qwen3_vl/qwen3_visual.py | 13 ++++++++ lightllm/server/api_models.py | 32 ++++++++++++++----- lightllm/server/api_openai.py | 8 +++++ lightllm/server/httpserver/manager.py | 8 +++++ 6 files changed, 57 insertions(+), 8 deletions(-) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 0e2af0cbb2..a29cb8758b 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -62,6 +62,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) + # Use channels_last_3d to enable cuDNN optimized Conv3D path + hidden_states = hidden_states.contiguous(memory_format=torch.channels_last_3d) hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index ffa2e19bd6..c20c227996 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -68,6 +68,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) + # Use channels_last_3d to enable cuDNN optimized Conv3D path + hidden_states = hidden_states.contiguous(memory_format=torch.channels_last_3d) hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index 00ad6c05a7..7fc8187ddc 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -29,6 +29,9 @@ from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class Qwen3VLVisionMLP(nn.Module): @@ -68,6 +71,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) + # Use channels_last_3d to enable cuDNN optimized Conv3D path + hidden_states = hidden_states.contiguous(memory_format=torch.channels_last_3d) hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states @@ -374,7 +379,15 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) + orig_size = image_data.size pixel_values, image_grid_thw = self.processor.preprocess(image_data) + + # Debug logging for image processing + logger.debug( + f"[VISUAL_DEBUG] Image {i}: orig_size={orig_size}, " + f"pixel_values.shape={pixel_values.shape}, grid_thw={image_grid_thw}" + ) + img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index f30ecc55fe..7c7d40698c 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -115,6 +115,7 @@ class CompletionRequest(BaseModel): prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None max_tokens: Optional[int] = 8192 + max_completion_tokens: Optional[int] = None # OpenAI's newer parameter, alias for max_tokens temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 @@ -169,10 +170,17 @@ def load_generation_cfg(cls, weight_dir: str): @classmethod def apply_loaded_defaults(cls, data: Any): """Apply loaded default values if field is not provided.""" - if isinstance(data, dict) and cls._loaded_defaults: - for key, value in cls._loaded_defaults.items(): - if key not in data: - data[key] = value + if isinstance(data, dict): + # Map max_completion_tokens to max_tokens if provided + # (OpenAI's newer parameter name) + if "max_completion_tokens" in data and data["max_completion_tokens"] is not None: + if "max_tokens" not in data or data["max_tokens"] is None: + data["max_tokens"] = data["max_completion_tokens"] + + if cls._loaded_defaults: + for key, value in cls._loaded_defaults.items(): + if key not in data: + data[key] = value return data @@ -187,6 +195,7 @@ class ChatCompletionRequest(BaseModel): stream_options: Optional[StreamOptions] = None stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = 8192 + max_completion_tokens: Optional[int] = None # OpenAI's newer parameter, alias for max_tokens presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -246,10 +255,17 @@ def load_generation_cfg(cls, weight_dir: str): @classmethod def apply_loaded_defaults(cls, data: Any): """Apply loaded default values if field is not provided.""" - if isinstance(data, dict) and cls._loaded_defaults: - for key, value in cls._loaded_defaults.items(): - if key not in data: - data[key] = value + if isinstance(data, dict): + # Map max_completion_tokens to max_tokens if provided + # (OpenAI's newer parameter name) + if "max_completion_tokens" in data and data["max_completion_tokens"] is not None: + if "max_tokens" not in data or data["max_tokens"] is None: + data["max_tokens"] = data["max_completion_tokens"] + + if cls._loaded_defaults: + for key, value in cls._loaded_defaults.items(): + if key not in data: + data[key] = value return data diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index fc14314ae3..de1423c496 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -276,6 +276,14 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req text = "".join(final_output_dict[sub_req_id]) full_text = text + # Debug logging for empty responses + if not text or len(text.strip()) == 0: + logger.warning( + f"[EMPTY_RESPONSE_DEBUG] sub_req_id={sub_req_id}, " + f"completion_tokens={completion_tokens}, finish_reason={finish_reason}, " + f"prompt_tokens={prompt_tokens}, output_chunks={len(final_output_dict[sub_req_id])}" + ) + # Handle reasoning content reasoning_text = None reasoning_parser = get_env_start_args().reasoning_parser diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d51f88cdda..c290880c73 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -288,6 +288,14 @@ async def generate( if self.pd_mode.is_P_or_NORMAL(): await multimodal_params.verify_and_preload(request) + # Debug logging for multimodal requests + if multimodal_params and multimodal_params.images: + logger.debug( + f"[MULTIMODAL_DEBUG] req_id={group_request_id}, " + f"num_images={len(multimodal_params.images)}, " + f"max_new_tokens={sampling_params.max_new_tokens}" + ) + # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) # encode From a6a2435d1ba82f49140a9ab63c37b7cb0999c771 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 07:02:44 +0000 Subject: [PATCH 093/180] [draft] qwen3.5 dense --- .../{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json | 8 + .../{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json | 7 + ...ARLEN=true,REVERSE=false}_NVIDIA_H200.json | 38 ++++ ...=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json | 7 + ...4,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 50 ++++++ ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 48 +++++ ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...=12,dtype=torch.bfloat16}_NVIDIA_H200.json | 50 ++++++ .../layer_weights/transformer_layer_weight.py | 8 +- lightllm/models/qwen3_moe/model.py | 4 +- .../layer_infer/transformer_layer_infer.py | 18 +- .../layer_weights/transformer_layer_weight.py | 168 ++++++++++++------ lightllm/models/qwen3next/model.py | 4 +- 15 files changed, 368 insertions(+), 65 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..808ed9a7fc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 1 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 2 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 8 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..fb62cf8259 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 32, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..ad8d397d3b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 8, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "1024": { + "BLK_HEADS": 32, + "num_warps": 1 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 16, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json index f525d11257..55ccb24a65 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -11,6 +11,10 @@ "BLOCK_N": 128, "num_warps": 1 }, + "1536": { + "BLOCK_N": 128, + "num_warps": 1 + }, "16": { "BLOCK_N": 256, "num_warps": 4 @@ -23,10 +27,26 @@ "BLOCK_N": 128, "num_warps": 1 }, + "192": { + "BLOCK_N": 128, + "num_warps": 1 + }, "2048": { "BLOCK_N": 64, "num_warps": 2 }, + "24": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2400": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "24576": { + "BLOCK_N": 128, + "num_warps": 1 + }, "256": { "BLOCK_N": 512, "num_warps": 2 @@ -35,18 +55,38 @@ "BLOCK_N": 128, "num_warps": 1 }, + "3072": { + "BLOCK_N": 128, + "num_warps": 1 + }, "32768": { "BLOCK_N": 128, "num_warps": 1 }, + "384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "393216": { + "BLOCK_N": 128, + "num_warps": 1 + }, "4096": { "BLOCK_N": 64, "num_warps": 2 }, + "49152": { + "BLOCK_N": 128, + "num_warps": 1 + }, "512": { "BLOCK_N": 256, "num_warps": 4 }, + "6144": { + "BLOCK_N": 128, + "num_warps": 1 + }, "64": { "BLOCK_N": 64, "num_warps": 2 @@ -55,6 +95,10 @@ "BLOCK_N": 128, "num_warps": 1 }, + "768": { + "BLOCK_N": 256, + "num_warps": 2 + }, "8": { "BLOCK_N": 64, "num_warps": 2 @@ -66,5 +110,9 @@ "8192": { "BLOCK_N": 128, "num_warps": 2 + }, + "98304": { + "BLOCK_N": 128, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..df501847ec --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..ada783ef92 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "num_stages": 2, + "num_warps": 1 + }, + "1024": { + "num_stages": 2, + "num_warps": 1 + }, + "128": { + "num_stages": 5, + "num_warps": 2 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 3, + "num_warps": 1 + }, + "2048": { + "num_stages": 5, + "num_warps": 2 + }, + "256": { + "num_stages": 3, + "num_warps": 2 + }, + "32": { + "num_stages": 3, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "num_stages": 3, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 13ba6cbe0f..e525cb2d20 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -5,11 +5,11 @@ class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - self.n_routed_experts = network_config["num_experts"] + self.n_routed_experts = network_config.get("num_experts", 0) self.is_moe = ( - network_config["num_experts"] > 0 - and layer_num not in network_config["mlp_only_layers"] - and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) super().__init__(layer_num, data_type, network_config, quant_cfg) return diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a5051276..b71d7f4878 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -25,4 +25,6 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + # Only initialize DeepEP group for MoE models with num_experts + if "num_experts" in self.config and self.config["num_experts"] > 0: + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index cd5fd67d53..dc44c64434 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -143,14 +143,19 @@ def _ffn_core(self, input, layer_weight, is_decode=False): def _standard_ffn(self, input, infer_state, layer_weight): """Standard FFN using shared expert weights (non-MoE layers).""" + # For dense models without shared experts, return zeros (no FFN computation) + if not hasattr(layer_weight, "shared_expert_gate_up_proj") or layer_weight.shared_expert_gate_up_proj is None: + return torch.zeros_like(input) ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) return ffn2_out def _compute_shared_expert(self, input, layer_weight, is_decode=False): """Compute shared expert FFN output with gating.""" ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) - gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() - ffn2_out.mul_(gate) + # Dense models don't have shared_expert_gate + if layer_weight.shared_expert_gate is not None: + gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() + ffn2_out.mul_(gate) return ffn2_out, input_view def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): @@ -488,14 +493,19 @@ def _ffn_core(self, input, layer_weight, is_decode=False): def _standard_ffn(self, input, infer_state, layer_weight): """Standard FFN using shared expert weights (non-MoE layers).""" + # For dense models without shared experts, return zeros (no FFN computation) + if not hasattr(layer_weight, "shared_expert_gate_up_proj") or layer_weight.shared_expert_gate_up_proj is None: + return torch.zeros_like(input) ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) return ffn2_out def _compute_shared_expert(self, input, layer_weight, is_decode=False): """Compute shared expert FFN output with gating.""" ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) - gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() - ffn2_out.mul_(gate) + # Dense models don't have shared_expert_gate + if layer_weight.shared_expert_gate is not None: + gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() + ffn2_out.mul_(gate) return ffn2_out, input_view def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index d4e16555d9..3e72041f8a 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -47,6 +47,11 @@ def _init_weight(self): self._init_gate_shared_expert_weight() return + def _init_ffn(self): + # Qwen3Next architecture uses _init_gate_shared_expert_weight() for FFN-like component + # No standard MLP FFN weights needed for this architecture + pass + def load_hf_weights(self, weights): self._split_q_with_gate(weights) super().load_hf_weights(weights) @@ -62,41 +67,65 @@ def _split_q_with_gate(self, weights): weights[self._o_gate_weight_name] = _gate_proj def _init_gate_shared_expert_weight(self): - prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" hidden_size = self.network_config_["hidden_size"] - shared_inter = self.network_config_["shared_expert_intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[shared_inter, shared_inter], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=shared_inter, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - self.shared_expert_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) + + # Check if this is a MoE model with shared_expert or a dense model + if "shared_expert_intermediate_size" in self.network_config_: + # MoE model with shared expert + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + else: + # Dense model with standard MLP + prefix = f"model.layers.{self.layer_num_}.mlp" + inter_size = self.network_config_["intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + # No shared_expert_gate for dense models + self.shared_expert_gate = None class Qwen3NextGatedDeltaNetTransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.is_moe = ( - network_config["num_experts"] > 0 - and layer_num not in network_config["mlp_only_layers"] - and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) super().__init__(layer_num, data_type, network_config, quant_cfg) @@ -126,6 +155,11 @@ def _init_weight(self): self._init_ffn() self._init_gate_shared_expert_weight() + def _init_ffn(self): + # GatedDeltaNet architecture uses _init_gate_shared_expert_weight() for FFN-like component + # No standard MLP FFN weights needed for this architecture + pass + def _init_gdn_weight(self): prefix = f"model.layers.{self.layer_num_}.linear_attn" hidden_size = self.network_config_["hidden_size"] @@ -284,30 +318,54 @@ def _parse_linear_conv1d(self, weight): return new_weight def _init_gate_shared_expert_weight(self): - prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" hidden_size = self.network_config_["hidden_size"] - shared_inter = self.network_config_["shared_expert_intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[shared_inter, shared_inter], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=shared_inter, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - self.shared_expert_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) + + # Check if this is a MoE model with shared_expert or a dense model + if "shared_expert_intermediate_size" in self.network_config_: + # MoE model with shared expert + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + self.shared_expert_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + else: + # Dense model with standard MLP + prefix = f"model.layers.{self.layer_num_}.mlp" + inter_size = self.network_config_["intermediate_size"] + self.shared_expert_gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_gate_up_proj"), + ) + self.shared_expert_down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_expert_down_proj"), + ) + # No shared_expert_gate for dense models + self.shared_expert_gate = None diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 1234a659ed..d15b357608 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -60,7 +60,9 @@ def _init_config(self): def _init_custom(self): super()._init_custom() - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + # Only initialize DeepEP group for MoE models with num_experts + if "num_experts" in self.config and self.config["num_experts"] > 0: + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 From 054035d84ad22ff8e00c747a01fddaa9dcdf8bbf Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 10:02:22 +0000 Subject: [PATCH 094/180] split dense and moe --- ...num=8,use_fp8_w8a8=false}_NVIDIA_H200.json | 38 +++++++++++++++++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 38 +++++++++++++++++ .../{topk_num=8}_NVIDIA_H200.json | 12 ++++++ ...orch.bfloat16,topk_num=8}_NVIDIA_H200.json | 18 ++++++++ ...M=8,dtype=torch.bfloat16}_NVIDIA_H200.json | 18 ++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 24 +++++++++++ lightllm/models/qwen35_moe/model.py | 42 ------------------- .../layer_infer/transformer_layer_infer.py | 12 +++--- 8 files changed, 154 insertions(+), 48 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json delete mode 100644 lightllm/models/qwen35_moe/model.py diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..662875ecdb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..1f8134fa64 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json index 002b842cbb..bf2afabaef 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json @@ -19,6 +19,14 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "16384": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, "256": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -27,6 +35,10 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "64": { "BLOCK_SIZE": 128, "num_warps": 8 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json index bc904bb7f8..b32622e3b1 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -29,6 +29,18 @@ "NUM_STAGE": 1, "num_warps": 2 }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, "256": { "BLOCK_DIM": 1024, "BLOCK_M": 1, @@ -41,6 +53,12 @@ "NUM_STAGE": 4, "num_warps": 4 }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, "64": { "BLOCK_DIM": 128, "BLOCK_M": 1, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b3e656b6d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,18 @@ +{ + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "16384": { + "num_stages": 4, + "num_warps": 2 + }, + "2048": { + "num_stages": 2, + "num_warps": 2 + }, + "4096": { + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json index e08a58baf5..0a0f01fe7a 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -23,12 +23,24 @@ "NUM_STAGES": 1, "num_warps": 1 }, + "131072": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "160": { "BLOCK_M": 1, "BLOCK_N": 256, "NUM_STAGES": 1, "num_warps": 1 }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "163840": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -53,6 +65,12 @@ "NUM_STAGES": 2, "num_warps": 1 }, + "32768": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "40960": { "BLOCK_M": 32, "BLOCK_N": 128, @@ -70,5 +88,11 @@ "BLOCK_N": 128, "NUM_STAGES": 2, "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/models/qwen35_moe/model.py b/lightllm/models/qwen35_moe/model.py deleted file mode 100644 index ee149f3a81..0000000000 --- a/lightllm/models/qwen35_moe/model.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import json - -from lightllm.models.qwen3_vl.model import QWen3VLTokenizer -from lightllm.models.registry import ModelRegistry -from lightllm.models.qwen3next.model import Qwen3NextTpPartModel -from lightllm.common.build_utils import repair_config -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights - - -class QWen35Tokenizer(QWen3VLTokenizer): - def __init__(self, tokenizer=None, image_processor=None, **kwargs): - super().__init__(tokenizer, image_processor, **kwargs) - - -@ModelRegistry(["qwen3_5"], is_multimodal=True) -class Qwen35MoeTpPartModel(Qwen3NextTpPartModel): - def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: - all_config = json.load(json_file) - self.config = all_config["text_config"] - - repair_config(self.config, same_names=["num_attention_heads", "n_head"]) - repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) - repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) - repair_config(self.config, same_names=["intermediate_size", "moe_intermediate_size"]) - - # Handle fine-tuning config if present - if self.finetune_config: - self.config["vocab_size"] = self.finetune_config.vocab_size - - def _load_hf_weights(self): - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 9eccddffc1..4f96506b14 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -21,14 +21,14 @@ class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, network_config): - self.n_routed_experts = network_config["num_experts"] + self.n_routed_experts = network_config.get("num_experts", 0) self.is_moe = ( - network_config["num_experts"] > 0 - and layer_num not in network_config["mlp_only_layers"] - and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) - self.num_experts_per_tok = network_config["num_experts_per_tok"] - self.norm_topk_prob = network_config["norm_topk_prob"] + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 0) + self.norm_topk_prob = network_config.get("norm_topk_prob", True) super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) From 01b112a388e2e295fadbb83d4987c3ac5a6e5fcc Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 11:40:11 +0000 Subject: [PATCH 095/180] feat: add mamba_cache_ratio for automatic memory allocation - Add mamba_cache_ratio parameter (default 0.5) - Change mamba_cache_size default from 3000 to None - Implement automatic memory allocation based on ratio - Add clear error messages with solutions when memory insufficient - Maintain backward compatibility with explicit mamba_cache_size Ratio formula: mamba_memory = total_available * ratio / (1 + ratio) - ratio=0.5 -> 33% mamba, 67% KV - ratio=1.0 -> 50% mamba, 50% KV - ratio=2.0 -> 67% mamba, 33% KV --- lightllm/models/qwen3next/model.py | 83 +++++++++++++++++++- lightllm/server/api_cli.py | 17 +++- lightllm/server/core/objs/start_args_type.py | 3 +- lightllm/utils/envs_utils.py | 2 +- 4 files changed, 98 insertions(+), 7 deletions(-) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index d15b357608..205eb1dc9b 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -54,6 +54,75 @@ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch def autotune_layers(self): return self.config["full_attention_interval"] + def _calculate_mamba_cache_size(self, start_args: StartArgs) -> int: + """Calculate mamba cache size based on available memory and mamba_cache_ratio.""" + from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory + import torch.distributed as dist + + use_ratio = self.max_total_token_num is None and start_args.mamba_cache_size is None + + world_size = dist.get_world_size() + total_memory = get_total_gpu_memory() + available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - self.mem_fraction) + + conv_kernel_size = self.config["linear_conv_kernel_dim"] + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) // self.tp_world_size_ + + num_linear_layers = self.config["n_layer"] - (self.config["n_layer"] // self.config["full_attention_interval"]) + + conv_cell_size = ( + num_linear_layers * conv_dim * (conv_kernel_size - 1) * torch._utils._element_size(self.data_type) + ) + + ssm_dtype = torch.bfloat16 if start_args.mamba_ssm_data_type == "bfloat16" else torch.float32 + ssm_cell_size = ( + num_linear_layers + * (self.num_linear_v_heads // self.tp_world_size_) + * self.head_linear_k_dim + * self.head_linear_v_dim + * torch._utils._element_size(ssm_dtype) + ) + + total_cell_size = conv_cell_size + ssm_cell_size + + if use_ratio: + mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 + mamba_memory_gb = available_memory * mamba_cache_ratio / (1 + mamba_cache_ratio) + else: + mamba_memory_gb = available_memory + mamba_cache_ratio = None + + mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) + + if mamba_cache_size < start_args.running_max_req_size: + ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 + raise ValueError( + f"Insufficient memory for mamba cache allocation!\n\n" + f"Calculated mamba_cache_size ({mamba_cache_size}) < " + f"running_max_req_size ({start_args.running_max_req_size})\n\n" + f"Memory budget:\n" + f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated buffers: {mamba_cache_size}\n" + f" Required buffers: {start_args.running_max_req_size}\n\n" + f"Solutions:\n" + f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" + f" 2. Increase --mamba_cache_ratio from {ratio} to " + f"{start_args.running_max_req_size * (1 + ratio) / mamba_cache_size - 1:.3f} or higher\n" + f" 3. Increase --mem_fraction to leave more memory for caches\n" + ) + + logger.info( + f"Mamba cache allocation:\n" + f" Available memory: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated mamba_cache_size: {mamba_cache_size}" + ) + + return mamba_cache_size + def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) @@ -69,16 +138,22 @@ def _init_mem_manager(self): start_args: StartArgs = get_env_start_args() mamba_cache_size = start_args.mamba_cache_size - if mamba_cache_size is not None: - assert ( - mamba_cache_size >= start_args.running_max_req_size - ), "mamba_cache_size must be greater than running_max_req_size" self.num_linear_k_heads = self.config["linear_num_key_heads"] self.num_linear_v_heads = self.config["linear_num_value_heads"] self.head_linear_k_dim = self.config["linear_key_head_dim"] self.head_linear_v_dim = self.config["linear_value_head_dim"] + if mamba_cache_size is None: + mamba_cache_size = self._calculate_mamba_cache_size(start_args) + else: + if mamba_cache_size < start_args.running_max_req_size: + raise ValueError( + f"Explicitly set mamba_cache_size ({mamba_cache_size}) < " + f"running_max_req_size ({start_args.running_max_req_size})\n" + f"Please increase mamba_cache_size to at least {start_args.running_max_req_size}" + ) + conv_kernel_size = self.config["linear_conv_kernel_dim"] conv_dim = ( self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4d122f615d..25365491d3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -629,7 +629,22 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) - parser.add_argument("--mamba_cache_size", type=int, default=3000, help="""The size of linear attn cache. """) + parser.add_argument( + "--mamba_cache_size", + type=int, + default=None, + help="""The size of linear attn cache. If not specified, will be calculated + automatically based on mamba_cache_ratio or max_total_token_num.""", + ) + parser.add_argument( + "--mamba_cache_ratio", + type=float, + default=0.5, + help="""Ratio of available memory to allocate for mamba cache (after model + weights and dynamic memory reservation). Only effective when both + mamba_cache_size and max_total_token_num are not set. Default is 0.5 + (50%% of available memory for mamba cache, rest for KV cache).""", + ) parser.add_argument( "--mamba_ssm_data_type", type=str, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d8d2c6ff8b..0baa11383a 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -176,5 +176,6 @@ class StartArgs: enable_multimodal_audio: bool = field(default=False) # hybrid attention model (Qwen3Next) - mamba_cache_size: int = field(default=800) + mamba_cache_size: Optional[int] = field(default=None) + mamba_cache_ratio: Optional[float] = field(default=0.5) mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index cdafb88873..7a7a9be121 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -158,7 +158,7 @@ def get_kv_quant_calibration_inference_count(): @lru_cache(maxsize=None) def get_triton_autotune_level(): - return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 1)) + return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0)) g_model_init_done = False From 174757d8c6bd7a3e32987a08233d2566aaede131 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 11:49:25 +0000 Subject: [PATCH 096/180] refactor: simplify mamba_cache_ratio to direct percentage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change ratio meaning from complex formula to simple percentage: - Old: ratio = mamba / kv, mamba = total * ratio / (1+ratio) - New: ratio = mamba / total, mamba = total * ratio This makes the ratio more intuitive: - 0.3 → 30% mamba, 70% KV - 0.5 → 50% mamba, 50% KV (default) - 0.7 → 70% mamba, 30% KV Also simplifies error message recommendation formula. --- MAMBA_CACHE_USAGE.md | 53 ++++++++++++++++++++++++++++++ lightllm/models/qwen3next/model.py | 5 +-- lightllm/server/api_cli.py | 8 ++--- 3 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 MAMBA_CACHE_USAGE.md diff --git a/MAMBA_CACHE_USAGE.md b/MAMBA_CACHE_USAGE.md new file mode 100644 index 0000000000..e8bebdec89 --- /dev/null +++ b/MAMBA_CACHE_USAGE.md @@ -0,0 +1,53 @@ +# Mamba Cache Ratio-Based Allocation + +## Parameters + +- `--mamba_cache_ratio ` (default: 0.5) - Percentage of cache memory for mamba +- `--mamba_cache_size ` (default: None) - Explicit buffer count (backward compatible) + +## Ratio Meaning + +`mamba_cache_ratio = mamba_memory / total_cache_memory` + +Examples: +- `0.3` → 30% mamba, 70% KV +- `0.5` → 50% mamba, 50% KV (default) +- `0.7` → 70% mamba, 30% KV + +## Usage Examples + +### Automatic (recommended) +```bash +python -m lightllm.server.api_server \ + --model_dir /path/to/qwen3next \ + --mem_fraction 0.9 +# Uses default ratio 0.5 → 50% mamba, 50% KV +``` + +### Custom ratio +```bash +# For long-context workloads (more KV cache) +python -m lightllm.server.api_server \ + --model_dir /path/to/qwen3next \ + --mamba_cache_ratio 0.3 # 30% mamba, 70% KV + +# For high-concurrency workloads (more mamba cache) +python -m lightllm.server.api_server \ + --model_dir /path/to/qwen3next \ + --mamba_cache_ratio 0.7 # 70% mamba, 30% KV +``` + +### Explicit size (backward compatible) +```bash +python -m lightllm.server.api_server \ + --model_dir /path/to/qwen3next \ + --mamba_cache_size 3000 +``` + +## Troubleshooting + +### Error: "Insufficient memory for mamba cache allocation!" + +**Solution 1**: Reduce `--running_max_req_size` to calculated value or lower +**Solution 2**: Increase `--mamba_cache_ratio` to give more memory to mamba +**Solution 3**: Increase `--mem_fraction` to leave more memory for caches diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 205eb1dc9b..263d1c622d 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -88,8 +88,9 @@ def _calculate_mamba_cache_size(self, start_args: StartArgs) -> int: total_cell_size = conv_cell_size + ssm_cell_size if use_ratio: + # mamba_cache_ratio = mamba_memory / total_cache_memory mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 - mamba_memory_gb = available_memory * mamba_cache_ratio / (1 + mamba_cache_ratio) + mamba_memory_gb = available_memory * mamba_cache_ratio else: mamba_memory_gb = available_memory mamba_cache_ratio = None @@ -110,7 +111,7 @@ def _calculate_mamba_cache_size(self, start_args: StartArgs) -> int: f"Solutions:\n" f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" f" 2. Increase --mamba_cache_ratio from {ratio} to " - f"{start_args.running_max_req_size * (1 + ratio) / mamba_cache_size - 1:.3f} or higher\n" + f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" f" 3. Increase --mem_fraction to leave more memory for caches\n" ) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 25365491d3..eec9a05cf2 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -640,10 +640,10 @@ def make_argument_parser() -> argparse.ArgumentParser: "--mamba_cache_ratio", type=float, default=0.5, - help="""Ratio of available memory to allocate for mamba cache (after model - weights and dynamic memory reservation). Only effective when both - mamba_cache_size and max_total_token_num are not set. Default is 0.5 - (50%% of available memory for mamba cache, rest for KV cache).""", + help="""Ratio of mamba cache to total cache memory (mamba + KV). + Only effective when both mamba_cache_size and max_total_token_num are not set. + Default is 0.5 (50%% mamba cache, 50%% KV cache). + Example: 0.3 -> 30%% mamba, 70%% KV; 0.7 -> 70%% mamba, 30%% KV.""", ) parser.add_argument( "--mamba_ssm_data_type", From dd2516e60c68cf078182eea494df0fbbe70d89ed Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 26 Feb 2026 13:04:36 +0000 Subject: [PATCH 097/180] add H100 config --- ...12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 14 +++ ...12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 14 +++ ...12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 8 ++ ...=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 14 +++ ...12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json | 12 ++ ...,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json | 38 ++++++ ...,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json | 38 ++++++ ...,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json | 38 ++++++ ...LEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...LEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...LEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json | 12 ++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 118 ++++++++++++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++ .../{topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...t16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 74 +++++++++++ ...t16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 74 +++++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 74 +++++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 74 +++++++++++ 40 files changed, 1708 insertions(+) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..cc5c68eb79 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..b6e5109b62 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..511935b4cf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 128, + "num_stages": 4, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 128, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..cc5c68eb79 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..1038611f6a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..4bc06d07d9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..7421097fa4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..f1159e4357 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..892c20e78d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..d831f32c4a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..40cdc996b9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,12 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..833062ec2f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 2 + }, + "100": { + "num_warps": 1 + }, + "1024": { + "num_warps": 1 + }, + "128": { + "num_warps": 1 + }, + "16": { + "num_warps": 2 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 1 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 1 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 8 + }, + "8": { + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..5f2cf9465b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 2 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 2 + }, + "2048": { + "num_warps": 2 + }, + "256": { + "num_warps": 2 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 2 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..c8a1841674 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 2 + }, + "1024": { + "num_warps": 8 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 4 + }, + "32": { + "num_warps": 2 + }, + "4096": { + "num_warps": 8 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..a97cabf8b2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "4": { + "BK": 64, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..786624883f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 64, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..eaca03cf75 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,12 @@ +{ + "2": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..d9064e5d6a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..baef19d90c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 32, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "64": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "8": { + "BLK_HEADS": 64, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..90ac24c408 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 1 + }, + "1024": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 4, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 8, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 16, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..31d7a6e203 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,118 @@ +{ + "1024": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "12": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "1200": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "12288": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "128": { + "BLOCK_N": 256, + "num_warps": 8 + }, + "131072": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "1536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "1600": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "192": { + "BLOCK_N": 512, + "num_warps": 1 + }, + "196608": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "24576": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "256": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "262144": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "3072": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "384": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "4096": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "49152": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "64": { + "BLOCK_N": 256, + "num_warps": 8 + }, + "65536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "768": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "8": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "96": { + "BLOCK_N": 512, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..864d1d3f18 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 4096, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..bcf56e01f7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 128, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ba1dc8a75d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "3072": { + "BLOCK_SIZE": 2048, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..6f109e1c6e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 32768, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0042ef8a2a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..54a5967071 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..bb78d1dd84 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..1552d8bf1a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..08cbfd85c3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..13a070b8f0 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..169a148799 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 512, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 128, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..5022588ef5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 2, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..4ae96d02d1 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 32, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 16 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..28c654f3d2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 3, + "num_warps": 2 + }, + "100": { + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "num_stages": 3, + "num_warps": 2 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 2, + "num_warps": 1 + }, + "2048": { + "num_stages": 4, + "num_warps": 1 + }, + "256": { + "num_stages": 3, + "num_warps": 1 + }, + "32": { + "num_stages": 3, + "num_warps": 2 + }, + "4096": { + "num_stages": 4, + "num_warps": 1 + }, + "64": { + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "num_stages": 2, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..08b0d5e5bc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "num_stages": 3, + "num_warps": 1 + }, + "1024": { + "num_stages": 2, + "num_warps": 1 + }, + "128": { + "num_stages": 4, + "num_warps": 4 + }, + "16": { + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 5, + "num_warps": 2 + }, + "32": { + "num_stages": 5, + "num_warps": 8 + }, + "4096": { + "num_stages": 2, + "num_warps": 1 + }, + "64": { + "num_stages": 2, + "num_warps": 1 + }, + "8": { + "num_stages": 5, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0d871841ed --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "num_stages": 1, + "num_warps": 1 + }, + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "num_stages": 2, + "num_warps": 1 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 4, + "num_warps": 1 + }, + "2048": { + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "num_stages": 4, + "num_warps": 1 + }, + "32": { + "num_stages": 2, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 1 + }, + "64": { + "num_stages": 4, + "num_warps": 1 + }, + "8": { + "num_stages": 3, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..9f3a8dcb25 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "131072": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..72026f01c4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "131072": { + "BLOCK_M": 64, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file From 326ae227d55c1798bcc40e5e7f43cb2018d5203f Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 06:15:24 +0000 Subject: [PATCH 098/180] refactor: align radix_cache_class with infer_state_class style - Replace get_radix_cache_class() classmethod with radix_cache_class class attribute in TpPartBaseModel and Qwen3NextTpPartModel - Move RadixCache/HybridRadixCache imports to module top-level - Update base_backend.py to access radix_cache_class directly - Replace alloc_buffer_for_req_triton with simpler indexed PyTorch assignment - Remove now-unused alloc_buffer_kernel.py Triton kernel - Revert LOADWORKER default to 1 and remove language_model. prefix stripping --- lightllm/common/basemodel/basemodel.py | 8 +- .../basemodel/layer_weights/hf_load_utils.py | 10 +-- .../triton_kernel/alloc_buffer_kernel.py | 80 ------------------- lightllm/common/req_manager.py | 4 +- lightllm/models/qwen3next/model.py | 7 +- .../model_infer/mode_backend/base_backend.py | 2 +- 6 files changed, 9 insertions(+), 102 deletions(-) delete mode 100644 lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index caa90462cc..1d36c72d0b 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -53,11 +54,8 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo - @classmethod - def get_radix_cache_class(cls): - from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache - - return RadixCache + # radix cache class + radix_cache_class = RadixCache def __init__(self, kvargs): self.args = get_env_start_args() diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 304b04ab44..8cf66a5ad6 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -18,14 +18,6 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay weights = {k: weights.get_tensor(k) for k in weights.keys()} else: weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") - new_weight = {} - for k, v in weights.items(): - if "language_model." in k: - new_weight[k[len("language_model.") :]] = v - else: - new_weight[k] = v - del weights - weights = new_weight if pre_post_layer is not None: pre_post_layer.load_hf_weights(weights) @@ -68,7 +60,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye transformer_layer_list=transformer_layer_list, weight_dir=weight_dir, ) # noqa - worker = int(os.environ.get("LOADWORKER", 18)) + worker = int(os.environ.get("LOADWORKER", 1)) with Pool(worker) as p: iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" diff --git a/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py b/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py deleted file mode 100644 index b6444449b1..0000000000 --- a/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def alloc_buffer_for_req_kernel( - req_index_ptr, # [num_reqs] - indices of requests to allocate buffers for - buffer_indexes_ptr, # [num_reqs * num_buffers_per_req] - buffer indices to assign (from CPU) - req_to_buffer_index_ptr, # [max_request_num + 1, num_buffers_per_req] - tensor mapping req_idx to buffer_idx - num_reqs, # number of requests to process - stride_buffer, # stride for req_to_buffer_index second dimension - NUM_BUFFERS_PER_REQ: tl.constexpr, # number of buffers per request (mtp_step + 1) - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - # Mask for valid indices - mask = offsets < num_reqs - - # Load request indices - req_indices = tl.load(req_index_ptr + offsets, mask=mask, other=0) - - # For each request, allocate NUM_BUFFERS_PER_REQ buffers - for buf_idx in tl.static_range(NUM_BUFFERS_PER_REQ): - # Load buffer index for this position - buffer_offset = offsets * NUM_BUFFERS_PER_REQ + buf_idx - buffer_indices = tl.load(buffer_indexes_ptr + buffer_offset, mask=mask, other=0) - - # Update req_to_buffer_index[req_indices, buf_idx] = buffer_indices - output_offset = req_indices * stride_buffer + buf_idx - tl.store(req_to_buffer_index_ptr + output_offset, buffer_indices, mask=mask) - - -def alloc_buffer_for_req_triton( - req_index: torch.Tensor, # [num_reqs] int32/int64 tensor on CUDA - buffer_indexes: torch.Tensor, # [num_reqs * (mtp_step + 1)] int32 tensor (can be CPU or CUDA) - req_to_buffer_index: torch.Tensor, # [max_request_num + 1, mtp_step + 1] int32 tensor on CUDA - mtp_step: int = 0, # number of additional buffers per request (default 0 for non-MTP mode) -): - num_reqs = req_index.shape[0] - num_buffers_per_req = mtp_step + 1 - - # Ensure inputs are on CUDA - if not req_index.is_cuda: - req_index = req_index.cuda() - if not buffer_indexes.is_cuda: - buffer_indexes = buffer_indexes.cuda() - - # Ensure correct dtypes - if req_index.dtype not in [torch.int32, torch.int64]: - req_index = req_index.to(torch.int32) - if buffer_indexes.dtype != torch.int32: - buffer_indexes = buffer_indexes.to(torch.int32) - - # Validate buffer_indexes size - expected_size = num_reqs * num_buffers_per_req - assert buffer_indexes.shape[0] == expected_size, ( - f"Expected {expected_size} buffer indices for {num_reqs} requests " - f"with mtp_step={mtp_step}, but got {buffer_indexes.shape[0]}" - ) - - # Get stride for the second dimension of req_to_buffer_index - stride_buffer = req_to_buffer_index.stride(0) - - # Launch kernel - BLOCK_SIZE = 256 - grid = (triton.cdiv(num_reqs, BLOCK_SIZE),) - - alloc_buffer_for_req_kernel[grid]( - req_index, - buffer_indexes, - req_to_buffer_index, - num_reqs, - stride_buffer, - NUM_BUFFERS_PER_REQ=num_buffers_per_req, - BLOCK_SIZE=BLOCK_SIZE, - ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 573fe50842..bad3fa0557 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,6 +1,5 @@ import torch import collections -from lightllm.common.basemodel.triton_kernel.alloc_buffer_kernel import alloc_buffer_for_req_triton from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional @@ -268,7 +267,8 @@ def alloc_buffer_for_req(self, req_index: torch.Tensor): num_reqs = req_index.shape[0] num_buffers_per_req = self.mtp_step + 1 buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) - alloc_buffer_for_req_triton(req_index, buffer_indexes, self.req_to_buffer_index, self.mtp_step) + # Pure PyTorch: indexed assignment is already a fused GPU kernel + self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): # 获取目标请求的所有 MTP buffer (从 buffer[0] 到 buffer[mtp_step]) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 263d1c622d..b3f0f53cac 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -20,6 +20,7 @@ from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache logger = init_logger(__name__) @@ -33,11 +34,7 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): is_hybrid_attention = True # Indicates model uses hybrid (full + linear) attention use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states - @classmethod - def get_radix_cache_class(cls): - from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache - - return HybridRadixCache + radix_cache_class = HybridRadixCache def __init__(self, kvargs) -> None: self.mem_manager: Qwen3NextHybridMemManager = None diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0ba4b9248c..57a3508e93 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -175,7 +175,7 @@ def init_model(self, kvargs): self.use_buffer_manager = getattr(self.model, "use_buffer_manager", False) - radix_cache_class = self.model.get_radix_cache_class() + radix_cache_class = self.model.radix_cache_class self.radix_cache = ( radix_cache_class( get_unique_server_name(), From e996cd249d717481c7967ea918360d6db48c662e Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 06:19:23 +0000 Subject: [PATCH 099/180] fix: add missing attention_chunk param to flashattention_nopad.py The sgl_kernel.fwd.default API requires attention_chunk before softcap. This file was missed when the parameter was added in commit a4ab210f. Also update sgl-kernel from 0.3.7.post1 to 0.3.21 which supports this API. --- lightllm/models/vit/triton_kernel/flashattention_nopad.py | 3 ++- requirements.txt | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 8428e52996..b43f8f95af 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -195,7 +195,8 @@ def flash_attention_v3_fwd( False, window_size[0], window_size[1], - 0.0, + 0, # attention_chunk + 0.0, # softcap is_rotary_interleaved=False, scheduler_metadata=None, num_splits=1, diff --git a/requirements.txt b/requirements.txt index 25cdab955d..521038f719 100644 --- a/requirements.txt +++ b/requirements.txt @@ -81,7 +81,7 @@ atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 flashinfer-python==0.2.4 -sgl-kernel==0.3.7.post1 +sgl-kernel==0.3.21 httpx==0.28.1 librosa==0.11.0 cuda_bindings==12.9.0 From 5e5cdbe84b54336aa6097c8f0b16785e6324a317 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 06:31:05 +0000 Subject: [PATCH 100/180] refactor: clarify naming in mamba_buffer_copy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename copy_buffer_p2p → copy_mamba_buffer (indexed 1:1 slot copy) - Rename copy_buffer_broadcast → fork_mamba_buffer (1:N MTP fork) - Unify chunk offset param name (pair_idx_offset/copy_idx_offset → chunk_offset) - Rename stride_index → stride_slot to reflect the slot/cache dimension - Rename src_idx_in_batch → src_chunk_idx in fork kernel - Extract _MAX_GRID_DIM = 65535 module constant (was duplicated inline) - Add divisibility assertion before implicit // in fork autotuned wrapper - Update autotuner cache keys to match new names --- .../triton_kernel/mamba_buffer_copy.py | 133 +++++++++--------- 1 file changed, 68 insertions(+), 65 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py index 6a1d8adbd5..21301570d3 100644 --- a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -3,36 +3,38 @@ import triton.language as tl from lightllm.common.triton_utils.autotuner import autotune +_MAX_GRID_DIM = 65535 + @triton.jit -def _copy_buffer_p2p_1d_kernel( +def _copy_mamba_buffer_1d_kernel( src_buffer_ptr, dst_buffer_ptr, src_indexes_ptr, dst_indexes_ptr, - pair_idx_offset, + chunk_offset, layer_idx_offset, stride_layer, - stride_index, + stride_slot, stride_d, d_size, BLOCK_D: tl.constexpr, ): """ - Optimized kernel for 1D buffer copy. + Indexed 1:1 copy kernel for Mamba recurrent state buffers. Grid: (num_pairs, layer_num, num_blocks_d) Each program copies one block of dimension d for one (pair, layer) combination. """ - pair_idx = tl.program_id(0) + pair_idx_offset + pair_idx = tl.program_id(0) + chunk_offset layer_idx = tl.program_id(1) + layer_idx_offset block_d_idx = tl.program_id(2) # Cast strides to int64 to prevent overflow in pointer arithmetic stride_layer = stride_layer.to(tl.int64) - stride_index = stride_index.to(tl.int64) + stride_slot = stride_slot.to(tl.int64) - # Load source and destination indices for this pair + # Load source and destination slot indices for this pair src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) @@ -44,8 +46,8 @@ def _copy_buffer_p2p_1d_kernel( mask = d_offsets < d_size # Calculate source and destination pointers for this layer and pair - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_slot + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_slot src_ptr = base_src + d_offsets * stride_d dst_ptr = base_dst + d_offsets * stride_d @@ -56,54 +58,53 @@ def _copy_buffer_p2p_1d_kernel( @triton.jit -def _copy_buffer_broadcast_1d_kernel( +def _fork_mamba_buffer_1d_kernel( src_buffer_ptr, dst_buffer_ptr, src_indexes_ptr, dst_indexes_ptr, - copy_idx_offset, + chunk_offset, layer_idx_offset, stride_layer, - stride_index, + stride_slot, stride_d, d_size, num_dst_per_src, BLOCK_D: tl.constexpr, ): """ - Broadcast kernel for 1D buffer copy (one source to multiple destinations). + Fork kernel for Mamba recurrent state buffers: one source slot → N destination slots. + Used for MTP speculation where one parent state is copied to multiple child slots. Grid: (num_src, layer_num, num_blocks_d) """ - src_idx_in_batch = tl.program_id(0) + copy_idx_offset + src_chunk_idx = tl.program_id(0) + chunk_offset layer_idx = tl.program_id(1) + layer_idx_offset block_d_idx = tl.program_id(2) # Cast strides to int64 to prevent overflow in pointer arithmetic stride_layer = stride_layer.to(tl.int64) - stride_index = stride_index.to(tl.int64) + stride_slot = stride_slot.to(tl.int64) - # Load source index - src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + # Load source slot index + src_idx = tl.load(src_indexes_ptr + src_chunk_idx).to(tl.int64) # Calculate offsets for this block d_start = block_d_idx * BLOCK_D d_offsets = d_start + tl.arange(0, BLOCK_D) mask = d_offsets < d_size - # Calculate source pointer - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + # Calculate source pointer and load data once + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_slot src_ptr = base_src + d_offsets * stride_d - - # Load data once data = tl.load(src_ptr, mask=mask, other=0.0) - # Broadcast to all destinations for this source + # Write to each destination slot for this source for dst_offset in range(num_dst_per_src): - dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx_in_batch = src_chunk_idx * num_dst_per_src + dst_offset dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_slot dst_ptr = base_dst + d_offsets * stride_d tl.store(dst_ptr, data, mask=mask) @@ -151,20 +152,20 @@ def _get_buffer_copy_run_key(src_indexes: torch.Tensor): @autotune( - kernel_name="mamba_buffer_copy_p2p_1d:v1", + kernel_name="mamba_buffer_copy_1d:v1", configs_gen_func=_get_buffer_copy_1d_configs, static_key_func=_get_buffer_copy_static_key, run_key_func=_get_buffer_copy_run_key, mutates_args=["dst_buffer"], ) -def _copy_buffer_p2p_1d_autotuned( +def _copy_mamba_buffer_1d_autotuned( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, dst_indexes: torch.Tensor, run_config: dict = None, ): - """Auto-tuned 1D buffer copy.""" + """Auto-tuned indexed 1:1 copy of Mamba recurrent state buffer slots.""" num_pairs = src_indexes.shape[0] layer_num = src_buffer.shape[0] d_size = src_buffer.shape[2] @@ -180,19 +181,17 @@ def _copy_buffer_p2p_1d_autotuned( num_blocks_d = triton.cdiv(d_size, BLOCK_D) - MAX_GRID_SIZE = 65535 - - for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): - pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + for pair_chunk_start in range(0, num_pairs, _MAX_GRID_DIM): + pair_chunk_end = min(pair_chunk_start + _MAX_GRID_DIM, num_pairs) pair_chunk_size = pair_chunk_end - pair_chunk_start - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + for layer_chunk_start in range(0, layer_num, _MAX_GRID_DIM): + layer_chunk_end = min(layer_chunk_start + _MAX_GRID_DIM, layer_num) layer_chunk_size = layer_chunk_end - layer_chunk_start grid = (pair_chunk_size, layer_chunk_size, num_blocks_d) - _copy_buffer_p2p_1d_kernel[grid]( + _copy_mamba_buffer_1d_kernel[grid]( src_buffer, dst_buffer, src_indexes, @@ -210,23 +209,26 @@ def _copy_buffer_p2p_1d_autotuned( @autotune( - kernel_name="mamba_buffer_broadcast_1d:v1", + kernel_name="mamba_buffer_fork_1d:v1", configs_gen_func=_get_buffer_copy_1d_configs, static_key_func=_get_buffer_copy_static_key, run_key_func=_get_buffer_copy_run_key, mutates_args=["dst_buffer"], ) -def _copy_buffer_broadcast_1d_autotuned( +def _fork_mamba_buffer_1d_autotuned( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, + dst_indexes: torch.Tensor, # flat 1D: [num_src * num_dst_per_src] run_config: dict = None, ): - """Auto-tuned 1D buffer broadcast (one src to multiple dst).""" + """Auto-tuned fork: copy each source Mamba slot to N destination slots.""" num_src = src_indexes.shape[0] layer_num = src_buffer.shape[0] d_size = src_buffer.shape[2] + assert ( + dst_indexes.shape[0] % num_src == 0 + ), f"dst_indexes length {dst_indexes.shape[0]} must be divisible by num_src {num_src}" num_dst_per_src = dst_indexes.shape[0] // num_src if run_config is None: @@ -240,19 +242,17 @@ def _copy_buffer_broadcast_1d_autotuned( num_blocks_d = triton.cdiv(d_size, BLOCK_D) - MAX_GRID_SIZE = 65535 - - for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): - src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + for src_chunk_start in range(0, num_src, _MAX_GRID_DIM): + src_chunk_end = min(src_chunk_start + _MAX_GRID_DIM, num_src) src_chunk_size = src_chunk_end - src_chunk_start - for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): - layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + for layer_chunk_start in range(0, layer_num, _MAX_GRID_DIM): + layer_chunk_end = min(layer_chunk_start + _MAX_GRID_DIM, layer_num) layer_chunk_size = layer_chunk_end - layer_chunk_start grid = (src_chunk_size, layer_chunk_size, num_blocks_d) - _copy_buffer_broadcast_1d_kernel[grid]( + _fork_mamba_buffer_1d_kernel[grid]( src_buffer, dst_buffer, src_indexes, @@ -285,23 +285,23 @@ def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: return buffer.view(L, B, -1) -def copy_buffer_p2p( +def copy_mamba_buffer( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, dst_indexes: torch.Tensor, ): """ - Copy buffers from source indices to destination indices with auto-tuning. + Indexed 1:1 copy of Mamba recurrent state buffer slots. - Supports any buffer shape [layer_num, buffer_size, ...] as long as the - trailing dimensions are contiguous (which is the default for torch.zeros). + Copies slot src_indexes[i] → dst_indexes[i] for all layers simultaneously. + Used for cache eviction/restore and normal token state management. Args: - src_buffer: Source buffer tensor [layer_num, buffer_size, ...] - dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] - src_indexes: Source buffer indices [num_pairs] - dst_indexes: Destination buffer indices [num_pairs] + src_buffer: [layer_num, num_slots, ...] + dst_buffer: [layer_num, num_slots, ...] + src_indexes: source slot indices [num_pairs] + dst_indexes: destination slot indices [num_pairs] """ assert src_buffer.shape == dst_buffer.shape assert src_indexes.shape == dst_indexes.shape @@ -309,36 +309,39 @@ def copy_buffer_p2p( src_flat = _flatten_trailing_dims(src_buffer) dst_flat = _flatten_trailing_dims(dst_buffer) - _copy_buffer_p2p_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) + _copy_mamba_buffer_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) -def copy_buffer_broadcast( +def fork_mamba_buffer( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, dst_indexes: torch.Tensor, ): """ - Broadcast buffers from source indices to multiple destination indices (MTP use case). + Fork Mamba recurrent state slots: copy one source slot to N destination slots. - Each source buffer is copied to multiple destination buffers. + Used for MTP (Multi-Token Prediction) speculation, where a parent token's + recurrent state must be replicated into each speculative child slot. Args: - src_buffer: Source buffer tensor [layer_num, buffer_size, ...] - dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] - src_indexes: Source buffer indices [num_src] - dst_indexes: Destination buffer indices [num_src, num_dst_per_src] (2D tensor) + src_buffer: [layer_num, num_slots, ...] + dst_buffer: [layer_num, num_slots, ...] + src_indexes: source slot indices [num_src] + dst_indexes: destination slot indices [num_src, num_dst_per_src] """ assert src_buffer.shape == dst_buffer.shape assert len(src_indexes.shape) == 1 - assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D, got shape {dst_indexes.shape}" + assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D [num_src, num_dst_per_src], got {dst_indexes.shape}" num_src = src_indexes.shape[0] - assert num_src == dst_indexes.shape[0], f"Mismatch: src_indexes {num_src} vs dst_indexes {dst_indexes.shape[0]}" + assert ( + num_src == dst_indexes.shape[0] + ), f"Mismatch: src_indexes {num_src} vs dst_indexes rows {dst_indexes.shape[0]}" - # Flatten dst_indexes for kernel + # Flatten dst_indexes to 1D for kernel; kernel reconstructs the 2D layout via num_dst_per_src dst_indexes_flat = dst_indexes.reshape(-1).contiguous() src_flat = _flatten_trailing_dims(src_buffer) dst_flat = _flatten_trailing_dims(dst_buffer) - _copy_buffer_broadcast_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat) + _fork_mamba_buffer_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat) From 9cf783c9f616972f276692d575a4e885e1868e38 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 06:35:06 +0000 Subject: [PATCH 101/180] clean --- MAMBA_CACHE_USAGE.md | 53 ------------------- .../mamba_cache_mem_manager/cache_manager.py | 14 ++--- 2 files changed, 7 insertions(+), 60 deletions(-) delete mode 100644 MAMBA_CACHE_USAGE.md diff --git a/MAMBA_CACHE_USAGE.md b/MAMBA_CACHE_USAGE.md deleted file mode 100644 index e8bebdec89..0000000000 --- a/MAMBA_CACHE_USAGE.md +++ /dev/null @@ -1,53 +0,0 @@ -# Mamba Cache Ratio-Based Allocation - -## Parameters - -- `--mamba_cache_ratio ` (default: 0.5) - Percentage of cache memory for mamba -- `--mamba_cache_size ` (default: None) - Explicit buffer count (backward compatible) - -## Ratio Meaning - -`mamba_cache_ratio = mamba_memory / total_cache_memory` - -Examples: -- `0.3` → 30% mamba, 70% KV -- `0.5` → 50% mamba, 50% KV (default) -- `0.7` → 70% mamba, 30% KV - -## Usage Examples - -### Automatic (recommended) -```bash -python -m lightllm.server.api_server \ - --model_dir /path/to/qwen3next \ - --mem_fraction 0.9 -# Uses default ratio 0.5 → 50% mamba, 50% KV -``` - -### Custom ratio -```bash -# For long-context workloads (more KV cache) -python -m lightllm.server.api_server \ - --model_dir /path/to/qwen3next \ - --mamba_cache_ratio 0.3 # 30% mamba, 70% KV - -# For high-concurrency workloads (more mamba cache) -python -m lightllm.server.api_server \ - --model_dir /path/to/qwen3next \ - --mamba_cache_ratio 0.7 # 70% mamba, 30% KV -``` - -### Explicit size (backward compatible) -```bash -python -m lightllm.server.api_server \ - --model_dir /path/to/qwen3next \ - --mamba_cache_size 3000 -``` - -## Troubleshooting - -### Error: "Insufficient memory for mamba cache allocation!" - -**Solution 1**: Reduce `--running_max_req_size` to calculated value or lower -**Solution 2**: Increase `--mamba_cache_ratio` to give more memory to mamba -**Solution 3**: Increase `--mem_fraction` to leave more memory for caches diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 272a999bb1..9b0933f22f 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -6,7 +6,7 @@ from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.common.allocator_utils import TokenAllocator -from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_buffer_p2p, copy_buffer_broadcast +from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_mamba_buffer, fork_mamba_buffer from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt @@ -57,29 +57,29 @@ def get_mamba_cache(self, layer_idx: int): return conv_state, ssm_state def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): - copy_buffer_p2p( + copy_mamba_buffer( self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes ) - copy_buffer_p2p( + copy_mamba_buffer( self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes ) def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): - copy_buffer_broadcast( + fork_mamba_buffer( self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) - copy_buffer_broadcast( + fork_mamba_buffer( self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): """ - Broadcast ONLY SSM states (not conv states) from source indices to destination indices. + Fork ONLY SSM states (not conv states) from source indices to destination indices. This is used for MTP mode where each buffer maintains its own independent conv state, but SSM states need to be synchronized. """ - copy_buffer_broadcast( + fork_mamba_buffer( self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) From e120edbc09051529db2b94bf2df0d15eb860e0fd Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 07:36:35 +0000 Subject: [PATCH 102/180] fix --- lightllm/common/req_manager.py | 3 ++- lightllm/server/api_models.py | 32 ++++++++------------------------ lightllm/server/api_openai.py | 21 +++------------------ 3 files changed, 13 insertions(+), 43 deletions(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index bad3fa0557..3a5e048fb9 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -267,7 +267,8 @@ def alloc_buffer_for_req(self, req_index: torch.Tensor): num_reqs = req_index.shape[0] num_buffers_per_req = self.mtp_step + 1 buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) - # Pure PyTorch: indexed assignment is already a fused GPU kernel + if not buffer_indexes.is_cuda: + buffer_indexes = buffer_indexes.cuda() self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 7c7d40698c..f30ecc55fe 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -115,7 +115,6 @@ class CompletionRequest(BaseModel): prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None max_tokens: Optional[int] = 8192 - max_completion_tokens: Optional[int] = None # OpenAI's newer parameter, alias for max_tokens temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 @@ -170,17 +169,10 @@ def load_generation_cfg(cls, weight_dir: str): @classmethod def apply_loaded_defaults(cls, data: Any): """Apply loaded default values if field is not provided.""" - if isinstance(data, dict): - # Map max_completion_tokens to max_tokens if provided - # (OpenAI's newer parameter name) - if "max_completion_tokens" in data and data["max_completion_tokens"] is not None: - if "max_tokens" not in data or data["max_tokens"] is None: - data["max_tokens"] = data["max_completion_tokens"] - - if cls._loaded_defaults: - for key, value in cls._loaded_defaults.items(): - if key not in data: - data[key] = value + if isinstance(data, dict) and cls._loaded_defaults: + for key, value in cls._loaded_defaults.items(): + if key not in data: + data[key] = value return data @@ -195,7 +187,6 @@ class ChatCompletionRequest(BaseModel): stream_options: Optional[StreamOptions] = None stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = 8192 - max_completion_tokens: Optional[int] = None # OpenAI's newer parameter, alias for max_tokens presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -255,17 +246,10 @@ def load_generation_cfg(cls, weight_dir: str): @classmethod def apply_loaded_defaults(cls, data: Any): """Apply loaded default values if field is not provided.""" - if isinstance(data, dict): - # Map max_completion_tokens to max_tokens if provided - # (OpenAI's newer parameter name) - if "max_completion_tokens" in data and data["max_completion_tokens"] is not None: - if "max_tokens" not in data or data["max_tokens"] is None: - data["max_tokens"] = data["max_completion_tokens"] - - if cls._loaded_defaults: - for key, value in cls._loaded_defaults.items(): - if key not in data: - data[key] = value + if isinstance(data, dict) and cls._loaded_defaults: + for key, value in cls._loaded_defaults.items(): + if key not in data: + data[key] = value return data diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index de1423c496..33f342822f 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -184,16 +184,9 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} ) else: - # Treat as local file path - if os.path.isfile(img): - with open(img, "rb") as f: - multimodal_params_dict["images"].append( - {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} - ) - else: - raise ValueError( - "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." - ) + raise ValueError( + "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." + ) tools = None if request.tools and request.tool_choice != "none": @@ -276,14 +269,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req text = "".join(final_output_dict[sub_req_id]) full_text = text - # Debug logging for empty responses - if not text or len(text.strip()) == 0: - logger.warning( - f"[EMPTY_RESPONSE_DEBUG] sub_req_id={sub_req_id}, " - f"completion_tokens={completion_tokens}, finish_reason={finish_reason}, " - f"prompt_tokens={prompt_tokens}, output_chunks={len(final_output_dict[sub_req_id])}" - ) - # Handle reasoning content reasoning_text = None reasoning_parser = get_env_start_args().reasoning_parser From f3330cf9b0c11bbb2ec8db5c5d462d810b8a1281 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 10:23:02 +0000 Subject: [PATCH 103/180] clean --- lightllm/models/qwen3_5/model.py | 13 ++- lightllm/models/qwen3next/buffer_pool.py | 83 ------------------- .../layer_infer/shared_expert_mixin.py | 7 +- lightllm/server/api_cli.py | 6 +- 4 files changed, 12 insertions(+), 97 deletions(-) delete mode 100644 lightllm/models/qwen3next/buffer_pool.py diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index fdbccdf787..2f7413bc87 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -1,9 +1,5 @@ import os import json -import time -import gc -from safetensors import safe_open -from tqdm import tqdm from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3next.model import Qwen3NextTpPartModel from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( @@ -11,8 +7,12 @@ Qwen35NextGatedDeltaNetTransformerLayerWeight, ) from lightllm.models.qwen3_vl.model import QWen3VLTokenizer -from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer -from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( + Qwen3VLMultimodalPreLayerInfer, +) +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import ( + Qwen3VLPreAndPostLayerWeight, +) from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( Qwen35FullAttentionTransformerLayerInfer, Qwen35GatedDeltaNetTransformerLayerInfer, @@ -20,7 +20,6 @@ from lightllm.models.qwen3_5.infer_struct import Qwen35InferStateInfo from lightllm.common.build_utils import repair_config from lightllm.utils.log_utils import init_logger -import lightllm.utils.petrel_helper as utils logger = init_logger(__name__) diff --git a/lightllm/models/qwen3next/buffer_pool.py b/lightllm/models/qwen3next/buffer_pool.py deleted file mode 100644 index 42c4bcafc7..0000000000 --- a/lightllm/models/qwen3next/buffer_pool.py +++ /dev/null @@ -1,83 +0,0 @@ -# lightllm/models/qwen3next/buffer_pool.py -import torch -from typing import Dict, Tuple - - -class Qwen3NextBufferPool: - """ - Buffer pool for Qwen3Next inference to reduce allocations. - - NOT thread-safe. Each GPU worker process should have its own pool instance. - - Manages reusable buffers for: - - Attention norm outputs - - FFN norm outputs - - FFN intermediate activations - - GDN intermediate tensors - """ - - def __init__(self, enable_stats: bool = False, max_buffers: int = 64): - self._buffers: Dict[Tuple[tuple, torch.dtype, torch.device], torch.Tensor] = {} - self._in_use: set = set() - self._max_buffers = max_buffers - self._access_order: list = [] # Track LRU order - self._enable_stats = enable_stats - self._stats = {"hits": 0, "misses": 0, "peak_buffers": 0, "evictions": 0} if enable_stats else None - - def get_buffer( - self, - shape: Tuple[int, ...], - dtype: torch.dtype, - device: torch.device, - ) -> torch.Tensor: - """Get a buffer from the pool or allocate a new one.""" - key = (shape, dtype, device) - - # Check if we have a matching buffer not in use - if key in self._buffers and key not in self._in_use: - self._in_use.add(key) - # Update LRU order - if key in self._access_order: - self._access_order.remove(key) - self._access_order.append(key) - if self._enable_stats: - self._stats["hits"] += 1 - return self._buffers[key] - - # Evict oldest unused buffer if at capacity - if len(self._buffers) >= self._max_buffers: - self._evict_one() - - # Allocate new buffer - buffer = torch.empty(shape, dtype=dtype, device=device) - self._buffers[key] = buffer - self._in_use.add(key) - self._access_order.append(key) - if self._enable_stats: - self._stats["misses"] += 1 - self._stats["peak_buffers"] = max(self._stats["peak_buffers"], len(self._buffers)) - return buffer - - def _evict_one(self): - """Evict oldest unused buffer (LRU).""" - for key in self._access_order: - if key not in self._in_use and key in self._buffers: - del self._buffers[key] - self._access_order.remove(key) - if self._enable_stats: - self._stats["evictions"] += 1 - return - - def release_all(self): - """Release all buffers back to the pool (call after forward pass).""" - self._in_use.clear() - - def clear(self): - """Clear all buffers (call when changing batch size significantly).""" - self._buffers.clear() - self._in_use.clear() - self._access_order.clear() - - def get_stats(self): - """Return buffer pool statistics (if enabled).""" - return self._stats.copy() if self._stats else None diff --git a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py index 2da106dbb2..be9000fcad 100644 --- a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py +++ b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py @@ -32,12 +32,7 @@ def _ffn_core(self, input, layer_weight): """Core FFN computation: gate_up -> silu_and_mul -> down.""" input = input.view(-1, self.embed_dim_) up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) - - if hasattr(self, "buffer_pool") and self.buffer_pool: - ffn1_out = self.buffer_pool.get_buffer((input.size(0), up_gate_out.size(1) // 2), input.dtype, input.device) - else: - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) silu_and_mul_fwd(up_gate_out, ffn1_out) ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) return ffn2_out, input diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index eec9a05cf2..47111f76bc 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -638,7 +638,11 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--mamba_cache_ratio", - type=float, + type=lambda v: float(v) + if 0.0 <= (_ := float(v)) <= 1.0 + else (_ for _ in ()).throw( + argparse.ArgumentTypeError(f"--mamba_cache_ratio must be between 0.0 and 1.0, got {v}") + ), default=0.5, help="""Ratio of mamba cache to total cache memory (mamba + KV). Only effective when both mamba_cache_size and max_total_token_num are not set. From d030a67ed76f457accf938b5f1219aad70fdce8a Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 27 Feb 2026 13:37:03 +0000 Subject: [PATCH 104/180] split --- lightllm/models/__init__.py | 6 +-- lightllm/models/qwen3_5/__init__.py | 7 ++- lightllm/models/qwen3_5/model.py | 30 ------------ lightllm/models/qwen3_5_moe/__init__.py | 0 .../qwen3_5_moe/layer_infer/__init__.py | 0 .../qwen3_5_moe/layer_weights/__init__.py | 0 lightllm/models/qwen3_5_moe/model.py | 48 +++++++++++++++++++ 7 files changed, 53 insertions(+), 38 deletions(-) create mode 100644 lightllm/models/qwen3_5_moe/__init__.py create mode 100644 lightllm/models/qwen3_5_moe/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_5_moe/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_5_moe/model.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index af13e34cd9..ad040cdf25 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -40,8 +40,6 @@ ) from lightllm.models.gpt_oss.model import GptOssTpPartModel from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel -from lightllm.models.qwen3_5.model import ( - Qwen3_5TpPartModel, - Qwen3_5MOETpPartModel, -) +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.models.qwen3_5_moe.model import Qwen3_5MOETpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/qwen3_5/__init__.py b/lightllm/models/qwen3_5/__init__.py index 47667a92d5..56a41a228a 100644 --- a/lightllm/models/qwen3_5/__init__.py +++ b/lightllm/models/qwen3_5/__init__.py @@ -1,17 +1,16 @@ """ -Qwen3.5 Multimodal Model Module +Qwen3.5 Multimodal Model Module (Dense Variant) -Provides Qwen3.5 multimodal models with hybrid attention and vision-language support. +Provides Qwen3.5 dense multimodal model with hybrid attention and vision-language support. +For MoE variant, see qwen3_5_moe module. """ from .model import ( Qwen3_5TpPartModel, - Qwen3_5MOETpPartModel, QWen3_5Tokenizer, ) __all__ = [ "Qwen3_5TpPartModel", - "Qwen3_5MOETpPartModel", "QWen3_5Tokenizer", ] diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index 2f7413bc87..3d093b3939 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -196,33 +196,3 @@ def _init_infer_layer(self): ) for i in range(self.config["n_layer"]) ] - - -@ModelRegistry(["qwen3_5_moe"], is_multimodal=True) -class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): - """ - Qwen3.5-MoE Multimodal Model (Mixture of Experts Variant) - - Extends Qwen3.5 with sparse expert routing: - - Same hybrid attention architecture as Qwen3.5 - - MoE layers replace dense MLP layers - - Expert routing handled by Qwen3NextSparseMoeBlock (inherited) - - The MoE variant is automatically configured by inheriting from - Qwen3NextTpPartModel, which inherits from Qwen3MOEModel. - - No additional configuration needed - MoE support is built-in. - """ - - def __init__(self, kvargs): - """ - Initialize Qwen3.5-MoE model. - - Args: - kvargs: Dictionary containing: - - weight_dir: Path to model weights - - max_total_token_num: Maximum total tokens - - Additional model configuration - """ - super().__init__(kvargs) - logger.info("Initialized Qwen3.5-MoE multimodal model with expert routing") diff --git a/lightllm/models/qwen3_5_moe/__init__.py b/lightllm/models/qwen3_5_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/layer_infer/__init__.py b/lightllm/models/qwen3_5_moe/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/layer_weights/__init__.py b/lightllm/models/qwen3_5_moe/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/model.py b/lightllm/models/qwen3_5_moe/model.py new file mode 100644 index 0000000000..069992bb37 --- /dev/null +++ b/lightllm/models/qwen3_5_moe/model.py @@ -0,0 +1,48 @@ +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager + +logger = init_logger(__name__) + + +@ModelRegistry("qwen3_5_moe", is_multimodal=True) +class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): + """ + Qwen3.5-MoE Multimodal Model (Mixture of Experts Variant) + + Extends Qwen3.5 with sparse expert routing: + - Same hybrid attention architecture as Qwen3.5 + - MoE layers replace dense MLP layers + - Expert routing handled by inherited MoE infrastructure + + This model combines: + - Hybrid attention from Qwen3Next (Gated Delta Networks + Full Attention) + - Multimodal capabilities from Qwen3VL (image/video processing) + - MoE sparse routing for efficient scaling + """ + + def __init__(self, kvargs): + """ + Initialize Qwen3.5-MoE model. + + Args: + kvargs: Dictionary containing: + - weight_dir: Path to model weights + - max_total_token_num: Maximum total tokens + - Additional model configuration + """ + super().__init__(kvargs) + logger.info("Initialized Qwen3.5-MoE multimodal model with expert routing") + + def _init_custom(self): + """ + Initialize MoE-specific components. + + Sets up DeepEP communication group for expert parallelism + when the model has experts configured. + """ + super()._init_custom() + # Initialize DeepEP group for MoE models with num_experts + if "num_experts" in self.config and self.config["num_experts"] > 0: + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) From e1f6129d8de0b2e8d7de322fd54a619511a3008d Mon Sep 17 00:00:00 2001 From: sufubao Date: Sun, 1 Mar 2026 17:11:01 +0000 Subject: [PATCH 105/180] style: apply black formatting to mamba_buffer_copy Pre-commit hook formatting changes. --- .../triton_kernel/mamba_buffer_copy.py | 397 ++++++++---------- 1 file changed, 186 insertions(+), 211 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py index 21301570d3..b198ed5d1e 100644 --- a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -7,282 +7,262 @@ @triton.jit -def _copy_mamba_buffer_1d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - chunk_offset, - layer_idx_offset, +def _copy_buffer_kernel( + src_ptr, + dst_ptr, + src_idx_ptr, + dst_idx_ptr, stride_layer, stride_slot, - stride_d, d_size, BLOCK_D: tl.constexpr, ): - """ - Indexed 1:1 copy kernel for Mamba recurrent state buffers. - - Grid: (num_pairs, layer_num, num_blocks_d) - Each program copies one block of dimension d for one (pair, layer) combination. - """ - pair_idx = tl.program_id(0) + chunk_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_d_idx = tl.program_id(2) + pair_idx = tl.program_id(0) + layer_idx = tl.program_id(1) + block_d = tl.program_id(2) - # Cast strides to int64 to prevent overflow in pointer arithmetic stride_layer = stride_layer.to(tl.int64) stride_slot = stride_slot.to(tl.int64) - # Load source and destination slot indices for this pair - src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) - dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) - - # Calculate offsets for this block - d_start = block_d_idx * BLOCK_D - d_offsets = d_start + tl.arange(0, BLOCK_D) + src_slot = tl.load(src_idx_ptr + pair_idx).to(tl.int64) + dst_slot = tl.load(dst_idx_ptr + pair_idx).to(tl.int64) - # Create mask for valid indices - mask = d_offsets < d_size + offs = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offs < d_size - # Calculate source and destination pointers for this layer and pair - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_slot - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_slot - - src_ptr = base_src + d_offsets * stride_d - dst_ptr = base_dst + d_offsets * stride_d - - # Load and store - data = tl.load(src_ptr, mask=mask, other=0.0) - tl.store(dst_ptr, data, mask=mask) + base = layer_idx * stride_layer + tl.store( + dst_ptr + base + dst_slot * stride_slot + offs, + tl.load(src_ptr + base + src_slot * stride_slot + offs, mask=mask), + mask=mask, + ) @triton.jit -def _fork_mamba_buffer_1d_kernel( - src_buffer_ptr, - dst_buffer_ptr, - src_indexes_ptr, - dst_indexes_ptr, - chunk_offset, - layer_idx_offset, +def _fork_buffer_kernel( + src_ptr, + dst_ptr, + src_idx_ptr, + dst_idx_ptr, stride_layer, stride_slot, - stride_d, d_size, num_dst_per_src, BLOCK_D: tl.constexpr, ): - """ - Fork kernel for Mamba recurrent state buffers: one source slot → N destination slots. + flat_pair = tl.program_id(0) + layer_idx = tl.program_id(1) + block_d = tl.program_id(2) - Used for MTP speculation where one parent state is copied to multiple child slots. - Grid: (num_src, layer_num, num_blocks_d) - """ - src_chunk_idx = tl.program_id(0) + chunk_offset - layer_idx = tl.program_id(1) + layer_idx_offset - block_d_idx = tl.program_id(2) + src_chunk = flat_pair // num_dst_per_src - # Cast strides to int64 to prevent overflow in pointer arithmetic stride_layer = stride_layer.to(tl.int64) stride_slot = stride_slot.to(tl.int64) - # Load source slot index - src_idx = tl.load(src_indexes_ptr + src_chunk_idx).to(tl.int64) + src_slot = tl.load(src_idx_ptr + src_chunk).to(tl.int64) + dst_slot = tl.load(dst_idx_ptr + flat_pair).to(tl.int64) - # Calculate offsets for this block - d_start = block_d_idx * BLOCK_D - d_offsets = d_start + tl.arange(0, BLOCK_D) - mask = d_offsets < d_size + offs = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offs < d_size - # Calculate source pointer and load data once - base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_slot - src_ptr = base_src + d_offsets * stride_d - data = tl.load(src_ptr, mask=mask, other=0.0) + base = layer_idx * stride_layer + tl.store( + dst_ptr + base + dst_slot * stride_slot + offs, + tl.load(src_ptr + base + src_slot * stride_slot + offs, mask=mask), + mask=mask, + ) - # Write to each destination slot for this source - for dst_offset in range(num_dst_per_src): - dst_idx_in_batch = src_chunk_idx * num_dst_per_src + dst_offset - dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) - - base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_slot - dst_ptr = base_dst + d_offsets * stride_d - - tl.store(dst_ptr, data, mask=mask) +def _get_buffer_copy_configs(): + configs = [] + for block_d in [128, 256, 512, 1024, 2048, 4096]: + for num_warps in [1, 2, 4, 8]: + for num_stages in [1, 2]: + configs.append({"BLOCK_D": block_d, "num_warps": num_warps, "num_stages": num_stages}) + return configs -# ==================== Config Generation Functions ==================== +def _get_copy_static_key( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + """Static key for copy kernel cache: dtype, d_size, layer_num. -def _get_buffer_copy_1d_configs(): - """Generate candidate configurations for 1D buffer copy.""" - configs = [] - for block_d in [32, 64, 128, 256, 512, 1024]: - for num_warps in [2, 4, 8]: - for num_stages in [2, 3, 4]: - configs.append( - { - "BLOCK_D": block_d, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs + Different models (35B vs 397B) have different optimal configs, so each + should get its own cache file. + """ + d_size = ( + src_buffer.shape[2] + if src_buffer.ndim == 3 + else src_buffer.numel() // (src_buffer.shape[0] * src_buffer.shape[1]) + ) + return { + "dtype": str(src_buffer.dtype), + "d_size": d_size, + "layer_num": src_buffer.shape[0], + "ndim": src_buffer.ndim, + } -# ==================== Static and Run Key Functions ==================== +def _get_copy_run_key( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + """Run key: constant since static_key already uniquely identifies config.""" + return 0 -def _get_buffer_copy_static_key(src_buffer: torch.Tensor): - """Static key based on buffer shape and dtype.""" - shape = src_buffer.shape +def _get_fork_static_key( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes_flat: torch.Tensor, + num_dst_per_src: int, +): + """Static key for fork kernel cache: dtype, d_size, layer_num.""" + d_size = ( + src_buffer.shape[2] + if src_buffer.ndim == 3 + else src_buffer.numel() // (src_buffer.shape[0] * src_buffer.shape[1]) + ) return { - "ndim": len(shape), - "layer_num": shape[0], - "d_sizes": str(shape[2:]), "dtype": str(src_buffer.dtype), + "d_size": d_size, + "layer_num": src_buffer.shape[0], + "ndim": src_buffer.ndim, } -def _get_buffer_copy_run_key(src_indexes: torch.Tensor): - """Run key based on number of copy pairs.""" - return src_indexes.shape[0] +def _get_fork_run_key( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes_flat: torch.Tensor, + num_dst_per_src: int, +): + """Run key: constant since static_key already uniquely identifies config.""" + return 0 -# ==================== Auto-tuned Buffer Copy Functions ==================== +# ─── Helper functions ───────────────────────────────────────────────────────── + + +def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: + """Flatten dims after [layer_num, buffer_size] into one. Zero-copy for contiguous tensors.""" + if buffer.ndim == 3: + return buffer + L, B = buffer.shape[:2] + return buffer.view(L, B, -1) + + +# ─── Autotuned implementations ──────────────────────────────────────────────── @autotune( kernel_name="mamba_buffer_copy_1d:v1", - configs_gen_func=_get_buffer_copy_1d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, + configs_gen_func=_get_buffer_copy_configs, + static_key_func=_get_copy_static_key, + run_key_func=_get_copy_run_key, mutates_args=["dst_buffer"], ) -def _copy_mamba_buffer_1d_autotuned( +def _copy_mamba_buffer_autotuned( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, dst_indexes: torch.Tensor, run_config: dict = None, ): - """Auto-tuned indexed 1:1 copy of Mamba recurrent state buffer slots.""" + """Autotuned indexed copy implementation.""" + # Default heuristic when autotune is disabled or no config cached + if not run_config: + d_size = src_buffer.shape[2] + # For memory-bound copy, larger BLOCK_D is better (reduces grid size) + BLOCK_D = min(4096, triton.next_power_of_2(d_size)) + num_warps = 4 if BLOCK_D >= 1024 else 2 + run_config = {"BLOCK_D": BLOCK_D, "num_warps": num_warps, "num_stages": 1} + + config = run_config + BLOCK_D = config["BLOCK_D"] num_pairs = src_indexes.shape[0] layer_num = src_buffer.shape[0] d_size = src_buffer.shape[2] - if run_config is None: - BLOCK_D = triton.next_power_of_2(min(d_size, 256)) - num_warps = 4 if BLOCK_D > 256 else 2 - num_stages = 2 - else: - BLOCK_D = run_config["BLOCK_D"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - num_blocks_d = triton.cdiv(d_size, BLOCK_D) - for pair_chunk_start in range(0, num_pairs, _MAX_GRID_DIM): - pair_chunk_end = min(pair_chunk_start + _MAX_GRID_DIM, num_pairs) - pair_chunk_size = pair_chunk_end - pair_chunk_start - - for layer_chunk_start in range(0, layer_num, _MAX_GRID_DIM): - layer_chunk_end = min(layer_chunk_start + _MAX_GRID_DIM, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - - grid = (pair_chunk_size, layer_chunk_size, num_blocks_d) - - _copy_mamba_buffer_1d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - pair_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - d_size, - BLOCK_D=BLOCK_D, - num_warps=num_warps, - num_stages=num_stages, - ) + assert num_pairs <= _MAX_GRID_DIM, f"num_pairs={num_pairs} exceeds grid limit {_MAX_GRID_DIM}" + assert layer_num <= _MAX_GRID_DIM, f"layer_num={layer_num} exceeds grid limit {_MAX_GRID_DIM}" + + grid = (num_pairs, layer_num, num_blocks_d) + _copy_buffer_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_buffer.stride(0), + src_buffer.stride(1), + d_size, + BLOCK_D=BLOCK_D, + num_warps=config["num_warps"], + num_stages=config["num_stages"], + ) @autotune( kernel_name="mamba_buffer_fork_1d:v1", - configs_gen_func=_get_buffer_copy_1d_configs, - static_key_func=_get_buffer_copy_static_key, - run_key_func=_get_buffer_copy_run_key, + configs_gen_func=_get_buffer_copy_configs, + static_key_func=_get_fork_static_key, + run_key_func=_get_fork_run_key, mutates_args=["dst_buffer"], ) -def _fork_mamba_buffer_1d_autotuned( +def _fork_mamba_buffer_autotuned( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, # flat 1D: [num_src * num_dst_per_src] + dst_indexes_flat: torch.Tensor, + num_dst_per_src: int, run_config: dict = None, ): - """Auto-tuned fork: copy each source Mamba slot to N destination slots.""" + """Autotuned fork implementation.""" + # Default heuristic when autotune is disabled or no config cached + if not run_config: + d_size = src_buffer.shape[2] + BLOCK_D = min(4096, triton.next_power_of_2(d_size)) + num_warps = 4 if BLOCK_D >= 1024 else 2 + run_config = {"BLOCK_D": BLOCK_D, "num_warps": num_warps, "num_stages": 1} + + config = run_config + BLOCK_D = config["BLOCK_D"] num_src = src_indexes.shape[0] layer_num = src_buffer.shape[0] d_size = src_buffer.shape[2] - assert ( - dst_indexes.shape[0] % num_src == 0 - ), f"dst_indexes length {dst_indexes.shape[0]} must be divisible by num_src {num_src}" - num_dst_per_src = dst_indexes.shape[0] // num_src - - if run_config is None: - BLOCK_D = triton.next_power_of_2(min(d_size, 256)) - num_warps = 4 if BLOCK_D > 256 else 2 - num_stages = 2 - else: - BLOCK_D = run_config["BLOCK_D"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] num_blocks_d = triton.cdiv(d_size, BLOCK_D) + total_pairs = num_src * num_dst_per_src - for src_chunk_start in range(0, num_src, _MAX_GRID_DIM): - src_chunk_end = min(src_chunk_start + _MAX_GRID_DIM, num_src) - src_chunk_size = src_chunk_end - src_chunk_start - - for layer_chunk_start in range(0, layer_num, _MAX_GRID_DIM): - layer_chunk_end = min(layer_chunk_start + _MAX_GRID_DIM, layer_num) - layer_chunk_size = layer_chunk_end - layer_chunk_start - - grid = (src_chunk_size, layer_chunk_size, num_blocks_d) + assert total_pairs <= _MAX_GRID_DIM, f"total_pairs={total_pairs} exceeds grid limit {_MAX_GRID_DIM}" + assert layer_num <= _MAX_GRID_DIM, f"layer_num={layer_num} exceeds grid limit {_MAX_GRID_DIM}" - _fork_mamba_buffer_1d_kernel[grid]( - src_buffer, - dst_buffer, - src_indexes, - dst_indexes, - src_chunk_start, - layer_chunk_start, - src_buffer.stride(0), - src_buffer.stride(1), - src_buffer.stride(2), - d_size, - num_dst_per_src, - BLOCK_D=BLOCK_D, - num_warps=num_warps, - num_stages=num_stages, - ) + grid = (total_pairs, layer_num, num_blocks_d) + _fork_buffer_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes_flat, + src_buffer.stride(0), + src_buffer.stride(1), + d_size, + num_dst_per_src, + BLOCK_D=BLOCK_D, + num_warps=config["num_warps"], + num_stages=config["num_stages"], + ) -# ==================== Unified Interface ==================== - - -def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: - """Flatten all dimensions after [layer_num, buffer_size] into one. - - For a contiguous buffer of shape [L, B, d1, d2, ...], returns a view - of shape [L, B, d1*d2*...]. This is a zero-copy operation. - """ - if buffer.ndim == 3: - return buffer - L, B = buffer.shape[:2] - return buffer.view(L, B, -1) +# ─── Public API ─────────────────────────────────────────────────────────────── def copy_mamba_buffer( @@ -294,8 +274,7 @@ def copy_mamba_buffer( """ Indexed 1:1 copy of Mamba recurrent state buffer slots. - Copies slot src_indexes[i] → dst_indexes[i] for all layers simultaneously. - Used for cache eviction/restore and normal token state management. + Copies slot src_indexes[i] -> dst_indexes[i] for all layers simultaneously. Args: src_buffer: [layer_num, num_slots, ...] @@ -304,12 +283,11 @@ def copy_mamba_buffer( dst_indexes: destination slot indices [num_pairs] """ assert src_buffer.shape == dst_buffer.shape - assert src_indexes.shape == dst_indexes.shape - assert len(src_indexes.shape) == 1 + assert src_indexes.shape == dst_indexes.shape and src_indexes.ndim == 1 src_flat = _flatten_trailing_dims(src_buffer) dst_flat = _flatten_trailing_dims(dst_buffer) - _copy_mamba_buffer_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) + _copy_mamba_buffer_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) def fork_mamba_buffer( @@ -319,10 +297,9 @@ def fork_mamba_buffer( dst_indexes: torch.Tensor, ): """ - Fork Mamba recurrent state slots: copy one source slot to N destination slots. + Fork Mamba recurrent state slots: one source -> N destinations. - Used for MTP (Multi-Token Prediction) speculation, where a parent token's - recurrent state must be replicated into each speculative child slot. + Used for MTP speculation where parent state is replicated to child slots. Args: src_buffer: [layer_num, num_slots, ...] @@ -331,17 +308,15 @@ def fork_mamba_buffer( dst_indexes: destination slot indices [num_src, num_dst_per_src] """ assert src_buffer.shape == dst_buffer.shape - assert len(src_indexes.shape) == 1 - assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D [num_src, num_dst_per_src], got {dst_indexes.shape}" - - num_src = src_indexes.shape[0] + assert src_indexes.ndim == 1 + assert dst_indexes.ndim == 2, f"dst_indexes must be 2D [num_src, num_dst_per_src], got {dst_indexes.shape}" assert ( - num_src == dst_indexes.shape[0] - ), f"Mismatch: src_indexes {num_src} vs dst_indexes rows {dst_indexes.shape[0]}" + dst_indexes.shape[0] == src_indexes.shape[0] + ), f"Mismatch: src_indexes {src_indexes.shape[0]} vs dst_indexes rows {dst_indexes.shape[0]}" - # Flatten dst_indexes to 1D for kernel; kernel reconstructs the 2D layout via num_dst_per_src + num_dst_per_src = dst_indexes.shape[1] dst_indexes_flat = dst_indexes.reshape(-1).contiguous() src_flat = _flatten_trailing_dims(src_buffer) dst_flat = _flatten_trailing_dims(dst_buffer) - _fork_mamba_buffer_1d_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat) + _fork_mamba_buffer_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat, num_dst_per_src) From 74f82d13506c33ccffd2d5451642ce4d6ec30c8d Mon Sep 17 00:00:00 2001 From: sufubao Date: Sun, 1 Mar 2026 17:11:35 +0000 Subject: [PATCH 106/180] perf: add autotune configs for mamba_buffer_copy/fork kernels on H200 Configs for Qwen3.5-35B (layer_num=30) and 397B (layer_num=48): - SSM state (float32): d_size=262144/393216 - Conv state (bf16): d_size=12288/15360 --- ...pe=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json | 7 +++++++ ...pe=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json | 7 +++++++ ...ype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json | 7 +++++++ ...ype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json | 7 +++++++ ...pe=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json | 7 +++++++ ...pe=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json | 7 +++++++ ...ype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json | 7 +++++++ ...ype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json | 7 +++++++ 8 files changed, 56 insertions(+) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..69a0e9ca42 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 2048, + "num_stages": 2, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..9de6716c3c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 4096, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..2e3a3febbb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..5c9f40590b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 1024, + "num_stages": 1, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..0a3facae38 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 2048, + "num_stages": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..9cdaab5ace --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 4096, + "num_stages": 2, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..2e3a3febbb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..889f6ab71b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 1024, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file From c1ea7697b729ceeb667cc41e5047f2c53ef81d4d Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 2 Mar 2026 03:43:50 +0000 Subject: [PATCH 107/180] refactor: rename buffer copy methods for clarity - copy_buffer_p2p -> copy_state_buffers (removes misleading "p2p") - copy_buffer_broadcast -> fork_state_buffers (aligns with fork_mamba_buffer kernel) - copy_ssm_buffer_broadcast -> fork_ssm_buffers (consistent naming) - Remove redundant docstrings in mamba_buffer_copy.py --- .../triton_kernel/mamba_buffer_copy.py | 33 ------------------- .../mamba_cache_mem_manager/cache_manager.py | 6 ++-- lightllm/common/req_manager.py | 2 +- .../dynamic_prompt/hybrid_radix_cache.py | 2 +- .../mode_backend/chunked_prefill/impl.py | 4 +-- .../mode_backend/dp_backend/impl.py | 8 ++--- 6 files changed, 11 insertions(+), 44 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py index b198ed5d1e..361c0565ae 100644 --- a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -112,7 +112,6 @@ def _get_copy_run_key( src_indexes: torch.Tensor, dst_indexes: torch.Tensor, ): - """Run key: constant since static_key already uniquely identifies config.""" return 0 @@ -123,7 +122,6 @@ def _get_fork_static_key( dst_indexes_flat: torch.Tensor, num_dst_per_src: int, ): - """Static key for fork kernel cache: dtype, d_size, layer_num.""" d_size = ( src_buffer.shape[2] if src_buffer.ndim == 3 @@ -144,13 +142,9 @@ def _get_fork_run_key( dst_indexes_flat: torch.Tensor, num_dst_per_src: int, ): - """Run key: constant since static_key already uniquely identifies config.""" return 0 -# ─── Helper functions ───────────────────────────────────────────────────────── - - def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: """Flatten dims after [layer_num, buffer_size] into one. Zero-copy for contiguous tensors.""" if buffer.ndim == 3: @@ -159,9 +153,6 @@ def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: return buffer.view(L, B, -1) -# ─── Autotuned implementations ──────────────────────────────────────────────── - - @autotune( kernel_name="mamba_buffer_copy_1d:v1", configs_gen_func=_get_buffer_copy_configs, @@ -176,8 +167,6 @@ def _copy_mamba_buffer_autotuned( dst_indexes: torch.Tensor, run_config: dict = None, ): - """Autotuned indexed copy implementation.""" - # Default heuristic when autotune is disabled or no config cached if not run_config: d_size = src_buffer.shape[2] # For memory-bound copy, larger BLOCK_D is better (reduces grid size) @@ -271,17 +260,6 @@ def copy_mamba_buffer( src_indexes: torch.Tensor, dst_indexes: torch.Tensor, ): - """ - Indexed 1:1 copy of Mamba recurrent state buffer slots. - - Copies slot src_indexes[i] -> dst_indexes[i] for all layers simultaneously. - - Args: - src_buffer: [layer_num, num_slots, ...] - dst_buffer: [layer_num, num_slots, ...] - src_indexes: source slot indices [num_pairs] - dst_indexes: destination slot indices [num_pairs] - """ assert src_buffer.shape == dst_buffer.shape assert src_indexes.shape == dst_indexes.shape and src_indexes.ndim == 1 @@ -296,17 +274,6 @@ def fork_mamba_buffer( src_indexes: torch.Tensor, dst_indexes: torch.Tensor, ): - """ - Fork Mamba recurrent state slots: one source -> N destinations. - - Used for MTP speculation where parent state is replicated to child slots. - - Args: - src_buffer: [layer_num, num_slots, ...] - dst_buffer: [layer_num, num_slots, ...] - src_indexes: source slot indices [num_src] - dst_indexes: destination slot indices [num_src, num_dst_per_src] - """ assert src_buffer.shape == dst_buffer.shape assert src_indexes.ndim == 1 assert dst_indexes.ndim == 2, f"dst_indexes must be 2D [num_src, num_dst_per_src], got {dst_indexes.shape}" diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 9b0933f22f..a33a737516 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -56,7 +56,7 @@ def get_mamba_cache(self, layer_idx: int): ssm_state = self.ssm_state_cache.buffer[layer_idx] return conv_state, ssm_state - def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): + def copy_state_buffers(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): copy_mamba_buffer( self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes ) @@ -64,7 +64,7 @@ def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes ) - def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + def fork_state_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): fork_mamba_buffer( self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) @@ -72,7 +72,7 @@ def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_index self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) - def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + def fork_ssm_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): """ Fork ONLY SSM states (not conv states) from source indices to destination indices. diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 3a5e048fb9..f85fcec452 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -277,5 +277,5 @@ def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_re all_mtp_buffers = self.req_to_buffer_index[tgt_req_index[:, None], mtp_range[None, :]] # 将 shared buffer 广播到所有 MTP step - self.buffer_mem_manager.copy_buffer_broadcast(src_buffer_index, all_mtp_buffers) + self.buffer_mem_manager.fork_state_buffers(src_buffer_index, all_mtp_buffers) return diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 2a4fe06628..30765a0aa2 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -77,7 +77,7 @@ def insert_for_hybrid_radix_cache(self, reqs): # Move to CUDA and convert to int64, ensure contiguous new_buffer_indexes_cuda = new_buffer_indexes.to(device="cuda", dtype=torch.int64).contiguous() - self.buffer_mem_manager.copy_buffer_p2p(cur_buffer_indexes, new_buffer_indexes_cuda) + self.buffer_mem_manager.copy_state_buffers(cur_buffer_indexes, new_buffer_indexes_cuda) for i, req in enumerate(reqs_to_insert): input_token_ids = req.get_input_token_ids() diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 3cabd97baa..2ea8f07cf6 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -285,8 +285,8 @@ def decode_mtp( # Destination: buffer[0] for each request dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] # P2P copy both conv_states and ssm_states - if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): - g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_state_buffers"): + g_infer_context.req_manager.buffer_mem_manager.copy_state_buffers( src_buffer_indexes, dst_buffer_indexes ) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index c5dd768224..5d0b6c701d 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -463,8 +463,8 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): actual_req_idxes, mtp_accept_len[mask] - 1 ] dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] - if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): - g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_state_buffers"): + g_infer_context.req_manager.buffer_mem_manager.copy_state_buffers( src_buffer_indexes, dst_buffer_indexes ) @@ -790,8 +790,8 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf actual_req_idxes, mtp_accept_len[mask] - 1 ] dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] - if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): - g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_state_buffers"): + g_infer_context.req_manager.buffer_mem_manager.copy_state_buffers( src_buffer_indexes, dst_buffer_indexes ) From b81baaab6f481987531ab29c46e70d19b173dee4 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 2 Mar 2026 05:11:12 +0000 Subject: [PATCH 108/180] clean the code --- .../layer_weights/transformer_layer_weight.py | 6 ---- lightllm/models/qwen3_5_moe/model.py | 29 ------------------- 2 files changed, 35 deletions(-) diff --git a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py index ca1f9d992e..75eb382fa9 100644 --- a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py @@ -13,8 +13,6 @@ def split_fused_expert_weights(weights, layer_num, moe_intermediate_size): layer_prefix = f"model.layers.{layer_num}." keys = list(weights.keys()) - gate_up_count = 0 - down_count = 0 num_experts = 0 for k in keys: @@ -34,8 +32,6 @@ def split_fused_expert_weights(weights, layer_num, moe_intermediate_size): weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] - gate_up_count += 1 - elif "mlp.experts.down_proj" in k: down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] num_experts = down_weight.shape[0] @@ -45,8 +41,6 @@ def split_fused_expert_weights(weights, layer_num, moe_intermediate_size): for expert_idx in range(num_experts): weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx] - down_count += 1 - class Qwen35NextFullAttentionTransformerLayerWeight(Qwen3NextFullAttentionTransformerLayerWeight): def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen3_5_moe/model.py b/lightllm/models/qwen3_5_moe/model.py index 069992bb37..573d563edd 100644 --- a/lightllm/models/qwen3_5_moe/model.py +++ b/lightllm/models/qwen3_5_moe/model.py @@ -8,40 +8,11 @@ @ModelRegistry("qwen3_5_moe", is_multimodal=True) class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): - """ - Qwen3.5-MoE Multimodal Model (Mixture of Experts Variant) - - Extends Qwen3.5 with sparse expert routing: - - Same hybrid attention architecture as Qwen3.5 - - MoE layers replace dense MLP layers - - Expert routing handled by inherited MoE infrastructure - - This model combines: - - Hybrid attention from Qwen3Next (Gated Delta Networks + Full Attention) - - Multimodal capabilities from Qwen3VL (image/video processing) - - MoE sparse routing for efficient scaling - """ - def __init__(self, kvargs): - """ - Initialize Qwen3.5-MoE model. - - Args: - kvargs: Dictionary containing: - - weight_dir: Path to model weights - - max_total_token_num: Maximum total tokens - - Additional model configuration - """ super().__init__(kvargs) logger.info("Initialized Qwen3.5-MoE multimodal model with expert routing") def _init_custom(self): - """ - Initialize MoE-specific components. - - Sets up DeepEP communication group for expert parallelism - when the model has experts configured. - """ super()._init_custom() # Initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: From 52b422ac76934088707bed368d54c3703fb68145 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 2 Mar 2026 08:08:19 +0000 Subject: [PATCH 109/180] vlm tokenizer support token list --- lightllm/models/qwen2_vl/model.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 237c4ad897..421026d731 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -12,6 +12,7 @@ from .vision_process import smart_resize from lightllm.models.qwen2.model import Qwen2TpPartModel import os +from typing import Union, List # Warp of the origal tokenizer class QWen2VLTokenizer(BaseMultiModalTokenizer): @@ -52,9 +53,13 @@ def get_image_token_length(self, img: ImageItem): def get_audio_token_length(self, audio: AudioItem): raise NotImplementedError - def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): - - origin_ids = self.tokenizer.encode(prompt) + def encode(self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams = None, **kwargs): + if isinstance(prompt, str): + origin_ids = self.tokenizer.encode(prompt) + elif isinstance(prompt, List[int]): + origin_ids = prompt + else: + raise ValueError(f"Unsupported prompt type: {type(prompt)}") # -> origin_ids = [token for token in origin_ids if token != self.image_token_id] From aa442a42f70908d2093dcacf398dd5f2e49be998 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 2 Mar 2026 08:16:33 +0000 Subject: [PATCH 110/180] fix --- lightllm/models/qwen2_vl/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 421026d731..c94135573b 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -56,7 +56,7 @@ def get_audio_token_length(self, audio: AudioItem): def encode(self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams = None, **kwargs): if isinstance(prompt, str): origin_ids = self.tokenizer.encode(prompt) - elif isinstance(prompt, List[int]): + elif isinstance(prompt, list): origin_ids = prompt else: raise ValueError(f"Unsupported prompt type: {type(prompt)}") From 0fd0202e13d66eb68a14bf7dbd41c30823fca8da Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 2 Mar 2026 08:46:01 +0000 Subject: [PATCH 111/180] clean code --- .../triton_kernel/mamba_buffer_copy.py | 48 ++----------------- lightllm/common/req_manager.py | 6 +++ lightllm/models/qwen2_vl/qwen2_visual.py | 2 + .../qwen3_omni_visual.py | 2 + lightllm/models/qwen3_vl/qwen3_visual.py | 2 + .../server/router/model_infer/infer_batch.py | 21 ++++---- .../model_infer/mode_backend/base_backend.py | 6 --- 7 files changed, 25 insertions(+), 62 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py index 361c0565ae..bd2aaed530 100644 --- a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -3,8 +3,6 @@ import triton.language as tl from lightllm.common.triton_utils.autotuner import autotune -_MAX_GRID_DIM = 65535 - @triton.jit def _copy_buffer_kernel( @@ -84,15 +82,7 @@ def _get_buffer_copy_configs(): def _get_copy_static_key( src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, ): - """Static key for copy kernel cache: dtype, d_size, layer_num. - - Different models (35B vs 397B) have different optimal configs, so each - should get its own cache file. - """ d_size = ( src_buffer.shape[2] if src_buffer.ndim == 3 @@ -106,22 +96,11 @@ def _get_copy_static_key( } -def _get_copy_run_key( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes: torch.Tensor, -): +def _get_copy_run_key(src_buffer: torch.Tensor): return 0 -def _get_fork_static_key( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes_flat: torch.Tensor, - num_dst_per_src: int, -): +def _get_fork_static_key(src_buffer: torch.Tensor): d_size = ( src_buffer.shape[2] if src_buffer.ndim == 3 @@ -135,18 +114,11 @@ def _get_fork_static_key( } -def _get_fork_run_key( - src_buffer: torch.Tensor, - dst_buffer: torch.Tensor, - src_indexes: torch.Tensor, - dst_indexes_flat: torch.Tensor, - num_dst_per_src: int, -): +def _get_fork_run_key(src_buffer: torch.Tensor): return 0 def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: - """Flatten dims after [layer_num, buffer_size] into one. Zero-copy for contiguous tensors.""" if buffer.ndim == 3: return buffer L, B = buffer.shape[:2] @@ -158,7 +130,6 @@ def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: configs_gen_func=_get_buffer_copy_configs, static_key_func=_get_copy_static_key, run_key_func=_get_copy_run_key, - mutates_args=["dst_buffer"], ) def _copy_mamba_buffer_autotuned( src_buffer: torch.Tensor, @@ -169,7 +140,6 @@ def _copy_mamba_buffer_autotuned( ): if not run_config: d_size = src_buffer.shape[2] - # For memory-bound copy, larger BLOCK_D is better (reduces grid size) BLOCK_D = min(4096, triton.next_power_of_2(d_size)) num_warps = 4 if BLOCK_D >= 1024 else 2 run_config = {"BLOCK_D": BLOCK_D, "num_warps": num_warps, "num_stages": 1} @@ -182,9 +152,6 @@ def _copy_mamba_buffer_autotuned( num_blocks_d = triton.cdiv(d_size, BLOCK_D) - assert num_pairs <= _MAX_GRID_DIM, f"num_pairs={num_pairs} exceeds grid limit {_MAX_GRID_DIM}" - assert layer_num <= _MAX_GRID_DIM, f"layer_num={layer_num} exceeds grid limit {_MAX_GRID_DIM}" - grid = (num_pairs, layer_num, num_blocks_d) _copy_buffer_kernel[grid]( src_buffer, @@ -205,7 +172,6 @@ def _copy_mamba_buffer_autotuned( configs_gen_func=_get_buffer_copy_configs, static_key_func=_get_fork_static_key, run_key_func=_get_fork_run_key, - mutates_args=["dst_buffer"], ) def _fork_mamba_buffer_autotuned( src_buffer: torch.Tensor, @@ -215,8 +181,6 @@ def _fork_mamba_buffer_autotuned( num_dst_per_src: int, run_config: dict = None, ): - """Autotuned fork implementation.""" - # Default heuristic when autotune is disabled or no config cached if not run_config: d_size = src_buffer.shape[2] BLOCK_D = min(4096, triton.next_power_of_2(d_size)) @@ -232,9 +196,6 @@ def _fork_mamba_buffer_autotuned( num_blocks_d = triton.cdiv(d_size, BLOCK_D) total_pairs = num_src * num_dst_per_src - assert total_pairs <= _MAX_GRID_DIM, f"total_pairs={total_pairs} exceeds grid limit {_MAX_GRID_DIM}" - assert layer_num <= _MAX_GRID_DIM, f"layer_num={layer_num} exceeds grid limit {_MAX_GRID_DIM}" - grid = (total_pairs, layer_num, num_blocks_d) _fork_buffer_kernel[grid]( src_buffer, @@ -251,9 +212,6 @@ def _fork_mamba_buffer_autotuned( ) -# ─── Public API ─────────────────────────────────────────────────────────────── - - def copy_mamba_buffer( src_buffer: torch.Tensor, dst_buffer: torch.Tensor, diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index f85fcec452..8874e549e2 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -68,6 +68,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num + self.req_to_buffer_index = None def alloc(self): return self.req_list.alloc() @@ -94,6 +95,11 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + @property + def has_recurrent_state(self): + """Whether this model uses per-request recurrent state buffers (e.g. Mamba/linear attention).""" + return self.req_to_buffer_index is not None + def alloc_buffer_for_req(self, req_index: torch.Tensor): """Allocate buffers for requests. No-op for standard models without linear attention.""" pass diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index a29cb8758b..6076756043 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -57,6 +57,8 @@ def __init__( kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup) + self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index c20c227996..0276724749 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -60,6 +60,8 @@ def __init__( kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup) + self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index 7fc8187ddc..f636715033 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -63,6 +63,8 @@ def __init__( kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup) + self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 57241de967..37e05bd2ff 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any -from lightllm.common.req_manager import ReqManager, ReqManagerForMamba +from lightllm.common.req_manager import ReqManager from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode @@ -38,7 +38,9 @@ class InferenceContext: overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream - use_mamba_model: bool = False + @property + def has_recurrent_state(self): + return self.req_manager is not None and self.req_manager.has_recurrent_state def register( self, @@ -47,7 +49,6 @@ def register( radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int, - use_mamba_model: bool = False, ): self.args = get_env_start_args() from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend @@ -63,12 +64,10 @@ def register( self.vocab_size = vocab_size - self.use_mamba_model = use_mamba_model - if self.use_mamba_model: + if self.has_recurrent_state: assert self.radix_cache is None or isinstance( self.radix_cache, HybridRadixCache - ), "Mamba model only support HybridRadixCache" - assert isinstance(self.req_manager, ReqManagerForMamba), "Mamba model only support ReqManagerForMamba" + ), "Recurrent state models only support HybridRadixCache" self.mtp_step = get_env_start_args().mtp_step return @@ -205,7 +204,7 @@ def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> b def _free_req_mem_and_buffers(self, free_token_index: List, free_buffer_index: List, req: "InferReq"): """释放请求的 KV cache 和 buffer 内存""" - if self.use_mamba_model: + if self.has_recurrent_state: need_free_base_buffer = self.free_a_req_mem_for_mamba(free_token_index, req) req_to_buffer_index = self.req_manager.req_to_buffer_index if need_free_base_buffer: @@ -251,7 +250,7 @@ def _filter(self, finished_request_ids: List[int]): free_token_index = custom_cat(free_token_index) self.req_manager.free(free_req_index, free_token_index) - if self.use_mamba_model and len(free_buffer_index) != 0: + if len(free_buffer_index) != 0: self.req_manager.free_buffer(free_buffer_index) finished_req_ids_set = set(finished_request_ids) @@ -301,7 +300,7 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) - if self.use_mamba_model and len(free_buffer_index) != 0: + if len(free_buffer_index) != 0: self.req_manager.free_buffer(free_buffer_index) g_infer_state_lock.release() @@ -513,7 +512,7 @@ def _match_radix_cache(self): self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 - if g_infer_context.use_mamba_model: + if g_infer_context.has_recurrent_state: MAMBA_PREFILL_BLOCK_SIZE = 128 MAMBA_MIN_INSERT_LEN = 1024 miss_prefix_len = miss_prefix_len - miss_prefix_len % MAMBA_PREFILL_BLOCK_SIZE diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 57a3508e93..92102a90d8 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -193,18 +193,12 @@ def init_model(self, kvargs): self.logger.info(f"loaded model class {self.model.__class__}") - # Check if the model uses Mamba (linear attention) layers - from lightllm.common.req_manager import ReqManagerForMamba - - use_mamba_model = isinstance(self.model.req_manager, ReqManagerForMamba) - g_infer_context.register( backend=self, req_manager=self.model.req_manager, radix_cache=self.radix_cache, shm_req_manager=self.shm_req_manager, vocab_size=self.model.vocab_size, - use_mamba_model=use_mamba_model, ) # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 From b9a386e5b3aa3e3515f8699e6ec713c54bb274b0 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Mar 2026 06:38:17 +0000 Subject: [PATCH 112/180] code simplify --- .../layer_weights/meta_weights/__init__.py | 3 +- .../layer_weights/meta_weights/norm_weight.py | 10 +- .../layer_weights/transformer_layer_weight.py | 2 +- .../layer_infer/transformer_layer_infer.py | 7 +- .../layer_weights/transformer_layer_weight.py | 9 +- lightllm/models/qwen3_5/model.py | 38 +- .../qwen3_5_moe/layer_infer/__init__.py | 0 .../layer_weights/transformer_layer_weight.py | 40 ++ lightllm/models/qwen3_5_moe/model.py | 15 +- .../layer_infer/shared_expert_mixin.py | 96 ---- .../layer_infer/transformer_layer_infer.py | 500 +++--------------- .../layer_weights/transformer_layer_weight.py | 248 +++------ lightllm/models/qwen3next/model.py | 28 +- 13 files changed, 209 insertions(+), 787 deletions(-) delete mode 100644 lightllm/models/qwen3_5_moe/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py delete mode 100644 lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index dc0683294c..21b5b7959e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -10,10 +10,11 @@ from .norm_weight import ( TpRMSNormWeight, RMSNormWeight, + GEMMANormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight, - QKRMSNORMWeightGEMMANormWeight, + QKGEMMANormWeight, ) from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index f69fe4e1ab..89a3d24119 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -71,6 +71,14 @@ def __call__( return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) +class GEMMANormWeight(RMSNormWeight): + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name]) + self.weight += 1 + self.weight.load_ok = True + + class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): super().__init__(tp_rank=0, tp_world_size=1) @@ -278,7 +286,7 @@ def __call__( return self._forward(q=q, k=k, eps=eps) -class QKRMSNORMWeightGEMMANormWeight(QKRMSNORMWeight): +class QKGEMMANormWeight(QKRMSNORMWeight): def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.q_weight_name in weights: self.q_weight.copy_(weights[self.q_weight_name]) diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 0566c9f1c6..e42a6191e6 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -31,7 +31,7 @@ def _parse_config(self): head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] self.head_dim = self.network_config_.get("head_dim", head_dim) self.n_embed = self.network_config_["hidden_size"] - self.n_inter = self.network_config_["intermediate_size"] + self.n_inter = self.network_config_.get("intermediate_size", -1) def _init_weight_names(self): self._q_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index 3cd07f39ae..64ecf94edb 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -6,9 +6,8 @@ Qwen3NextFullAttentionTransformerLayerInfer, Qwen3NextGatedDeltaNetTransformerLayerInfer, ) -from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( - Qwen3NextFullAttentionTransformerLayerWeight, - Qwen3NextGatedDeltaNetTransformerLayerWeight, +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35TransformerLayerWeight, ) from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -29,7 +28,7 @@ def _get_qkv( self, input: torch.Tensor, infer_state: LlamaInferStateInfo, - layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + layer_weight: Qwen35TransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) diff --git a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py index 0605e7e3b7..9f91f3db8b 100644 --- a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py @@ -2,19 +2,14 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( - Qwen3NextFullAttentionTransformerLayerWeight, - Qwen3NextGatedDeltaNetTransformerLayerWeight, + Qwen3NextTransformerLayerWeight, ) from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) -class Qwen35NextFullAttentionTransformerLayerWeight(Qwen3NextFullAttentionTransformerLayerWeight): - pass - - -class Qwen35NextGatedDeltaNetTransformerLayerWeight(Qwen3NextGatedDeltaNetTransformerLayerWeight): +class Qwen35TransformerLayerWeight(Qwen3NextTransformerLayerWeight): def _init_gdn_weight(self): # Initialize everything from parent first, then override only linear_in_proj. super()._init_gdn_weight() diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index 33f398a9af..f29d50476b 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -3,8 +3,7 @@ from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3next.model import Qwen3NextTpPartModel from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( - Qwen35NextFullAttentionTransformerLayerWeight, - Qwen35NextGatedDeltaNetTransformerLayerWeight, + Qwen35TransformerLayerWeight, ) from lightllm.models.qwen3_vl.model import QWen3VLTokenizer from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( @@ -54,7 +53,7 @@ class Qwen3_5TpPartModel(Qwen3NextTpPartModel): """ pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer - + transformer_weight_class = Qwen35TransformerLayerWeight pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight infer_state_class = Qwen35InferStateInfo @@ -76,18 +75,7 @@ def _init_config(self): repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) - # # Qwen3.5 MoE uses moe_intermediate_size instead of intermediate_size - # # Set intermediate_size for compatibility with base layer weight classes - # if "intermediate_size" not in self.config: - # if "moe_intermediate_size" in self.config: - # self.config["intermediate_size"] = self.config["moe_intermediate_size"] - # else: - # # Default fallback: 4x hidden_size (common in transformer architectures) - # self.config["intermediate_size"] = self.config.get("hidden_size", 4096) * 4 - # Qwen3.5 stores RoPE config under text_config.rope_parameters. - # Qwen3Next/llama infer path expects flattened keys like rope_theta and - # partial_rotary_factor on the main config dict. rope_parameters = self.config.get("rope_parameters") if isinstance(rope_parameters, dict): if "rope_theta" in rope_parameters and "rope_theta" not in self.config: @@ -110,28 +98,6 @@ def _init_config(self): # Required by parent class _init_mem_manager() in Qwen3NextTpPartModel self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) - num_full_attention_layers = self.config["full_attention_interval"] - self.trans_layers_weight = [ - ( - Qwen35NextFullAttentionTransformerLayerWeight( - i, - self.data_type, - network_config=self.config, - quant_cfg=self.quant_cfg, - ) - if (i + 1) % num_full_attention_layers == 0 - else Qwen35NextGatedDeltaNetTransformerLayerWeight( - i, - self.data_type, - network_config=self.config, - quant_cfg=self.quant_cfg, - ) - ) - for i in range(self.config["n_layer"]) - ] - def _init_infer_layer(self): """ Initialize inference layers for Qwen3.5 multimodal model. diff --git a/lightllm/models/qwen3_5_moe/layer_infer/__init__.py b/lightllm/models/qwen3_5_moe/layer_infer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..fe4b1883bd --- /dev/null +++ b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py @@ -0,0 +1,40 @@ +import torch +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import Qwen35TransformerLayerWeight + + +class Qwen35MOETransformerLayerWeight(Qwen35TransformerLayerWeight): + def load_hf_weights(self, weights): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + split_fused_expert_weights(weights, self.layer_num_, moe_intermediate_size) + return super().load_hf_weights(weights) + + +def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_size: int): + layer_prefix = f"model.layers.{layer_num}." + keys = list(weights.keys()) + num_experts = 0 + + for k in keys: + if not k.startswith(layer_prefix): + continue + + if "mlp.experts.gate_up_proj" in k: + fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size] + num_experts = fused_weight.shape[0] + + prefix = k.rsplit(".gate_up_proj", 1)[0] + gate_weight = fused_weight[:, :moe_intermediate_size, :] + up_weight = fused_weight[:, moe_intermediate_size:, :] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] + weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] + + elif "mlp.experts.down_proj" in k: + down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] + num_experts = down_weight.shape[0] + + prefix = k.rsplit(".down_proj", 1)[0] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx] diff --git a/lightllm/models/qwen3_5_moe/model.py b/lightllm/models/qwen3_5_moe/model.py index 573d563edd..973274774f 100644 --- a/lightllm/models/qwen3_5_moe/model.py +++ b/lightllm/models/qwen3_5_moe/model.py @@ -1,19 +1,12 @@ from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel from lightllm.utils.log_utils import init_logger -from lightllm.distributed.communication_op import dist_group_manager - -logger = init_logger(__name__) +from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import ( + Qwen35MOETransformerLayerWeight, +) @ModelRegistry("qwen3_5_moe", is_multimodal=True) class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): - def __init__(self, kvargs): - super().__init__(kvargs) - logger.info("Initialized Qwen3.5-MoE multimodal model with expert routing") - def _init_custom(self): - super()._init_custom() - # Initialize DeepEP group for MoE models with num_experts - if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + transformer_weight_class = Qwen35MOETransformerLayerWeight diff --git a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py deleted file mode 100644 index be9000fcad..0000000000 --- a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py +++ /dev/null @@ -1,96 +0,0 @@ -# lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py -import torch.nn.functional as F -from functools import partial -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -import os - - -class SharedExpertFFNMixin: - """ - Mixin providing shared expert + MoE FFN implementations. - - Used by both full attention and GDN layers in Qwen3Next. - - Requirements: - - Class must have: embed_dim_, tp_world_size_, alloc_tensor() - - Class must have MoE config: is_moe, n_routed_experts, num_experts_per_tok, norm_topk_prob - """ - - def _bind_ffn(self): - """Bind FFN implementation based on MoE configuration.""" - if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": - self._ffn = partial(SharedExpertFFNMixin._ffn_with_shared_expert_ep, self) - else: - self._ffn = partial(SharedExpertFFNMixin._ffn_with_shared_expert_tp, self) - else: - self._ffn = partial(SharedExpertFFNMixin._standard_ffn, self) - return - - def _ffn_core(self, input, layer_weight): - """Core FFN computation: gate_up -> silu_and_mul -> down.""" - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) - return ffn2_out, input - - def _standard_ffn(self, input, infer_state, layer_weight): - """Standard FFN using shared expert weights (non-MoE layers).""" - ffn2_out, _ = self._ffn_core(input, layer_weight) - return ffn2_out - - def _compute_shared_expert(self, input, layer_weight): - """Compute shared expert FFN output with gating.""" - ffn2_out, input_view = self._ffn_core(input, layer_weight) - return F.sigmoid(layer_weight.shared_expert_gate.mm(input_view)) * ffn2_out, input_view - - def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): - """FFN with shared expert + MoE (tensor parallelism mode).""" - shared_expert_out, input = self._compute_shared_expert(input, layer_weight) - moe_out = self._moe_ffn(input, infer_state, layer_weight) - return shared_expert_out + moe_out - - def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): - """FFN with shared expert + MoE (expert parallelism mode).""" - shared_expert_out, input = self._compute_shared_expert(input, layer_weight) - moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) - return shared_expert_out + moe_out - - def _moe_ffn(self, input, infer_state, layer_weight): - """MoE FFN with tensor parallelism.""" - hidden_states = input.view(-1, self.embed_dim_) - num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(hidden_states) - layer_weight.experts.experts( - hidden_states, - router_logits=router_logits, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - use_grouped_topk=False, - topk_group=None, - num_expert_group=None, - ) - return hidden_states.view(num_tokens, hidden_dim) - - def _moe_ffn_edp(self, input, infer_state, layer_weight): - """MoE FFN with expert parallelism.""" - hidden_states = input - token_num, hidden_dim = hidden_states.shape - - router_logits = layer_weight.moe_gate.mm(hidden_states) - ep_output = layer_weight.experts.experts( - hidden_states, - router_logits=router_logits, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - use_grouped_topk=False, - topk_group=None, - num_expert_group=None, - is_prefill=infer_state.is_prefill, - ) - - ep_output = ep_output.view(token_num, hidden_dim) - return ep_output diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index f121be7001..5732dc41e3 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -3,11 +3,9 @@ import torch.distributed as dist from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( - Qwen3NextFullAttentionTransformerLayerWeight, - Qwen3NextGatedDeltaNetTransformerLayerWeight, + Qwen3NextTransformerLayerWeight, ) from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl from lightllm.utils.log_utils import init_logger @@ -26,31 +24,13 @@ ) from lightllm.distributed import all_reduce from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward -from lightllm.models.qwen3next.triton_kernel.fused_add_gemma_rmsnorm import fused_add_gemma_rmsnorm -from lightllm.models.qwen3next.triton_kernel.fused_split_copy import fused_split_copy_qkvzba, fused_split_copy_qkv from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type from functools import partial logger = init_logger(__name__) -class GemmaRMSNormMixin: - """ - Mixin providing Gemma-style RMSNorm implementations. - - Requirements: - - Class must have: eps_, alloc_tensor() - """ - - def _gemma_norm_with_pool(self, input, norm_weight): - """Apply Gemma RMSNorm.""" - out = self.alloc_tensor(input.shape, input.dtype) - gemma_rmsnorm_forward(input, norm_weight, self.eps_, out=out) - return out - - -class Qwen3NextFullAttentionBaseLayerInfer(GemmaRMSNormMixin, LlamaTransformerLayerInfer): +class Qwen3NextFullAttentionBaseLayerInfer(LlamaTransformerLayerInfer): """ Base class for Qwen3Next full attention layers. Contains shared logic for both standard full attention and MTP layers. @@ -68,127 +48,47 @@ def __init__(self, layer_num, network_config): self.norm_topk_prob = network_config.get("norm_topk_prob", False) super().__init__(layer_num, network_config) - # Override head_dim which may be different in Qwen3Next self.head_dim_ = network_config.get( "head_dim", network_config["hidden_size"] // network_config["num_attention_heads"] ) - - # Pre-allocated decode buffers (mirrors GDN layer pattern) - start_args = get_env_start_args() - self._decode_buffers = {} - self._graph_max_batch_size = start_args.graph_max_batch_size - - # Pre-compute dims for decode buffer pre-allocation - self.shared_inter_size = network_config.get("shared_expert_intermediate_size", 0) - self.tp_gate_up_dim = 2 * self.shared_inter_size // self.tp_world_size_ if self.shared_inter_size > 0 else 0 - self.tp_q_gate_dim = (self.tp_q_head_num_ + self.tp_o_head_num_) * self.head_dim_ - self.tp_kv_dim = (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_ - return - def _get_decode_buffer(self, name, max_shape, dtype, device): - """Get or create a pre-allocated buffer for the decode path.""" - key = (name, dtype, device if isinstance(device, str) else str(device)) - if key not in self._decode_buffers: - self._decode_buffers[key] = torch.empty(max_shape, dtype=dtype, device=device) - return self._decode_buffers[key] - def _bind_func(self): super()._bind_func() self._bind_ffn() return - def _bind_norm(self): - """Use Gemma-style RMSNorm""" - self._att_norm = partial(Qwen3NextFullAttentionBaseLayerInfer._att_norm_impl, self) - self._ffn_norm = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_norm_impl, self) - return - def _bind_ffn(self): """Bind FFN implementation based on MoE configuration.""" if self.is_moe: moe_mode = os.environ.get("MOE_MODE", "TP") if moe_mode == "EP": - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_ep, self) + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._moe_ffn_edp, self) else: - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_tp, self) + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._moe_ffn, self) else: - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._standard_ffn, self) + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn, self) return - def _ffn_core(self, input, layer_weight, is_decode=False): - """Core FFN computation: gate_up -> silu_and_mul -> down.""" + def _compute_shared_expert( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): input = input.view(-1, self.embed_dim_) - if is_decode and self.tp_gate_up_dim > 0: - up_gate_buf = self._get_decode_buffer( - "up_gate_out", - (self._graph_max_batch_size, self.tp_gate_up_dim), - input.dtype, - input.device, - )[: input.size(0)] - up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input, out=up_gate_buf) - else: - up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) - inter_dim = up_gate_out.size(1) // 2 - if is_decode: - ffn1_out = self._get_decode_buffer( - "ffn1_out", (self._graph_max_batch_size, inter_dim), input.dtype, input.device - )[: input.size(0)] - else: - ffn1_out = self.alloc_tensor((input.size(0), inter_dim), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) - return ffn2_out, input - - def _standard_ffn(self, input, infer_state, layer_weight): - """Standard FFN using shared expert weights (non-MoE layers).""" - # For dense models without shared experts, return zeros (no FFN computation) - if not hasattr(layer_weight, "shared_expert_gate_up_proj") or layer_weight.shared_expert_gate_up_proj is None: - return torch.zeros_like(input) - ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) - return ffn2_out - - def _compute_shared_expert(self, input, layer_weight, is_decode=False): - """Compute shared expert FFN output with gating.""" - ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) - # Dense models don't have shared_expert_gate - if layer_weight.shared_expert_gate is not None: - gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() - ffn2_out.mul_(gate) - return ffn2_out, input_view - - def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): - """FFN with shared expert + MoE (tensor parallelism mode).""" - shared_expert_out, input = self._compute_shared_expert( - input, layer_weight, is_decode=not infer_state.is_prefill - ) - moe_out = self._moe_ffn(input, infer_state, layer_weight) - moe_out.add_(shared_expert_out) - return moe_out - - def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): - """FFN with shared expert + MoE (expert parallelism mode).""" - shared_expert_out, input = self._compute_shared_expert( - input, layer_weight, is_decode=not infer_state.is_prefill - ) - moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) - moe_out.add_(shared_expert_out) - return moe_out + shared_expert_out = super()._ffn(input, infer_state, layer_weight) + gate = layer_weight.ffn_gate.mm(input).sigmoid_() + shared_expert_out.mul_(gate) + return shared_expert_out - def _moe_ffn(self, input, infer_state, layer_weight): + def _moe_ffn( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): """MoE FFN with tensor parallelism.""" + + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - if not infer_state.is_prefill: - router_buf = self._get_decode_buffer( - "router_logits", - (self._graph_max_batch_size, self.n_routed_experts), - hidden_states.dtype, - hidden_states.device, - )[:num_tokens] - router_logits = layer_weight.moe_gate.mm(hidden_states, out=router_buf) - else: - router_logits = layer_weight.moe_gate.mm(hidden_states) + router_logits = layer_weight.moe_gate.mm(hidden_states) layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -198,10 +98,15 @@ def _moe_ffn(self, input, infer_state, layer_weight): topk_group=None, num_expert_group=None, ) - return hidden_states.view(num_tokens, hidden_dim) + hidden_states = hidden_states.view(num_tokens, hidden_dim) + hidden_states.add_(shared_expert_out) + return hidden_states - def _moe_ffn_edp(self, input, infer_state, layer_weight): + def _moe_ffn_edp( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): """MoE FFN with expert parallelism.""" + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input token_num, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -216,29 +121,14 @@ def _moe_ffn_edp(self, input, infer_state, layer_weight): is_prefill=infer_state.is_prefill, ) ep_output = ep_output.view(token_num, hidden_dim) + ep_output.add_(shared_expert_out) return ep_output - def _att_norm_impl( - self, - input, - _infer_state: LlamaInferStateInfo, - layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, - ) -> torch.Tensor: - return self._gemma_norm_with_pool(input, layer_weight.att_norm_weight_.weight) - - def _ffn_norm_impl( - self, - input, - _infer_state: LlamaInferStateInfo, - layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, - ) -> torch.Tensor: - return self._gemma_norm_with_pool(input, layer_weight.ffn_norm_weight_.weight) - def _get_qkv( self, input: torch.Tensor, - infer_state: LlamaInferStateInfo, - layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: """ QKV projection with output gating, Q/K normalization, and partial rotary embedding. @@ -270,8 +160,8 @@ def _get_qkv( def _get_o( self, input, - infer_state: LlamaInferStateInfo, - layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, ) -> torch.Tensor: """Output projection with gating (in-place multiply to save one allocation).""" input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) @@ -280,36 +170,6 @@ def _get_o( o_tensor = layer_weight.o_proj.mm(input) return o_tensor - def token_forward(self, input_embdings, infer_state, layer_weight): - """Override token_forward to use pre-allocated decode buffers and fused kernels.""" - max_tokens = self._graph_max_batch_size - input1 = self._get_decode_buffer( - "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device - )[: input_embdings.shape[0]] - gemma_rmsnorm_forward(input_embdings, layer_weight.att_norm_weight_.weight, self.eps_, out=input1) - - o = self.token_attention_forward(input1, infer_state, layer_weight) - - # Fused residual add + FFN norm: saves 1 kernel launch + 1 read of input_embdings - input1 = self._get_decode_buffer( - "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device - )[: input_embdings.shape[0]] - fused_add_gemma_rmsnorm( - input_embdings, - o.view(-1, self.embed_dim_), - layer_weight.ffn_norm_weight_.weight, - self.eps_, - out=input1, - ) - o = None - - ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return input_embdings - class Qwen3NextFullAttentionTransformerLayerInfer(Qwen3NextFullAttentionBaseLayerInfer): """ @@ -320,29 +180,22 @@ class Qwen3NextFullAttentionTransformerLayerInfer(Qwen3NextFullAttentionBaseLaye pass -class Qwen3NextGatedDeltaNetTransformerLayerInfer(GemmaRMSNormMixin, TransformerLayerInferTpl): +class Qwen3NextGatedDeltaNetTransformerLayerInfer(LlamaTransformerLayerInfer): """ Linear attention (Gated Delta Networks) layer for Qwen3Next. """ def __init__(self, layer_num, network_config): - super().__init__(layer_num, network_config) - self.network_config_ = network_config - - # MoE configuration self.n_routed_experts = network_config.get("num_experts", 0) self.is_moe = ( network_config.get("num_experts", 0) > 0 and layer_num not in network_config.get("mlp_only_layers", []) and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) + super().__init__(layer_num, network_config) + # MoE configuration self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) self.norm_topk_prob = network_config.get("norm_topk_prob", False) - self.shared_inter_size = network_config.get("shared_expert_intermediate_size", 0) - - # Standard layer dimensions - self.eps_ = network_config["rms_norm_eps"] - self.embed_dim_ = network_config["hidden_size"] # Linear attention specific dimensions self.num_v_heads = network_config["linear_num_value_heads"] @@ -385,126 +238,45 @@ def __init__(self, layer_num, network_config): # GDN kernel output dtype is self.data_type # Conversion needed only if SSM state uses different dtype self.needs_ssm_dtype_conversion = get_llm_data_type() != self.ssm_state_dtype - - # Pre-allocated decode buffers to avoid repeated allocation during CUDA graph replay. - # Buffers are lazily allocated on first decode call, sized to graph_max_batch_size. - self._decode_buffers = {} - self._graph_max_batch_size = start_args.graph_max_batch_size - - # Pre-compute FFN dims for decode buffer pre-allocation - self.tp_gate_up_dim = 2 * self.shared_inter_size // self.tp_world_size_ if self.shared_inter_size > 0 else 0 - self._bind_func() return - def _get_decode_buffer(self, name, max_shape, dtype, device): - """Get or create a pre-allocated buffer for the decode path. - - On first call, allocates a buffer at max_shape. On subsequent calls, - returns the same buffer (caller should slice to actual batch size). - """ - key = (name, dtype, device if isinstance(device, str) else str(device)) - if key not in self._decode_buffers: - self._decode_buffers[key] = torch.empty(max_shape, dtype=dtype, device=device) - return self._decode_buffers[key] - def _bind_func(self): """Bind layer-specific implementations""" - self._bind_norm() self._bind_ffn() return - def _bind_norm(self): - """Use Gemma-style RMSNorm""" - self._att_norm = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._att_norm_impl, self) - self._ffn_norm = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_norm_impl, self) - return - def _bind_ffn(self): """Bind FFN implementation based on MoE configuration.""" if self.is_moe: moe_mode = os.environ.get("MOE_MODE", "TP") if moe_mode == "EP": - self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_with_shared_expert_ep, self) + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._moe_ffn_edp, self) else: - self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_with_shared_expert_tp, self) + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._moe_ffn, self) else: - self._ffn = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._standard_ffn, self) + self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn, self) return - def _ffn_core(self, input, layer_weight, is_decode=False): - """Core FFN computation: gate_up -> silu_and_mul -> down.""" + def _compute_shared_expert( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): input = input.view(-1, self.embed_dim_) - if is_decode and self.tp_gate_up_dim > 0: - up_gate_buf = self._get_decode_buffer( - "up_gate_out", - (self._graph_max_batch_size * self.mtp_size, self.tp_gate_up_dim), - input.dtype, - input.device, - )[: input.size(0)] - up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input, out=up_gate_buf) - else: - up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) - inter_dim = up_gate_out.size(1) // 2 - if is_decode: - ffn1_out = self._get_decode_buffer( - "ffn1_out", (self._graph_max_batch_size, inter_dim), input.dtype, input.device - )[: input.size(0)] - else: - ffn1_out = self.alloc_tensor((input.size(0), inter_dim), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) - return ffn2_out, input - - def _standard_ffn(self, input, infer_state, layer_weight): - """Standard FFN using shared expert weights (non-MoE layers).""" - # For dense models without shared experts, return zeros (no FFN computation) - if not hasattr(layer_weight, "shared_expert_gate_up_proj") or layer_weight.shared_expert_gate_up_proj is None: - return torch.zeros_like(input) - ffn2_out, _ = self._ffn_core(input, layer_weight, is_decode=not infer_state.is_prefill) - return ffn2_out - - def _compute_shared_expert(self, input, layer_weight, is_decode=False): - """Compute shared expert FFN output with gating.""" - ffn2_out, input_view = self._ffn_core(input, layer_weight, is_decode=is_decode) - # Dense models don't have shared_expert_gate - if layer_weight.shared_expert_gate is not None: - gate = layer_weight.shared_expert_gate.mm(input_view).sigmoid_() - ffn2_out.mul_(gate) - return ffn2_out, input_view - - def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): - """FFN with shared expert + MoE (tensor parallelism mode).""" - shared_expert_out, input = self._compute_shared_expert( - input, layer_weight, is_decode=not infer_state.is_prefill - ) - moe_out = self._moe_ffn(input, infer_state, layer_weight) - moe_out.add_(shared_expert_out) - return moe_out - - def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): - """FFN with shared expert + MoE (expert parallelism mode).""" - shared_expert_out, input = self._compute_shared_expert( - input, layer_weight, is_decode=not infer_state.is_prefill - ) - moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) - moe_out.add_(shared_expert_out) - return moe_out + shared_expert_out = super()._ffn(input, infer_state, layer_weight) + gate = layer_weight.ffn_gate.mm(input).sigmoid_() + shared_expert_out.mul_(gate) + return shared_expert_out - def _moe_ffn(self, input, infer_state, layer_weight): + def _moe_ffn( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): """MoE FFN with tensor parallelism.""" + + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - if not infer_state.is_prefill: - router_buf = self._get_decode_buffer( - "router_logits", - (self._graph_max_batch_size * self.mtp_size, self.n_routed_experts), - hidden_states.dtype, - hidden_states.device, - )[:num_tokens] - router_logits = layer_weight.moe_gate.mm(hidden_states, out=router_buf) - else: - router_logits = layer_weight.moe_gate.mm(hidden_states) + router_logits = layer_weight.moe_gate.mm(hidden_states) layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -514,10 +286,15 @@ def _moe_ffn(self, input, infer_state, layer_weight): topk_group=None, num_expert_group=None, ) - return hidden_states.view(num_tokens, hidden_dim) + hidden_states = hidden_states.view(num_tokens, hidden_dim) + hidden_states.add_(shared_expert_out) + return hidden_states - def _moe_ffn_edp(self, input, infer_state, layer_weight): + def _moe_ffn_edp( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): """MoE FFN with expert parallelism.""" + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input token_num, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -532,130 +309,27 @@ def _moe_ffn_edp(self, input, infer_state, layer_weight): is_prefill=infer_state.is_prefill, ) ep_output = ep_output.view(token_num, hidden_dim) + ep_output.add_(shared_expert_out) return ep_output - def _att_norm_impl( - self, - input, - _infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, - ) -> torch.Tensor: - return self._gemma_norm_with_pool(input, layer_weight.att_norm_weight_.weight) - - def _ffn_norm_impl( - self, - input, - _infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, - ) -> torch.Tensor: - return self._gemma_norm_with_pool(input, layer_weight.ffn_norm_weight_.weight) - - def _get_qkv( - self, - _input: torch.Tensor, - _infer_state: Qwen3NextInferStateInfo, - _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Not used by GDN - QKV projection handled in gdn_forward. - - GDN uses a fused projection that includes z, b, a parameters - in addition to q, k, v, so the standard template flow doesn't apply. - This method exists to satisfy the template interface. - """ - pass # Implementation in gdn_forward - - def _tpsp_get_qkv( - self, - _input: torch.Tensor, - _infer_state: Qwen3NextInferStateInfo, - _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """TPSP mode not implemented for GDN layers.""" - pass # No TPSP support planned - - def _get_o( - self, - _input, - _infer_state: Qwen3NextInferStateInfo, - _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, - ) -> torch.Tensor: - """ - Not used by GDN - output projection handled in gdn_forward. - - Output computation is fused with GDN recurrence in gdn_forward. - """ - pass # Implementation in gdn_forward - - def _tpsp_get_o( - self, - _input, - _infer_state: Qwen3NextInferStateInfo, - _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, - ) -> torch.Tensor: - """TPSP mode not implemented for GDN layers.""" - pass # No TPSP support planned - - def _context_attention_kernel( - self, - _q: torch.Tensor, - _kv: torch.Tensor, - _infer_state: Qwen3NextInferStateInfo, - _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, - ) -> torch.Tensor: - """Not used by GDN - attention computed in gdn_forward.""" - pass # Implementation in gdn_forward - - def _token_attention_kernel( - self, - _q: torch.Tensor, - _infer_state: Qwen3NextInferStateInfo, - _layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, - ) -> torch.Tensor: - """Not used by GDN - attention computed in gdn_forward.""" - pass # Implementation in gdn_forward - def _gdn_layer_forward( self, input_embdings, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, is_prefill: bool, ): """Unified forward for both prefill and decode in GDN layers.""" # Attention + GDN processing - if is_prefill: - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - else: - # Decode: use pre-allocated buffer to avoid alloc_tensor overhead - max_tokens = self._graph_max_batch_size * self.mtp_size - input1 = self._get_decode_buffer( - "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device - )[: input_embdings.shape[0]] - gemma_rmsnorm_forward(input_embdings, layer_weight.att_norm_weight_.weight, self.eps_, out=input1) - + input1 = layer_weight.att_norm_weight_(input=input_embdings, eps=self.eps_, alloc_func=self.alloc_tensor) gdn_out = self.gdn_forward(input1, infer_state, layer_weight, is_prefill=is_prefill) if self.tp_world_size_ > 1: all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) # FFN - if is_prefill: - input_embdings.add_(gdn_out.view(-1, self.embed_dim_)) - gdn_out = None - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - else: - # Decode: fused residual add + FFN norm saves 1 kernel + 1 read of input_embdings - input1 = self._get_decode_buffer( - "att_norm_out", (max_tokens, self.embed_dim_), input_embdings.dtype, input_embdings.device - )[: input_embdings.shape[0]] - fused_add_gemma_rmsnorm( - input_embdings, - gdn_out.view(-1, self.embed_dim_), - layer_weight.ffn_norm_weight_.weight, - self.eps_, - out=input1, - ) - gdn_out = None + input_embdings.add_(gdn_out.view(-1, self.embed_dim_)) + gdn_out = None + input1 = layer_weight.ffn_norm_weight_(input=input_embdings, eps=self.eps_, alloc_func=self.alloc_tensor) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None @@ -668,7 +342,7 @@ def context_forward( self, input_embdings, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): """Override context_forward to use GDN logic instead of standard attention flow.""" return self._gdn_layer_forward(input_embdings, infer_state, layer_weight, is_prefill=True) @@ -677,7 +351,7 @@ def token_forward( self, input_embdings, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): """Override token_forward to use GDN logic instead of standard attention flow.""" return self._gdn_layer_forward(input_embdings, infer_state, layer_weight, is_prefill=False) @@ -688,7 +362,7 @@ def overlap_tpsp_token_forward( input_embdings1, infer_state: Qwen3NextInferStateInfo, infer_state1: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): """Microbatch overlap for decode: process two half-batches sequentially. Enables --enable_decode_microbatch_overlap for GDN layers.""" @@ -702,7 +376,7 @@ def overlap_tpsp_context_forward( input_embdings1, infer_state: Qwen3NextInferStateInfo, infer_state1: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): """Microbatch overlap for context: process two half-batches sequentially.""" input_embdings = self.context_forward(input_embdings, infer_state, layer_weight) @@ -775,7 +449,7 @@ def context_attention_forward( self, input_embdings, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) return gdn_out @@ -784,7 +458,7 @@ def token_attention_forward( self, input_embdings, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=False) return gdn_out @@ -797,7 +471,7 @@ def _gdn_prefill_kernel( g: torch.Tensor, beta: torch.Tensor, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): """Prefill kernel for GDN forward pass.""" # Conv1D processing @@ -845,7 +519,7 @@ def _gdn_decode_kernel( a: torch.Tensor, b: torch.Tensor, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): """Decode kernel for GDN forward pass (single-token, non-MTP mode). Uses fused gating: g/beta computed inline in the recurrent kernel.""" @@ -886,7 +560,7 @@ def _gdn_decode_mtp_kernel( g: torch.Tensor, beta: torch.Tensor, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, ): """ Optimized decode kernel for GDN forward pass (MTP mode with multiple steps). @@ -969,7 +643,7 @@ def gdn_forward( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + layer_weight: Qwen3NextTransformerLayerWeight, is_prefill: bool, ): assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) @@ -978,18 +652,7 @@ def gdn_forward( input = input.view(-1, self.embed_dim_) conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) - if not is_prefill: - # Decode: pre-allocate GEMM output to avoid cache tensor manager overhead - in_proj_out_dim = self.tp_qkvz_dim + self.tp_ba_dim - in_proj_out = self._get_decode_buffer( - "in_proj_out", - (self._graph_max_batch_size * self.mtp_size, in_proj_out_dim), - input.dtype, - input.device, - )[: input.shape[0]] - mixed_qkvzba = layer_weight.linear_in_proj.mm(input, out=in_proj_out) - else: - mixed_qkvzba = layer_weight.linear_in_proj.mm(input) + mixed_qkvzba = layer_weight.linear_in_proj.mm(input) # mixed_qkv is now returned pre-concatenated (no torch.cat needed) mixed_qkv, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba, is_decode=not is_prefill) @@ -1014,18 +677,7 @@ def gdn_forward( num_tokens = z.shape[0] # batch (decode) or total_tokens (prefill/MTP) core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - if not is_prefill: - # Decode: use pre-allocated buffer for norm output to avoid alloc_tensor - max_decode_tokens = self._graph_max_batch_size * self.mtp_size - flat_size = max_decode_tokens * self.tp_num_v_heads - norm_out = self._get_decode_buffer( - "gdn_norm_out", - (flat_size, self.head_v_dim), - core_attn_out.dtype, - core_attn_out.device, - )[: core_attn_out.shape[0]] - else: - norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) + norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) gated_rmsnorm_forward( core_attn_out, layer_weight.linear_norm.weight, diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 15d3d954b5..be68e6aeb1 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -4,15 +4,17 @@ ROWMMWeight, COLMMWeight, RMSNormWeight, + GEMMANormWeight, TpParameterWeight, - KVROWNMMWeight, QKVROWNMMWeight, - QKRMSNORMWeightGEMMANormWeight, + QKGEMMANormWeight, ) -class Qwen3NextFullAttentionTransformerLayerWeight(Qwen3MOETransformerLayerWeight): +class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + num_full_attention_layers = network_config["full_attention_interval"] + self.is_linear_attention = (layer_num + 1) % num_full_attention_layers != 0 super().__init__(layer_num, data_type, network_config, quant_cfg) return @@ -40,37 +42,73 @@ def _init_qkv(self): ) def _init_weight(self): - super()._init_weight() - self._init_gate_shared_expert_weight() - return + if self.is_linear_attention: + self._init_gdn_weight() + else: + self._init_qkv() + self._init_o() + + if self.is_moe: + self._init_moe() + else: + self._init_ffn() + self._init_norm() - def _init_ffn(self): - # Qwen3Next architecture uses _init_gate_shared_expert_weight() for FFN-like component - # No standard MLP FFN weights needed for this architecture - pass + def _init_moe(self): + super()._init_moe() + self._init_gated_ffn() + return def _init_norm(self): hidden_size = self.network_config_["hidden_size"] - self.att_norm_weight_ = RMSNormWeight( + self.att_norm_weight_ = GEMMANormWeight( dim=hidden_size, weight_name=self._att_norm_weight_name, data_type=self.data_type_, ) - self.ffn_norm_weight_ = RMSNormWeight( + self.ffn_norm_weight_ = GEMMANormWeight( dim=hidden_size, weight_name=self._ffn_norm_weight_name, data_type=self.data_type_, ) - self.qk_norm_weight_ = QKRMSNORMWeightGEMMANormWeight( - dim=self.head_dim, - q_weight_name=self._q_norm_name, - k_weight_name=self._k_norm_name, + if not self.is_linear_attention: + self.qk_norm_weight_ = QKGEMMANormWeight( + dim=self.head_dim, + q_weight_name=self._q_norm_name, + k_weight_name=self._k_norm_name, + data_type=self.data_type_, + ) + + def _init_gated_ffn(self): + hidden_size = self.network_config_["hidden_size"] + if "shared_expert_intermediate_size" not in self.network_config_: + return + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + ) + self.ffn_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, ) - - def load_hf_weights(self, weights): - self._split_q_with_gate(weights) - super().load_hf_weights(weights) def _split_q_with_gate(self, weights): if self._q_weight_name in weights: @@ -82,69 +120,6 @@ def _split_q_with_gate(self, weights): weights[self._q_weight_name] = _q_proj weights[self._o_gate_weight_name] = _gate_proj - def _init_gate_shared_expert_weight(self): - hidden_size = self.network_config_["hidden_size"] - - # Check if this is a MoE model with shared_expert or a dense model - if "shared_expert_intermediate_size" in self.network_config_: - # MoE model with shared expert - prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" - inter_size = self.network_config_["shared_expert_intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[inter_size, inter_size], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=inter_size, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - self.shared_expert_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) - else: - # Dense model with standard MLP - prefix = f"model.layers.{self.layer_num_}.mlp" - inter_size = self.network_config_["intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[inter_size, inter_size], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=inter_size, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - # No shared_expert_gate for dense models - self.shared_expert_gate = None - - -class Qwen3NextGatedDeltaNetTransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - self.is_moe = ( - network_config.get("num_experts", 0) > 0 - and layer_num not in network_config.get("mlp_only_layers", []) - and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 - ) - super().__init__(layer_num, data_type, network_config, quant_cfg) - def _parse_config(self): super()._parse_config() self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] @@ -152,30 +127,6 @@ def _parse_config(self): self.linear_k_head_dim = self.network_config_["linear_key_head_dim"] self.linear_v_head_dim = self.network_config_["linear_value_head_dim"] - def _init_weight(self): - hidden_size = self.network_config_["hidden_size"] - self.att_norm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=self._att_norm_weight_name, - data_type=self.data_type_, - ) - self._init_gdn_weight() - self.ffn_norm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=self._ffn_norm_weight_name, - data_type=self.data_type_, - ) - if self.is_moe: - self._init_moe() - else: - self._init_ffn() - self._init_gate_shared_expert_weight() - - def _init_ffn(self): - # GatedDeltaNet architecture uses _init_gate_shared_expert_weight() for FFN-like component - # No standard MLP FFN weights needed for this architecture - pass - def _init_gdn_weight(self): prefix = f"model.layers.{self.layer_num_}.linear_attn" hidden_size = self.network_config_["hidden_size"] @@ -185,8 +136,6 @@ def _init_gdn_weight(self): kernel_size = self.network_config_.get("linear_conv_kernel_dim", 4) # Conv1d weight: after _preprocess_weight, shape is [channels, kernel_size]. - # ROWMMWeight row-slices out_dims (rows), matching TP split of channels dim. - # causal_conv1d_fn expects weight shape (dim, width) = (channels_per_tp, kernel_size). self.linear_conv1d = ROWMMWeight( in_dim=kernel_size, out_dims=[conv1d_channels], @@ -242,10 +191,6 @@ def _init_gdn_weight(self): data_type=self.data_type_, ) - def load_hf_weights(self, weights): - self._preprocess_weight(weights) - return super().load_hf_weights(weights) - def _preprocess_weight(self, weights): linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" @@ -263,18 +208,6 @@ def _rearrange_gdn_in_proj_weights(self, weights): """Rearrange in_proj_qkvz and in_proj_ba weight rows from interleaved per-k-head layout to TP-aware grouped layout so that after ROWMMWeight's row-slicing, each rank's MM output is already [q_chunk, k_chunk, v_chunk, z_chunk, b_chunk, a_chunk]. - - This eliminates the expensive split+reshape+cat in _fix_query_key_value_ba_ordering - at inference time, replacing it with simple slicing. - - The key challenge is that ROWMMWeight slices each weight as a contiguous row chunk - (rows [start:end]). So we arrange the rows such that each TP chunk contains - the grouped layout for that rank: - 1. Deinterleave from per-k-head groups into per-component tensors - 2. Chunk each component by TP - 3. Reassemble as [q_tp0, k_tp0, v_tp0, z_tp0, q_tp1, k_tp1, ...] so row-slicing - gives each rank [q_chunk, k_chunk, v_chunk, z_chunk]. - Same pattern as _parse_linear_conv1d uses for conv1d weights. """ num_k = self.linear_num_k_heads k_dim = self.linear_k_head_dim @@ -324,64 +257,17 @@ def _rearrange_gdn_in_proj_weights(self, weights): def _parse_linear_conv1d(self, weight): qk_dim = self.linear_num_k_heads * self.linear_k_head_dim v_dim = self.linear_num_v_heads * self.linear_v_head_dim - q_bias, k_bias, v_bias = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0) - q_splits = q_bias.chunk(self.tp_world_size_, dim=0) - k_splits = k_bias.chunk(self.tp_world_size_, dim=0) - v_splits = v_bias.chunk(self.tp_world_size_, dim=0) + q, k, v = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0) + q_splits = q.chunk(self.tp_world_size_, dim=0) + k_splits = k.chunk(self.tp_world_size_, dim=0) + v_splits = v.chunk(self.tp_world_size_, dim=0) new_weight = torch.cat( [torch.cat([q_splits[i], k_splits[i], v_splits[i]], dim=0) for i in range(self.tp_world_size_)], dim=0 ) return new_weight - def _init_gate_shared_expert_weight(self): - hidden_size = self.network_config_["hidden_size"] - - # Check if this is a MoE model with shared_expert or a dense model - if "shared_expert_intermediate_size" in self.network_config_: - # MoE model with shared expert - prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" - inter_size = self.network_config_["shared_expert_intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[inter_size, inter_size], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=inter_size, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - self.shared_expert_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) - else: - # Dense model with standard MLP - prefix = f"model.layers.{self.layer_num_}.mlp" - inter_size = self.network_config_["intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[inter_size, inter_size], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=inter_size, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - # No shared_expert_gate for dense models - self.shared_expert_gate = None + def load_hf_weights(self, weights): + self._split_q_with_gate(weights) + if self.is_linear_attention: + self._preprocess_weight(weights) + super().load_hf_weights(weights) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index b3f0f53cac..add3f06b9a 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -4,8 +4,7 @@ from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( - Qwen3NextFullAttentionTransformerLayerWeight, - Qwen3NextGatedDeltaNetTransformerLayerWeight, + Qwen3NextTransformerLayerWeight, ) from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( Qwen3NextFullAttentionTransformerLayerInfer, @@ -28,10 +27,11 @@ @ModelRegistry("qwen3_next") class Qwen3NextTpPartModel(Qwen3MOEModel): + transformer_weight_class = Qwen3NextTransformerLayerWeight + post_layer_infer_class = Qwen3NextPostLayerInfer infer_state_class = Qwen3NextInferStateInfo - is_hybrid_attention = True # Indicates model uses hybrid (full + linear) attention use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states radix_cache_class = HybridRadixCache @@ -195,28 +195,6 @@ def _init_req_manager(self): self.req_manager = ReqManagerForMamba(self.max_req_num, create_max_seq_len, self.mem_manager) - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) - num_full_attention_layers = self.config["full_attention_interval"] - self.trans_layers_weight = [ - ( - Qwen3NextFullAttentionTransformerLayerWeight( - i, - self.data_type, - network_config=self.config, - quant_cfg=self.quant_cfg, - ) - if (i + 1) % num_full_attention_layers == 0 - else Qwen3NextGatedDeltaNetTransformerLayerWeight( - i, - self.data_type, - network_config=self.config, - quant_cfg=self.quant_cfg, - ) - ) - for i in range(self.config["n_layer"]) - ] - def _init_infer_layer(self): self.pre_infer = self.pre_layer_infer_class(network_config=self.config) self.post_infer = self.post_layer_infer_class(network_config=self.config) From 86f17b69c87b6643893cbc58c95dc2e5b0d1f597 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Mar 2026 15:33:05 +0000 Subject: [PATCH 113/180] clean code --- .../layer_weights/transformer_layer_weight.py | 11 + .../qwen3next/layer_infer/post_layer_infer.py | 12 - .../layer_infer/transformer_layer_infer.py | 74 ++-- .../pre_and_post_layer_weight.py | 29 ++ lightllm/models/qwen3next/model.py | 5 +- .../triton_kernel/fused_split_copy.py | 400 ------------------ .../qwen3next/triton_kernel/gemma_rmsnorm.py | 141 ------ .../layer_infer/post_layer_infer.py | 16 - .../layer_infer/pre_layer_infer.py | 11 +- 9 files changed, 73 insertions(+), 626 deletions(-) delete mode 100644 lightllm/models/qwen3next/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fused_split_copy.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py delete mode 100644 lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py diff --git a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py index 9f91f3db8b..da93133444 100644 --- a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py @@ -10,6 +10,17 @@ class Qwen35TransformerLayerWeight(Qwen3NextTransformerLayerWeight): + def _init_weight_names(self): + super()._init_weight_names() + self._gate_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" + self._gate_bias_name = None + self._up_weight_name = f"model.layers.{self.layer_num_}.mlp.up_proj.weight" + self._up_bias_name = None + self._gate_up_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight" + self._gate_up_bias_name = None + self._down_weight_name = f"model.layers.{self.layer_num_}.mlp.down_proj.weight" + self._down_bias_name = None + def _init_gdn_weight(self): # Initialize everything from parent first, then override only linear_in_proj. super()._init_gdn_weight() diff --git a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py deleted file mode 100644 index 9dcab4e6fc..0000000000 --- a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch - -from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward - - -class Qwen3NextPostLayerInfer(LlamaPostLayerInfer): - def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) - gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) - return out diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 5732dc41e3..cc3b1fe370 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -402,48 +402,35 @@ def _fix_query_key_value_ba_ordering(self, mixed_qkvzba, is_decode=False): z_end = qkv_dim + self.tp_value_dim b_end = z_end + self.tp_num_v_heads - if is_decode: - mixed_qkv = mixed_qkvzba[:, :qkv_dim].contiguous() - z = mixed_qkvzba[:, qkv_dim:z_end].contiguous().view(-1, self.tp_num_v_heads, self.head_v_dim) - b = mixed_qkvzba[:, z_end:b_end].contiguous() - a = mixed_qkvzba[:, b_end:].contiguous() - else: - mixed_qkv = mixed_qkvzba[:, :qkv_dim] - # .reshape() handles non-contiguous slices by copying when needed (unlike .view()) - z = mixed_qkvzba[:, qkv_dim:z_end].reshape(-1, self.tp_num_v_heads, self.head_v_dim) - # b and a must be contiguous: fused_gdn_gating_kernel uses raw pointer arithmetic - # (off = i_b * NUM_HEADS + head_off) that assumes contiguous layout. - # Non-contiguous slices have stride[0]=total_dim, causing wrong reads for i_b > 0. - b = mixed_qkvzba[:, z_end:b_end].contiguous() - a = mixed_qkvzba[:, b_end:].contiguous() + mixed_qkv = mixed_qkvzba[:, :qkv_dim] + z = mixed_qkvzba[:, qkv_dim:z_end].view(-1, self.tp_num_v_heads, self.head_v_dim) + b = mixed_qkvzba[:, z_end:b_end] + a = mixed_qkvzba[:, b_end:] + return mixed_qkv, z, b, a + + def _split_qkvzba(self, mixed_qkvzba: torch.Tensor): + + qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim + z_end = qkv_dim + self.tp_value_dim + b_end = z_end + self.tp_num_v_heads + mixed_qkv = mixed_qkvzba[:, :qkv_dim] + z = mixed_qkvzba[:, qkv_dim:z_end].view(-1, self.tp_num_v_heads, self.head_v_dim) + b = mixed_qkvzba[:, z_end:b_end] + a = mixed_qkvzba[:, b_end:] return mixed_qkv, z, b, a - def _rearrange_mixed_qkv(self, mixed_qkv, decode=False): - if mixed_qkv is None: - return None, None, None - if decode: - query, key, value = torch.split( - mixed_qkv, - [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], - dim=-1, - ) - batch_size = mixed_qkv.shape[0] - query = query.contiguous().view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) - key = key.contiguous().view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) - value = value.contiguous().view(batch_size, 1, self.tp_num_v_heads, self.head_v_dim) - return query, key, value - else: - query, key, value = torch.split( - mixed_qkv, - [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], - dim=-1, - ) - seq_len = query.shape[0] - query = query.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) - key = key.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) - value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim) - return query, key, value + def _split_qkv(self, mixed_qkv: torch.Tensor): + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + seq_len = query.shape[0] + query = query.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + key = key.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim) + return query, key, value def context_attention_forward( self, @@ -489,7 +476,7 @@ def _gdn_prefill_kernel( mixed_qkv = out_tensor.transpose(0, 1) # Recurrent processing - query, key, value = self._rearrange_mixed_qkv(mixed_qkv) + query, key, value = self._split_qkv(mixed_qkv) initial_state = ssm_states[infer_state.b_buffer_idx] # g and beta have shape (total_tokens, num_heads), need to unsqueeze to get (1, total_tokens, num_heads) core_attn_out, last_recurrent_state = chunk_gated_delta_rule( @@ -523,8 +510,6 @@ def _gdn_decode_kernel( ): """Decode kernel for GDN forward pass (single-token, non-MTP mode). Uses fused gating: g/beta computed inline in the recurrent kernel.""" - # Conv1D processing — mixed_qkv is pre-copied to contiguous buffer - # by _fix_query_key_value_ba_ordering (causal_conv1d_update requires contiguous input) mixed_qkv = causal_conv1d_update( mixed_qkv, conv_states, @@ -536,7 +521,7 @@ def _gdn_decode_kernel( # Recurrent processing with fused gating # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally - query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) + query, key, value = self._split_qkv(mixed_qkv) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -653,8 +638,7 @@ def gdn_forward( conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) mixed_qkvzba = layer_weight.linear_in_proj.mm(input) - # mixed_qkv is now returned pre-concatenated (no torch.cat needed) - mixed_qkv, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba, is_decode=not is_prefill) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) # Dispatch to appropriate kernel if is_prefill: diff --git a/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..daaf146907 --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,29 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, GEMMANormWeight + + +class Qwen3NextPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) + self.final_norm_weight_ = GEMMANormWeight( + dim=hidden_size, + weight_name="model.norm.weight", + data_type=self.data_type_, + ) + return diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index add3f06b9a..24069b800d 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -6,11 +6,11 @@ from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( Qwen3NextTransformerLayerWeight, ) +from lightllm.models.qwen3next.layer_weights.pre_and_post_layer_weight import Qwen3NextPreAndPostLayerWeight from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( Qwen3NextFullAttentionTransformerLayerInfer, Qwen3NextGatedDeltaNetTransformerLayerInfer, ) -from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -18,7 +18,6 @@ from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache logger = init_logger(__name__) @@ -27,9 +26,9 @@ @ModelRegistry("qwen3_next") class Qwen3NextTpPartModel(Qwen3MOEModel): + pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight transformer_weight_class = Qwen3NextTransformerLayerWeight - post_layer_infer_class = Qwen3NextPostLayerInfer infer_state_class = Qwen3NextInferStateInfo use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states diff --git a/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py b/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py deleted file mode 100644 index 5f4433fb34..0000000000 --- a/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py +++ /dev/null @@ -1,400 +0,0 @@ -""" -Fused Split-Copy Triton Kernels for GDN Decode Path - -Replaces multiple separate .copy_() calls with single kernel launches to reduce -kernel launch overhead in the decode hot path (36 GDN layers per step). - -Kernel 1 (fused_split_copy_qkvzba): 4 copies → 1 kernel - Splits GEMM output [batch, total_dim] into qkv, z, b, a destination buffers. - -Kernel 2 (fused_split_copy_qkv): 3 copies → 1 kernel - Splits conv1d output [batch, qkv_dim] into q, k, v destination buffers. - Handles non-contiguous source (stride(0) != total_dim from column slicing). -""" - -import torch -import triton -import triton.language as tl - - -# ============================================================================= -# Kernel 1: Fused split-copy for qkv, z, b, a from GEMM output -# ============================================================================= - - -@triton.jit -def _fused_split_copy_qkvzba_kernel( - # Source pointer (contiguous GEMM output) - src_ptr, - # Destination pointers (pre-allocated contiguous buffers) - dst_qkv_ptr, - dst_z_ptr, - dst_b_ptr, - dst_a_ptr, - # Row strides - src_stride0, - dst_qkv_stride0, - dst_z_stride0, - dst_b_stride0, - dst_a_stride0, - # Segment boundaries (cumulative): [0, qkv_dim) [qkv_dim, z_end) [z_end, b_end) [b_end, total_dim) - qkv_dim, - z_end, - b_end, - total_dim, - # Block size - BLOCK_N: tl.constexpr, -): - """ - One program per (row, column_block). Loads a BLOCK_N chunk from the source row, - then conditionally stores to the correct destination based on column position. - - Grid: (batch, cdiv(total_dim, BLOCK_N)) - """ - row = tl.program_id(0) - col_block = tl.program_id(1) - - col_start = col_block * BLOCK_N - cols = col_start + tl.arange(0, BLOCK_N) - mask = cols < total_dim - - # Load source chunk - data = tl.load(src_ptr + row * src_stride0 + cols, mask=mask) - - # Store to qkv destination: columns [0, qkv_dim) - qkv_mask = mask & (cols < qkv_dim) - tl.store(dst_qkv_ptr + row * dst_qkv_stride0 + cols, data, mask=qkv_mask) - - # Store to z destination: columns [qkv_dim, z_end) - z_mask = mask & (cols >= qkv_dim) & (cols < z_end) - tl.store(dst_z_ptr + row * dst_z_stride0 + (cols - qkv_dim), data, mask=z_mask) - - # Store to b destination: columns [z_end, b_end) - b_mask = mask & (cols >= z_end) & (cols < b_end) - tl.store(dst_b_ptr + row * dst_b_stride0 + (cols - z_end), data, mask=b_mask) - - # Store to a destination: columns [b_end, total_dim) - a_mask = mask & (cols >= b_end) - tl.store(dst_a_ptr + row * dst_a_stride0 + (cols - b_end), data, mask=a_mask) - - -def fused_split_copy_qkvzba( - src: torch.Tensor, - dst_qkv: torch.Tensor, - dst_z: torch.Tensor, - dst_b: torch.Tensor, - dst_a: torch.Tensor, - qkv_dim: int, - z_dim: int, - b_dim: int, - a_dim: int, -): - """ - Fused split-copy from GEMM output into 4 contiguous destination buffers. - - Replaces: - conv_buf.copy_(mixed_qkvzba[:, :qkv_dim]) - z_buf.view(batch, -1).copy_(mixed_qkvzba[:, qkv_dim:z_end]) - b_buf.copy_(mixed_qkvzba[:, z_end:b_end]) - a_buf.copy_(mixed_qkvzba[:, b_end:]) - - Args: - src: [batch, total_dim] contiguous source (GEMM output) - dst_qkv: [batch, qkv_dim] contiguous destination for conv1d input - dst_z: [batch, z_dim] contiguous destination (z_buf viewed flat) - dst_b: [batch, b_dim] contiguous destination - dst_a: [batch, a_dim] contiguous destination - qkv_dim: width of qkv segment (tp_key_dim * 2 + tp_value_dim) - z_dim: width of z segment (tp_value_dim) - b_dim: width of b segment (tp_num_v_heads) - a_dim: width of a segment (tp_num_v_heads) - """ - total_dim = qkv_dim + z_dim + b_dim + a_dim - z_end = qkv_dim + z_dim - b_end = z_end + b_dim - - batch = src.shape[0] - BLOCK_N = 128 - num_col_blocks = triton.cdiv(total_dim, BLOCK_N) - - grid = (batch, num_col_blocks) - - _fused_split_copy_qkvzba_kernel[grid]( - src, - dst_qkv, - dst_z, - dst_b, - dst_a, - src.stride(0), - dst_qkv.stride(0), - dst_z.stride(0), - dst_b.stride(0), - dst_a.stride(0), - qkv_dim, - z_end, - b_end, - total_dim, - BLOCK_N=BLOCK_N, - num_warps=4, - ) - - -# ============================================================================= -# Kernel 2: Fused split-copy for q, k, v from conv1d output -# ============================================================================= - - -@triton.jit -def _fused_split_copy_qkv_kernel( - # Source pointer (may be non-contiguous column slice) - src_ptr, - # Destination pointers (contiguous buffers) - dst_q_ptr, - dst_k_ptr, - dst_v_ptr, - # Row strides - src_stride0, - dst_q_stride0, - dst_k_stride0, - dst_v_stride0, - # Segment boundaries: [0, q_dim) [q_dim, qk_end) [qk_end, total_dim) - q_dim, - qk_end, - total_dim, - # Block size - BLOCK_N: tl.constexpr, -): - """ - One program per (row, column_block). Loads a BLOCK_N chunk from the source row, - then conditionally stores to q, k, or v destination. - - Supports non-contiguous source via src_stride0 (stride may be > total_dim - when source is a column slice of a larger tensor). - - Grid: (batch, cdiv(total_dim, BLOCK_N)) - """ - row = tl.program_id(0) - col_block = tl.program_id(1) - - col_start = col_block * BLOCK_N - cols = col_start + tl.arange(0, BLOCK_N) - mask = cols < total_dim - - # Load source chunk (use src_stride0 for row advancement) - data = tl.load(src_ptr + row * src_stride0 + cols, mask=mask) - - # Store to q destination: columns [0, q_dim) - q_mask = mask & (cols < q_dim) - tl.store(dst_q_ptr + row * dst_q_stride0 + cols, data, mask=q_mask) - - # Store to k destination: columns [q_dim, qk_end) - k_mask = mask & (cols >= q_dim) & (cols < qk_end) - tl.store(dst_k_ptr + row * dst_k_stride0 + (cols - q_dim), data, mask=k_mask) - - # Store to v destination: columns [qk_end, total_dim) - v_mask = mask & (cols >= qk_end) - tl.store(dst_v_ptr + row * dst_v_stride0 + (cols - qk_end), data, mask=v_mask) - - -def fused_split_copy_qkv( - src: torch.Tensor, - dst_q: torch.Tensor, - dst_k: torch.Tensor, - dst_v: torch.Tensor, - q_dim: int, - k_dim: int, - v_dim: int, - src_stride0: int, -): - """ - Fused split-copy from conv1d output into 3 contiguous q/k/v buffers. - - Replaces: - q_split, k_split, v_split = torch.split(mixed_qkv, [...], dim=-1) - q_buf.view(batch, -1).copy_(q_split) - k_buf.view(batch, -1).copy_(k_split) - v_buf.view(batch, -1).copy_(v_split) - - Args: - src: [batch, total_dim] source tensor (may be non-contiguous if column slice) - dst_q: [batch, q_dim] contiguous destination - dst_k: [batch, k_dim] contiguous destination - dst_v: [batch, v_dim] contiguous destination - q_dim: width of q segment (tp_key_dim) - k_dim: width of k segment (tp_key_dim) - v_dim: width of v segment (tp_value_dim) - src_stride0: row stride of source (may be > q_dim+k_dim+v_dim) - """ - total_dim = q_dim + k_dim + v_dim - qk_end = q_dim + k_dim - - batch = src.shape[0] - BLOCK_N = 128 - num_col_blocks = triton.cdiv(total_dim, BLOCK_N) - - grid = (batch, num_col_blocks) - - _fused_split_copy_qkv_kernel[grid]( - src, - dst_q, - dst_k, - dst_v, - src_stride0, - dst_q.stride(0), - dst_k.stride(0), - dst_v.stride(0), - q_dim, - qk_end, - total_dim, - BLOCK_N=BLOCK_N, - num_warps=4, - ) - - -# ============================================================================= -# Test / Verification -# ============================================================================= - - -def test_fused_split_copy(): - """Verify fused kernels produce identical results to separate .copy_() calls.""" - torch.manual_seed(42) - device = "cuda" - dtype = torch.bfloat16 - - print("=" * 60) - print("Testing fused_split_copy_qkvzba") - print("=" * 60) - - # Typical dimensions for Qwen3-Coder-Next with TP=4 - # tp_key_dim=128, tp_value_dim=256, tp_num_v_heads=2 - qkv_dim = 128 + 128 + 256 # q + k + v = 512 - z_dim = 256 - b_dim = 2 - a_dim = 2 - total_dim = qkv_dim + z_dim + b_dim + a_dim # 772 - - for batch in [1, 4, 8, 32]: - src = torch.randn(batch, total_dim, dtype=dtype, device=device) - - # Reference: separate copies - ref_qkv = src[:, :qkv_dim].clone() - ref_z = src[:, qkv_dim : qkv_dim + z_dim].clone() - ref_b = src[:, qkv_dim + z_dim : qkv_dim + z_dim + b_dim].clone() - ref_a = src[:, qkv_dim + z_dim + b_dim :].clone() - - # Fused kernel - dst_qkv = torch.empty(batch, qkv_dim, dtype=dtype, device=device) - dst_z = torch.empty(batch, z_dim, dtype=dtype, device=device) - dst_b = torch.empty(batch, b_dim, dtype=dtype, device=device) - dst_a = torch.empty(batch, a_dim, dtype=dtype, device=device) - fused_split_copy_qkvzba(src, dst_qkv, dst_z, dst_b, dst_a, qkv_dim, z_dim, b_dim, a_dim) - - assert torch.equal(dst_qkv, ref_qkv), f"qkv mismatch at batch={batch}" - assert torch.equal(dst_z, ref_z), f"z mismatch at batch={batch}" - assert torch.equal(dst_b, ref_b), f"b mismatch at batch={batch}" - assert torch.equal(dst_a, ref_a), f"a mismatch at batch={batch}" - print(f" batch={batch:3d}: PASS") - - print() - print("=" * 60) - print("Testing fused_split_copy_qkv") - print("=" * 60) - - q_dim = 128 - k_dim = 128 - v_dim = 256 - qkv_dim = q_dim + k_dim + v_dim # 512 - - for batch in [1, 4, 8, 32]: - # Test with contiguous source - src = torch.randn(batch, qkv_dim, dtype=dtype, device=device) - - ref_q = src[:, :q_dim].clone() - ref_k = src[:, q_dim : q_dim + k_dim].clone() - ref_v = src[:, q_dim + k_dim :].clone() - - dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) - dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) - dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) - fused_split_copy_qkv(src, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src.stride(0)) - - assert torch.equal(dst_q, ref_q), f"q mismatch at batch={batch} (contiguous)" - assert torch.equal(dst_k, ref_k), f"k mismatch at batch={batch} (contiguous)" - assert torch.equal(dst_v, ref_v), f"v mismatch at batch={batch} (contiguous)" - print(f" batch={batch:3d} (contiguous src): PASS") - - # Test with non-contiguous source (column slice of wider tensor) - wider = torch.randn(batch, qkv_dim + 64, dtype=dtype, device=device) - src_nc = wider[:, :qkv_dim] # Non-contiguous: stride(0) = qkv_dim + 64 - assert src_nc.stride(0) == qkv_dim + 64, "expected non-contiguous slice" - - ref_q = src_nc[:, :q_dim].clone() - ref_k = src_nc[:, q_dim : q_dim + k_dim].clone() - ref_v = src_nc[:, q_dim + k_dim :].clone() - - dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) - dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) - dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) - fused_split_copy_qkv(src_nc, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src_nc.stride(0)) - - assert torch.equal(dst_q, ref_q), f"q mismatch at batch={batch} (non-contiguous)" - assert torch.equal(dst_k, ref_k), f"k mismatch at batch={batch} (non-contiguous)" - assert torch.equal(dst_v, ref_v), f"v mismatch at batch={batch} (non-contiguous)" - print(f" batch={batch:3d} (non-contiguous src): PASS") - - print() - print("=" * 60) - print("Testing edge cases") - print("=" * 60) - - # Edge case: different dimension ratios (small q/k, large v) - q_dim, k_dim, v_dim = 32, 32, 512 - qkv_dim = q_dim + k_dim + v_dim - batch = 2 - src = torch.randn(batch, qkv_dim, dtype=dtype, device=device) - - dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) - dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) - dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) - fused_split_copy_qkv(src, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src.stride(0)) - - assert torch.equal(dst_q, src[:, :q_dim]) - assert torch.equal(dst_k, src[:, q_dim : q_dim + k_dim]) - assert torch.equal(dst_v, src[:, q_dim + k_dim :]) - print(" asymmetric dims (32, 32, 512): PASS") - - # Edge case: float32 dtype - src_f32 = torch.randn(4, 772, dtype=torch.float32, device=device) - dst_qkv = torch.empty(4, 512, dtype=torch.float32, device=device) - dst_z = torch.empty(4, 256, dtype=torch.float32, device=device) - dst_b = torch.empty(4, 2, dtype=torch.float32, device=device) - dst_a = torch.empty(4, 2, dtype=torch.float32, device=device) - fused_split_copy_qkvzba(src_f32, dst_qkv, dst_z, dst_b, dst_a, 512, 256, 2, 2) - - assert torch.equal(dst_qkv, src_f32[:, :512]) - assert torch.equal(dst_z, src_f32[:, 512:768]) - assert torch.equal(dst_b, src_f32[:, 768:770]) - assert torch.equal(dst_a, src_f32[:, 770:]) - print(" float32 dtype: PASS") - - # Edge case: float16 dtype - src_f16 = torch.randn(4, 772, dtype=torch.float16, device=device) - dst_qkv = torch.empty(4, 512, dtype=torch.float16, device=device) - dst_z = torch.empty(4, 256, dtype=torch.float16, device=device) - dst_b = torch.empty(4, 2, dtype=torch.float16, device=device) - dst_a = torch.empty(4, 2, dtype=torch.float16, device=device) - fused_split_copy_qkvzba(src_f16, dst_qkv, dst_z, dst_b, dst_a, 512, 256, 2, 2) - - assert torch.equal(dst_qkv, src_f16[:, :512]) - assert torch.equal(dst_z, src_f16[:, 512:768]) - assert torch.equal(dst_b, src_f16[:, 768:770]) - assert torch.equal(dst_a, src_f16[:, 770:]) - print(" float16 dtype: PASS") - - print() - print("All tests passed!") - - -if __name__ == "__main__": - test_fused_split_copy() diff --git a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py deleted file mode 100644 index 0a2b4bd662..0000000000 --- a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch - -import triton -import triton.language as tl - -from lightllm.common.triton_utils.autotuner import autotune - - -@triton.jit -def _gemma_rmsnorm_fwd_kernel( - x_ptr, - w_ptr, - y_ptr, - x_stride0, - x_stride1, - y_stride0, - y_stride1, - N: tl.constexpr, - EPS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - row = tl.program_id(0) - x_ptr = x_ptr + row * x_stride0 - y_ptr = y_ptr + row * y_stride0 - - _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(x_ptr + cols * x_stride1, mask=cols < N, other=0.0).to(tl.float32) - _sum += x * x - - var = tl.sum(_sum, axis=0) / N - rstd = 1 / tl.sqrt(var + EPS) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) - x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - w = w + 1.0 - y = x_hat * w - # Write output - tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) - - -def _get_gemma_rmsnorm_configs(): - """Generate configurations for autotuning gemma RMSNorm kernel.""" - configs = [] - for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: - for num_warps in [1, 2, 4, 8]: - # num_stages has minimal impact on this simple kernel, use 1 - configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1}) - return configs - - -def _get_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): - """Generate static key for caching autotuned configurations.""" - N = x.shape[-1] - return { - "x_dtype": str(x.dtype), - "weight_dtype": str(w.dtype), - "N": N, - } - - -@autotune( - kernel_name="gemma_rmsnorm_forward:v1", - configs_gen_func=_get_gemma_rmsnorm_configs, - static_key_func=_get_gemma_rmsnorm_static_key, - run_key_func=lambda x: x.shape[-1], -) -def gemma_rmsnorm_forward(x, w, eps, out=None, run_config: dict = None): - # Inplace gemma RMS Norm - # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - N = x.shape[-1] - y = torch.empty_like(x) if out is None else out - x_arg = x.view(-1, N) - y_arg = y.view(-1, N) - - M, _ = x_arg.shape - - # Default heuristic when autotune is disabled or no config provided - if not run_config: - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_SIZE: - raise RuntimeError("This gemma rmsnorm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} - - BLOCK_SIZE = run_config["BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - _gemma_rmsnorm_fwd_kernel[(M,)]( - x_arg, - w, - y_arg, - x_stride0=x.stride(0), - x_stride1=x.stride(1), - y_stride0=y.stride(0), - y_stride1=y.stride(1), - N=N, - EPS=eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) - - return y - - -def _gemma_rmsnorm_fwd_torch(x, weight, eps): - original_dtype = x.dtype - x = x.to(torch.float32) - x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - x = x * (1.0 + weight.float()) - return x.to(original_dtype) - - -def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): - # create data - x_shape = (M, N) - w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - # forward pass - y_tri = gemma_rmsnorm_forward(x, weight, eps) - y_ref = _gemma_rmsnorm_fwd_torch(x, weight, eps) - - # compare - print("type:", y_tri.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - # Use appropriate tolerance based on dtype - atol = 1e-2 if dtype == torch.float32 else 5e-2 - assert torch.allclose(y_tri, y_ref, atol=atol, rtol=0) - return diff --git a/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py deleted file mode 100644 index 2918fca79c..0000000000 --- a/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight -from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward - - -class Qwen3NextMTPPostLayerInfer(LlamaPostLayerInfer): - """ - Qwen3Next MTP Post Layer Inference. - Uses gemma_rmsnorm for normalization (same as Qwen3Next). - """ - - def _norm(self, input, infer_state, layer_weight: Qwen3NextMTPPreAndPostLayerWeight) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) - gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) - return out diff --git a/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py index 4fc207648c..ef3fe38153 100644 --- a/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py @@ -3,7 +3,6 @@ from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward class Qwen3NextMTPPreLayerInfer(LlamaPreLayerInfer): @@ -33,16 +32,10 @@ def _mtp_forward( assert input_embdings.shape[0] == tgt_embdings.shape[0] # Normalize embedding - input_embdings_normed = self.alloc_tensor(input_embdings.shape, input_embdings.dtype) - gemma_rmsnorm_forward( - input_embdings, layer_weight.pre_fc_norm_embedding_weight_.weight, self.eps_, out=input_embdings_normed - ) + input_embdings_normed = layer_weight.pre_fc_norm_embedding_weight_(input=input_embdings, eps=self.eps_) # Normalize hidden state - tgt_embdings_normed = self.alloc_tensor(tgt_embdings.shape, tgt_embdings.dtype) - gemma_rmsnorm_forward( - tgt_embdings, layer_weight.pre_fc_norm_hidden_weight_.weight, self.eps_, out=tgt_embdings_normed - ) + tgt_embdings_normed = layer_weight.pre_fc_norm_hidden_weight_(input=tgt_embdings, eps=self.eps_) # Concat normalized embedding and hidden cat_embdings = torch.cat((input_embdings_normed, tgt_embdings_normed), dim=-1) From a1849e61a3709519231a16a4c704428fc8db2f4f Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 13 Mar 2026 13:20:59 +0000 Subject: [PATCH 114/180] fix --- lightllm/server/visualserver/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 202c2fc453..782aaa7a75 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -192,7 +192,7 @@ async def loop_for_netio_req(self): self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) + self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 self.visual_recv_max_count = 64 From 61f74acf80e0ed067ac91a31aa19c770533e3945 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 16 Mar 2026 07:14:49 +0000 Subject: [PATCH 115/180] remove contiguous --- .../layer_infer/transformer_layer_infer.py | 70 ++++++++----------- .../triton_kernel/fused_gdn_gating.py | 10 ++- 2 files changed, 38 insertions(+), 42 deletions(-) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index cc3b1fe370..849af38bd5 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -385,52 +385,41 @@ def overlap_tpsp_context_forward( # ==================== GDN Helper Methods ==================== - def _fix_query_key_value_ba_ordering(self, mixed_qkvzba, is_decode=False): - """ - Extract q, k, v, z, b, a from the MM output. - - After weight rearrangement at load time, the MM output is already in grouped layout: - [all_q | all_k | all_v | all_z | all_b | all_a] - so this is just simple slicing — no split+reshape+cat needed. - - Note: - Decode fast-path fused split-copy kernels are intentionally avoided here. - The explicit contiguous slicing path is slower but is more robust and - matches the reference behavior used in vLLM. - """ - qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim - z_end = qkv_dim + self.tp_value_dim - b_end = z_end + self.tp_num_v_heads - - mixed_qkv = mixed_qkvzba[:, :qkv_dim] - z = mixed_qkvzba[:, qkv_dim:z_end].view(-1, self.tp_num_v_heads, self.head_v_dim) - b = mixed_qkvzba[:, z_end:b_end] - a = mixed_qkvzba[:, b_end:] - return mixed_qkv, z, b, a - - def _split_qkvzba(self, mixed_qkvzba: torch.Tensor): - + def _split_qkvzba(self, mixed_qkvzba, is_decode=False): qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim z_end = qkv_dim + self.tp_value_dim b_end = z_end + self.tp_num_v_heads - mixed_qkv = mixed_qkvzba[:, :qkv_dim] z = mixed_qkvzba[:, qkv_dim:z_end].view(-1, self.tp_num_v_heads, self.head_v_dim) b = mixed_qkvzba[:, z_end:b_end] a = mixed_qkvzba[:, b_end:] return mixed_qkv, z, b, a - def _split_qkv(self, mixed_qkv: torch.Tensor): - query, key, value = torch.split( - mixed_qkv, - [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], - dim=-1, - ) - seq_len = query.shape[0] - query = query.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) - key = key.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) - value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim) - return query, key, value + def _rearrange_mixed_qkv(self, mixed_qkv, decode=False): + if mixed_qkv is None: + return None, None, None + if decode: + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + batch_size = mixed_qkv.shape[0] + query = query.view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + key = key.view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + value = value.view(batch_size, 1, self.tp_num_v_heads, self.head_v_dim) + return query, key, value + else: + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + seq_len = query.shape[0] + query = query.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + key = key.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim) + return query, key, value def context_attention_forward( self, @@ -476,7 +465,7 @@ def _gdn_prefill_kernel( mixed_qkv = out_tensor.transpose(0, 1) # Recurrent processing - query, key, value = self._split_qkv(mixed_qkv) + query, key, value = self._rearrange_mixed_qkv(mixed_qkv) initial_state = ssm_states[infer_state.b_buffer_idx] # g and beta have shape (total_tokens, num_heads), need to unsqueeze to get (1, total_tokens, num_heads) core_attn_out, last_recurrent_state = chunk_gated_delta_rule( @@ -521,7 +510,7 @@ def _gdn_decode_kernel( # Recurrent processing with fused gating # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally - query, key, value = self._split_qkv(mixed_qkv) + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -638,7 +627,8 @@ def gdn_forward( conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) mixed_qkvzba = layer_weight.linear_in_proj.mm(input) - mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + # mixed_qkv is now returned pre-concatenated (no torch.cat needed) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) # Dispatch to appropriate kernel if is_prefill: diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py index c816a20013..88febaffc6 100644 --- a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -18,6 +18,8 @@ def fused_gdn_gating_kernel( a, b, dt_bias, + stride_a_row, + stride_b_row, NUM_HEADS: tl.constexpr, beta: tl.constexpr, threshold: tl.constexpr, @@ -26,10 +28,12 @@ def fused_gdn_gating_kernel( i_b, i_d = tl.program_id(0), tl.program_id(1) head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) off = i_b * NUM_HEADS + head_off + off_a = i_b * stride_a_row + head_off + off_b = i_b * stride_b_row + head_off mask = head_off < NUM_HEADS blk_A_log = tl.load(A_log + head_off, mask=mask) - blk_a = tl.load(a + off, mask=mask) - blk_b = tl.load(b + off, mask=mask) + blk_a = tl.load(a + off_a, mask=mask) + blk_b = tl.load(b + off_b, mask=mask) blk_bias = tl.load(dt_bias + head_off, mask=mask) x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) @@ -78,6 +82,8 @@ def fused_gdn_gating( a, b, dt_bias, + a.stride(0), + b.stride(0), num_heads, beta, threshold, From bf0f2543b64c8d255f2c9ac0184824bbaa4ff797 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 16 Mar 2026 08:42:50 +0000 Subject: [PATCH 116/180] remove gemma rms norm config --- ...at16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ------- ...at16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ------- ...at16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ------- ...at16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 7 ------- ...torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 ------- ...torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 ------- ...torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 ------- ...torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 ------- 8 files changed, 56 deletions(-) delete mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json delete mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json delete mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json delete mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json delete mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json delete mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json delete mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json delete mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json deleted file mode 100644 index 864d1d3f18..0000000000 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "2048": { - "BLOCK_SIZE": 4096, - "num_stages": 1, - "num_warps": 8 - } -} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json deleted file mode 100644 index bcf56e01f7..0000000000 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "256": { - "BLOCK_SIZE": 128, - "num_stages": 1, - "num_warps": 1 - } -} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json deleted file mode 100644 index ba1dc8a75d..0000000000 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "3072": { - "BLOCK_SIZE": 2048, - "num_stages": 1, - "num_warps": 8 - } -} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json deleted file mode 100644 index 6f109e1c6e..0000000000 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "5120": { - "BLOCK_SIZE": 32768, - "num_stages": 1, - "num_warps": 8 - } -} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json deleted file mode 100644 index 198a196dfb..0000000000 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "2048": { - "BLOCK_SIZE": 1024, - "num_stages": 1, - "num_warps": 4 - } -} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json deleted file mode 100644 index 537c7a90eb..0000000000 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "256": { - "BLOCK_SIZE": 512, - "num_stages": 1, - "num_warps": 1 - } -} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json deleted file mode 100644 index 9a6dcb6fbf..0000000000 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "4096": { - "BLOCK_SIZE": 1024, - "num_stages": 1, - "num_warps": 8 - } -} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json deleted file mode 100644 index df501847ec..0000000000 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "5120": { - "BLOCK_SIZE": 1024, - "num_stages": 1, - "num_warps": 8 - } -} \ No newline at end of file From 76782c2439d16465d680ba103e13356d5ac8eb24 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 17 Mar 2026 08:38:06 +0000 Subject: [PATCH 117/180] clean code --- lightllm/common/req_manager.py | 21 ----- lightllm/models/qwen3next/mem_manager.py | 81 ++++++++++++++++++ lightllm/models/qwen3next/model.py | 83 +++---------------- .../router/dynamic_prompt/radix_cache.py | 4 +- .../server/router/model_infer/infer_batch.py | 51 +++++++----- .../mode_backend/chunked_prefill/impl.py | 10 +-- 6 files changed, 128 insertions(+), 122 deletions(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 8874e549e2..bbe2bb4a3b 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -100,18 +100,6 @@ def has_recurrent_state(self): """Whether this model uses per-request recurrent state buffers (e.g. Mamba/linear attention).""" return self.req_to_buffer_index is not None - def alloc_buffer_for_req(self, req_index: torch.Tensor): - """Allocate buffers for requests. No-op for standard models without linear attention.""" - pass - - def free_buffer(self, free_buffer_indexes): - """Free buffer memory. No-op for standard models without linear attention.""" - pass - - def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): - """Copy buffer state between requests. No-op for standard models without linear attention.""" - pass - class ReqSamplingParamsManager: """ @@ -276,12 +264,3 @@ def alloc_buffer_for_req(self, req_index: torch.Tensor): if not buffer_indexes.is_cuda: buffer_indexes = buffer_indexes.cuda() self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) - - def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): - # 获取目标请求的所有 MTP buffer (从 buffer[0] 到 buffer[mtp_step]) - mtp_range = torch.arange(0, self.mtp_step + 1, dtype=torch.int32, device="cuda") - all_mtp_buffers = self.req_to_buffer_index[tgt_req_index[:, None], mtp_range[None, :]] - - # 将 shared buffer 广播到所有 MTP step - self.buffer_mem_manager.fork_state_buffers(src_buffer_index, all_mtp_buffers) - return diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index 7ac7149a06..709d8dcf4a 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -3,11 +3,92 @@ from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager +from lightllm.server.core.objs.start_args_type import StartArgs logger = init_logger(__name__) class Qwen3NextHybridMemManager(MemoryManager): + @staticmethod + def calculate_mamba_cache_size( + start_args: StartArgs, + max_total_token_num: int, + mem_fraction: float, + config: dict, + head_linear_k_dim: int, + num_linear_k_heads: int, + head_linear_v_dim: int, + num_linear_v_heads: int, + tp_world_size: int, + data_type: torch.dtype, + ) -> int: + """Calculate mamba cache size based on available memory and mamba_cache_ratio.""" + from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory + import torch.distributed as dist + + use_ratio = max_total_token_num is None and start_args.mamba_cache_size is None + + world_size = dist.get_world_size() + total_memory = get_total_gpu_memory() + available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) + + conv_kernel_size = config["linear_conv_kernel_dim"] + conv_dim = ( + head_linear_k_dim * num_linear_k_heads * 2 + head_linear_v_dim * num_linear_v_heads + ) // tp_world_size + + num_linear_layers = config["n_layer"] - (config["n_layer"] // config["full_attention_interval"]) + + conv_cell_size = num_linear_layers * conv_dim * (conv_kernel_size - 1) * torch._utils._element_size(data_type) + + ssm_dtype = torch.bfloat16 if start_args.mamba_ssm_data_type == "bfloat16" else torch.float32 + ssm_cell_size = ( + num_linear_layers + * (num_linear_v_heads // tp_world_size) + * head_linear_k_dim + * head_linear_v_dim + * torch._utils._element_size(ssm_dtype) + ) + + total_cell_size = conv_cell_size + ssm_cell_size + + if use_ratio: + # mamba_cache_ratio = mamba_memory / total_cache_memory + mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 + mamba_memory_gb = available_memory * mamba_cache_ratio + else: + mamba_memory_gb = available_memory + mamba_cache_ratio = None + + mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) + + if mamba_cache_size < start_args.running_max_req_size: + ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 + raise ValueError( + f"Insufficient memory for mamba cache allocation!\n\n" + f"Calculated mamba_cache_size ({mamba_cache_size}) < " + f"running_max_req_size ({start_args.running_max_req_size})\n\n" + f"Memory budget:\n" + f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated buffers: {mamba_cache_size}\n" + f" Required buffers: {start_args.running_max_req_size}\n\n" + f"Solutions:\n" + f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" + f" 2. Increase --mamba_cache_ratio from {ratio} to " + f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" + f" 3. Increase --mem_fraction to leave more memory for caches\n" + ) + + logger.info( + f"Mamba cache allocation:\n" + f" Available memory: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated mamba_cache_size: {mamba_cache_size}" + ) + + return mamba_cache_size + def __init__( self, full_attn_cache_size, diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 24069b800d..eac603becb 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -50,76 +50,6 @@ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch def autotune_layers(self): return self.config["full_attention_interval"] - def _calculate_mamba_cache_size(self, start_args: StartArgs) -> int: - """Calculate mamba cache size based on available memory and mamba_cache_ratio.""" - from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory - import torch.distributed as dist - - use_ratio = self.max_total_token_num is None and start_args.mamba_cache_size is None - - world_size = dist.get_world_size() - total_memory = get_total_gpu_memory() - available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - self.mem_fraction) - - conv_kernel_size = self.config["linear_conv_kernel_dim"] - conv_dim = ( - self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads - ) // self.tp_world_size_ - - num_linear_layers = self.config["n_layer"] - (self.config["n_layer"] // self.config["full_attention_interval"]) - - conv_cell_size = ( - num_linear_layers * conv_dim * (conv_kernel_size - 1) * torch._utils._element_size(self.data_type) - ) - - ssm_dtype = torch.bfloat16 if start_args.mamba_ssm_data_type == "bfloat16" else torch.float32 - ssm_cell_size = ( - num_linear_layers - * (self.num_linear_v_heads // self.tp_world_size_) - * self.head_linear_k_dim - * self.head_linear_v_dim - * torch._utils._element_size(ssm_dtype) - ) - - total_cell_size = conv_cell_size + ssm_cell_size - - if use_ratio: - # mamba_cache_ratio = mamba_memory / total_cache_memory - mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 - mamba_memory_gb = available_memory * mamba_cache_ratio - else: - mamba_memory_gb = available_memory - mamba_cache_ratio = None - - mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) - - if mamba_cache_size < start_args.running_max_req_size: - ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 - raise ValueError( - f"Insufficient memory for mamba cache allocation!\n\n" - f"Calculated mamba_cache_size ({mamba_cache_size}) < " - f"running_max_req_size ({start_args.running_max_req_size})\n\n" - f"Memory budget:\n" - f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" - f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" - f" Calculated buffers: {mamba_cache_size}\n" - f" Required buffers: {start_args.running_max_req_size}\n\n" - f"Solutions:\n" - f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" - f" 2. Increase --mamba_cache_ratio from {ratio} to " - f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" - f" 3. Increase --mem_fraction to leave more memory for caches\n" - ) - - logger.info( - f"Mamba cache allocation:\n" - f" Available memory: {mamba_memory_gb:.2f} GB\n" - f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" - f" Calculated mamba_cache_size: {mamba_cache_size}" - ) - - return mamba_cache_size - def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) @@ -142,7 +72,18 @@ def _init_mem_manager(self): self.head_linear_v_dim = self.config["linear_value_head_dim"] if mamba_cache_size is None: - mamba_cache_size = self._calculate_mamba_cache_size(start_args) + mamba_cache_size = Qwen3NextHybridMemManager.calculate_mamba_cache_size( + start_args=start_args, + max_total_token_num=self.max_total_token_num, + mem_fraction=self.mem_fraction, + config=self.config, + head_linear_k_dim=self.head_linear_k_dim, + num_linear_k_heads=self.num_linear_k_heads, + head_linear_v_dim=self.head_linear_v_dim, + num_linear_v_heads=self.num_linear_v_heads, + tp_world_size=self.tp_world_size_, + data_type=self.data_type, + ) else: if mamba_cache_size < start_args.running_max_req_size: raise ValueError( diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 4403dba517..b95213fd4c 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -112,10 +112,10 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None, kv_cache_mem_manager=None): + def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager=None): from lightllm.common.kv_cache_mem_manager import MemoryManager - self.mem_manager: MemoryManager = kv_cache_mem_manager if kv_cache_mem_manager is not None else mem_manager + self.mem_manager: MemoryManager = kv_cache_mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index d8cc2daeb1..a2e2bfc975 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any -from lightllm.common.req_manager import ReqManager +from lightllm.common.req_manager import ReqManager, ReqManagerForMamba from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode @@ -23,6 +23,9 @@ logger = init_logger(__name__) +# Cache for mtp_range tensors to avoid repeated allocation +_mtp_range_cache: Dict[int, torch.Tensor] = {} + @dataclass class InferenceContext: @@ -64,11 +67,8 @@ def register( self.vocab_size = vocab_size - if self.has_recurrent_state: - assert self.radix_cache is None or isinstance( - self.radix_cache, HybridRadixCache - ), "Recurrent state models only support HybridRadixCache" - self.mtp_step = get_env_start_args().mtp_step + self.mtp_step = get_env_start_args().mtp_step + return def init_cpu_embed_cache_client(self): @@ -85,26 +85,30 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: self.cpu_kv_cache_stream = torch.cuda.Stream() return self.cpu_kv_cache_stream - def _alloc_and_copy_req_buffers(self, req_objs: List["InferReq"]) -> None: - """Allocate and copy buffers for requests. Delegates to req_manager which handles model-specific logic.""" + def _alloc_and_copy_req_buffers( + self, req_manager: ReqManagerForMamba, radix_cache: HybridRadixCache, req_objs: List["InferReq"] + ) -> None: if not req_objs: return - if self.radix_cache is not None and hasattr(self.radix_cache, "free_radix_cache_to_get_enough_buffer"): - self.radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs) * (self.mtp_step + 1)) + if radix_cache is not None: + radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs) * (self.mtp_step + 1)) - request_indices_gpu = torch.tensor([r.req_idx for r in req_objs], device="cuda", dtype=torch.int64) - self.req_manager.alloc_buffer_for_req(request_indices_gpu) + req_idx_gpu = torch.tensor([r.req_idx for r in req_objs], device="cuda", dtype=torch.int64) + req_manager.alloc_buffer_for_req(req_idx_gpu) - if self.radix_cache is None: - return + if radix_cache is not None: + fork_req_ids = [r.req_idx for r in req_objs if r.shared_kv_node is not None] + if fork_req_ids: + src_buf_ids = [r.shared_kv_node.buffer_idx for r in req_objs if r.shared_kv_node is not None] + req_tensor = torch.tensor(fork_req_ids, device="cuda", dtype=torch.int32) + src_tensor = torch.tensor(src_buf_ids, device="cuda", dtype=torch.int32) - copy_data = [(r.req_idx, r.shared_kv_node.buffer_idx) for r in req_objs if r.shared_kv_node is not None] - if copy_data: - copy_indices, copy_buffers = zip(*copy_data) - copy_indices_tensor = torch.tensor(copy_indices, device="cuda", dtype=torch.int64) - copy_buffers_tensor = torch.tensor(copy_buffers, device="cuda", dtype=torch.int64) - self.req_manager.copy_buffer_from_another_buffer(copy_buffers_tensor, copy_indices_tensor) + mtp_step = req_manager.mtp_step + if mtp_step not in _mtp_range_cache: + _mtp_range_cache[mtp_step] = torch.arange(0, mtp_step + 1, dtype=torch.int32, device="cuda") + dst_buffers = req_manager.req_to_buffer_index[req_tensor[:, None], _mtp_range_cache[mtp_step][None, :]] + req_manager.buffer_mem_manager.fork_state_buffers(src_tensor, dst_buffers) def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] @@ -144,7 +148,8 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: slave_req: InferReq = slave_req slave_req.related_master_req = master_req - self._alloc_and_copy_req_buffers(req_objs) + if isinstance(self.req_manager, ReqManagerForMamba): + self._alloc_and_copy_req_buffers(self.req_manager, self.radix_cache, req_objs) return req_objs @@ -250,7 +255,7 @@ def _filter(self, finished_request_ids: List[int]): free_token_index = custom_cat(free_token_index) self.req_manager.free(free_req_index, free_token_index) - if len(free_buffer_index) != 0: + if len(free_buffer_index) != 0 and isinstance(self.req_manager, ReqManagerForMamba): self.req_manager.free_buffer(free_buffer_index) finished_req_ids_set = set(finished_request_ids) @@ -300,7 +305,7 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) - if len(free_buffer_index) != 0: + if len(free_buffer_index) != 0 and isinstance(self.req_manager, ReqManagerForMamba): self.req_manager.free_buffer(free_buffer_index) g_infer_state_lock.release() diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 2ea8f07cf6..e7ca588235 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -51,12 +51,12 @@ def __init__(self) -> None: self.classed_req_strict_prefill = False return - def _maybe_insert_hybrid_radix_cache(self, run_reqs: List[InferReq]): + def _maybe_insert_hybrid_radix_cache(self, radix_cache: HybridRadixCache, run_reqs: List[InferReq]): # Insert hybrid radix cache entries if applicable, use for hybrid attention models. - if self.use_buffer_manager and self.radix_cache is not None: + if self.use_buffer_manager and radix_cache is not None: torch.cuda.synchronize() g_infer_state_lock.acquire() - self.radix_cache.insert_for_hybrid_radix_cache(run_reqs) + radix_cache.insert_for_hybrid_radix_cache(run_reqs) g_infer_state_lock.release() def infer_loop(self): @@ -146,7 +146,7 @@ def prefill_normal( nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) - self._maybe_insert_hybrid_radix_cache(run_reqs) + self._maybe_insert_hybrid_radix_cache(self.radix_cache, run_reqs) # 第四阶段 event_pack.notify_pre_post_handle() @@ -231,7 +231,7 @@ def prefill_mtp( nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) - self._maybe_insert_hybrid_radix_cache(run_reqs) + self._maybe_insert_hybrid_radix_cache(self.radix_cache, run_reqs) # 第四阶段 event_pack.notify_pre_post_handle() From fdd20528e37584a63a30620c6b0ca4f91ace46d4 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 17 Mar 2026 08:43:12 +0000 Subject: [PATCH 118/180] add get_radix_class --- lightllm/common/basemodel/basemodel.py | 4 ++-- lightllm/models/qwen3next/model.py | 3 ++- .../server/router/model_infer/mode_backend/base_backend.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 1d36c72d0b..2463bbb8e8 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -54,8 +54,8 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo - # radix cache class - radix_cache_class = RadixCache + def get_radix_class(self): + return RadixCache def __init__(self, kvargs): self.args = get_env_start_args() diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index eac603becb..806b81927f 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -33,7 +33,8 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states - radix_cache_class = HybridRadixCache + def get_radix_class(self): + return HybridRadixCache def __init__(self, kvargs) -> None: self.mem_manager: Qwen3NextHybridMemManager = None diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 92102a90d8..1f7a31351d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -175,7 +175,7 @@ def init_model(self, kvargs): self.use_buffer_manager = getattr(self.model, "use_buffer_manager", False) - radix_cache_class = self.model.radix_cache_class + radix_cache_class = self.model.get_radix_class() self.radix_cache = ( radix_cache_class( get_unique_server_name(), From 733e851097a34ed9e9051c704ab0fc270cca68c4 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 17 Mar 2026 15:09:53 +0000 Subject: [PATCH 119/180] fix acc of mamba cache --- lightllm/models/qwen3next/model.py | 5 ++- lightllm/server/api_cli.py | 2 +- .../dynamic_prompt/hybrid_radix_cache.py | 37 +------------------ .../router/dynamic_prompt/radix_cache.py | 2 +- .../server/router/model_infer/infer_batch.py | 31 +--------------- .../mode_backend/chunked_prefill/impl.py | 12 ------ 6 files changed, 8 insertions(+), 81 deletions(-) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 24069b800d..d4ef13b0ef 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -93,12 +93,13 @@ def _calculate_mamba_cache_size(self, start_args: StartArgs) -> int: mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) - if mamba_cache_size < start_args.running_max_req_size: + if mamba_cache_size < start_args.running_max_req_size * 2: ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 raise ValueError( f"Insufficient memory for mamba cache allocation!\n\n" + f"mamba_cache_size should be at least running_max_req_size * 2\n" f"Calculated mamba_cache_size ({mamba_cache_size}) < " - f"running_max_req_size ({start_args.running_max_req_size})\n\n" + f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n\n" f"Memory budget:\n" f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 47111f76bc..512a882439 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -167,7 +167,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" + "--running_max_req_size", type=int, default=256, help="the max size for forward requests in the same time" ) parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes") parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node") diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 30765a0aa2..08f6ba3fff 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -59,46 +59,13 @@ def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_toke self.evict_tree_set.add(parent_node) return - def insert_for_hybrid_radix_cache(self, reqs): - from lightllm.server.router.model_infer.infer_batch import g_infer_context - - reqs_to_insert = [req for req in reqs if req.cur_kv_len < req.get_cur_total_len()] - - if len(reqs_to_insert) == 0: - return - - self.free_radix_cache_to_get_enough_buffer(len(reqs_to_insert)) - req_idxes = torch.tensor([req.req_idx for req in reqs_to_insert], dtype=torch.int64, device="cuda") - req_to_buffer_index = g_infer_context.req_manager.req_to_buffer_index - # Make contiguous and convert to int64 for Triton kernel compatibility - cur_buffer_indexes = req_to_buffer_index[req_idxes, 0].contiguous().to(torch.int64) - - new_buffer_indexes = self.buffer_mem_manager.alloc(len(reqs_to_insert)) - # Move to CUDA and convert to int64, ensure contiguous - new_buffer_indexes_cuda = new_buffer_indexes.to(device="cuda", dtype=torch.int64).contiguous() - - self.buffer_mem_manager.copy_state_buffers(cur_buffer_indexes, new_buffer_indexes_cuda) - - for i, req in enumerate(reqs_to_insert): - input_token_ids = req.get_input_token_ids() - key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") - value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() - prefix_len, new_shared_kv_node = super().insert(key, value) - old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len - self.dec_node_ref_counter(req.shared_kv_node) - self.add_node_ref_counter(new_shared_kv_node) - self.add_buffer_idx_to_node(new_shared_kv_node, new_buffer_indexes[i].item()) - req.extra_need_to_free_token_index.append( - g_infer_context.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] - ) - req.shared_kv_node = new_shared_kv_node - def match_prefix(self, key, update_refs=False): assert len(key) != 0 ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) miss_prefix_len = 0 evict_token_list = [] + kv_len = tree_node.node_prefix_total_len while tree_node != self.root_node and tree_node.buffer_idx is None: if tree_node.is_leaf(): self.evict_tree_set.discard(tree_node) @@ -129,7 +96,7 @@ def match_prefix(self, key, update_refs=False): self.mem_manager.free(evict_token_value) if tree_node == self.root_node: - return None, miss_prefix_len, None + return None, kv_len - miss_prefix_len, None update_node = tree_node while update_node != self.root_node: diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 4403dba517..a05d50e454 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -499,7 +499,7 @@ def _print_helper(self, node: TreeNode, indent): " " * indent, f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ - node_value_len: {node.node_value_len}", + node_value_len: {node.node_value_len} buffer_idx: {node.buffer_idx}", ) for _, child in node.children.items(): self._print_helper(child, indent=indent + 2) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index d8cc2daeb1..501f17f839 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -149,10 +149,6 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs def free_a_req_mem(self, free_token_index: List, req: "InferReq"): - # If no KV cache has been allocated yet, there's nothing to free - if req.cur_kv_len == 0: - return - if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) else: @@ -171,10 +167,6 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): req.shared_kv_node = None def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> bool: - # 返回该请求的 mamba buffer 是否需要手动释放 - if req.cur_kv_len == 0: - return True - if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) else: @@ -190,10 +182,6 @@ def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> b self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None - if len(req.extra_need_to_free_token_index) > 0: - free_token_index.extend(req.extra_need_to_free_token_index) - req.extra_need_to_free_token_index = [] - if node.buffer_idx is None: req_to_buffer_index = self.req_manager.req_to_buffer_index buffer_idx = req_to_buffer_index[req.req_idx, 0].item() @@ -447,11 +435,6 @@ def __init__( self.nixl_pd_task_failed_num: int = 0 self.nixl_trans_device_id: int = -1 - # 在开启radix cache的情况下,用于标记命中情况,用于插入算法 - self.mamba_model_match_len = 0 - self.mamba_buffer_insert_len = 0 - self.extra_need_to_free_token_index = [] - # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED @@ -509,7 +492,7 @@ def _match_radix_cache(self): input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - share_node, miss_prefix_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) if share_node is not None: self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len @@ -518,13 +501,6 @@ def _match_radix_cache(self): self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 - if g_infer_context.has_recurrent_state: - MAMBA_PREFILL_BLOCK_SIZE = 128 - MAMBA_MIN_INSERT_LEN = 1024 - miss_prefix_len = miss_prefix_len - miss_prefix_len % MAMBA_PREFILL_BLOCK_SIZE - if miss_prefix_len > MAMBA_MIN_INSERT_LEN: - self.mamba_buffer_insert_len = miss_prefix_len - self.shm_req.shm_cur_kv_len = self.cur_kv_len return @@ -579,11 +555,6 @@ def get_chuncked_input_token_ids(self): def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) - - if self.mamba_buffer_insert_len > 0: - chunked_end = min(self.get_cur_total_len(), chunked_start + self.mamba_buffer_insert_len) - self.mamba_buffer_insert_len = 0 - return chunked_end def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 2ea8f07cf6..85d1e01b9c 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -51,14 +51,6 @@ def __init__(self) -> None: self.classed_req_strict_prefill = False return - def _maybe_insert_hybrid_radix_cache(self, run_reqs: List[InferReq]): - # Insert hybrid radix cache entries if applicable, use for hybrid attention models. - if self.use_buffer_manager and self.radix_cache is not None: - torch.cuda.synchronize() - g_infer_state_lock.acquire() - self.radix_cache.insert_for_hybrid_radix_cache(run_reqs) - g_infer_state_lock.release() - def infer_loop(self): torch.cuda.set_device(get_current_device_id()) try: @@ -146,8 +138,6 @@ def prefill_normal( nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) - self._maybe_insert_hybrid_radix_cache(run_reqs) - # 第四阶段 event_pack.notify_pre_post_handle() return @@ -231,8 +221,6 @@ def prefill_mtp( nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) - self._maybe_insert_hybrid_radix_cache(run_reqs) - # 第四阶段 event_pack.notify_pre_post_handle() return From 90120b0ebfb615ebdd0f95f81db7ea983d3510c9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 17 Mar 2026 17:28:05 +0000 Subject: [PATCH 120/180] fix warmup --- lightllm/common/mamba_cache_mem_manager/cache_manager.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index a33a737516..fe5ac093e0 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -105,6 +105,12 @@ def free(self, free_index: Union[torch.Tensor, List[int]]): super().free(free_index) return + def free_all(self): + self.conv_state_cache.buffer.fill_(0) + self.ssm_state_cache.buffer.fill_(0) + super().free_all() + return + class ReadOnlyStaticsMambaCacheManager: """ From 13edba269faec29c85d614752e3cac4ea90a2918 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 18 Mar 2026 09:51:55 +0000 Subject: [PATCH 121/180] simplify the qwen3next layer_infer --- lightllm/models/__init__.py | 1 - .../layer_infer/transformer_layer_infer.py | 13 +- lightllm/models/qwen3_5/model.py | 30 +- .../layer_infer/transformer_layer_infer.py | 395 +---- .../layer_weights/transformer_layer_weight.py | 8 +- lightllm/models/qwen3next/model.py | 22 +- .../qwen3next/triton_kernel/gdn_decode_mtp.py | 1333 ----------------- lightllm/models/qwen3next_mtp/__init__.py | 3 - .../qwen3next_mtp/layer_infer/__init__.py | 0 .../layer_infer/pre_layer_infer.py | 61 - .../layer_infer/transformer_layer_infer.py | 30 - .../qwen3next_mtp/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 47 - .../layer_weights/transformer_layer_weight.py | 141 -- lightllm/models/qwen3next_mtp/model.py | 101 -- .../model_infer/mode_backend/base_backend.py | 3 - 16 files changed, 79 insertions(+), 2109 deletions(-) delete mode 100644 lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py delete mode 100644 lightllm/models/qwen3next_mtp/__init__.py delete mode 100644 lightllm/models/qwen3next_mtp/layer_infer/__init__.py delete mode 100644 lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py delete mode 100644 lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/qwen3next_mtp/layer_weights/__init__.py delete mode 100644 lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py delete mode 100644 lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py delete mode 100644 lightllm/models/qwen3next_mtp/model.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index a7e4cd58b7..2caee91709 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -8,7 +8,6 @@ from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.qwen3next.model import Qwen3NextTpPartModel -from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel from lightllm.models.internlm2.model import Internlm2TpPartModel diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index 64ecf94edb..d0657bcbe8 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -3,8 +3,7 @@ from typing import Tuple from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( - Qwen3NextFullAttentionTransformerLayerInfer, - Qwen3NextGatedDeltaNetTransformerLayerInfer, + Qwen3NextTransformerLayerInfer, ) from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( Qwen35TransformerLayerWeight, @@ -16,7 +15,7 @@ logger = init_logger(__name__) -class Qwen35FullAttentionTransformerLayerInfer(Qwen3NextFullAttentionTransformerLayerInfer): +class Qwen35TransformerLayerInfer(Qwen3NextTransformerLayerInfer): def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) # Initialize mrope section from config @@ -57,11 +56,3 @@ def _get_qkv( partial_rotary_factor=self.partial_rotary_factor, ) return q, cache_kv - - -class Qwen35GatedDeltaNetTransformerLayerInfer(Qwen3NextGatedDeltaNetTransformerLayerInfer): - def __init__(self, layer_num, network_config): - super().__init__(layer_num, network_config) - rope_scaling = network_config.get("rope_scaling", {}) - mrope_section = rope_scaling.get("mrope_section", [11, 11, 10]) - self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index f29d50476b..63503c77ba 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -13,8 +13,7 @@ Qwen3VLPreAndPostLayerWeight, ) from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( - Qwen35FullAttentionTransformerLayerInfer, - Qwen35GatedDeltaNetTransformerLayerInfer, + Qwen35TransformerLayerInfer, ) from lightllm.models.qwen3_5.infer_struct import Qwen35InferStateInfo from lightllm.common.build_utils import repair_config @@ -52,10 +51,12 @@ class Qwen3_5TpPartModel(Qwen3NextTpPartModel): - Multimodal embeddings merged with text embeddings """ - pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer transformer_weight_class = Qwen35TransformerLayerWeight pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer + transformer_layer_infer_class = Qwen35TransformerLayerInfer + infer_state_class = Qwen35InferStateInfo def _init_config(self): @@ -97,26 +98,3 @@ def _init_config(self): # Calculate num_kv_heads for KV cache memory management # Required by parent class _init_mem_manager() in Qwen3NextTpPartModel self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - - def _init_infer_layer(self): - """ - Initialize inference layers for Qwen3.5 multimodal model. - - Uses mrope-enabled transformer layers to properly handle image/video - tokens with 3D position encoding (temporal, height, width). - - This overrides the parent class to use Qwen35* layer classes instead - of Qwen3Next* layer classes. - """ - self.pre_infer = self.pre_layer_infer_class(network_config=self.config) - self.post_infer = self.post_layer_infer_class(network_config=self.config) - num_full_attention_layers = self.config["full_attention_interval"] - - self.layers_infer = [ - ( - Qwen35FullAttentionTransformerLayerInfer(i, network_config=self.config) - if (i + 1) % num_full_attention_layers == 0 - else Qwen35GatedDeltaNetTransformerLayerInfer(i, network_config=self.config) - ) - for i in range(self.config["n_layer"]) - ] diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 849af38bd5..6e2f8d7c9c 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -17,11 +17,6 @@ from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule -from lightllm.models.qwen3next.triton_kernel.gdn_decode_mtp import ( - copy_conv_states, - copy_ssm_states, - copy_states_fused, -) from lightllm.distributed import all_reduce from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type @@ -30,12 +25,7 @@ logger = init_logger(__name__) -class Qwen3NextFullAttentionBaseLayerInfer(LlamaTransformerLayerInfer): - """ - Base class for Qwen3Next full attention layers. - Contains shared logic for both standard full attention and MTP layers. - """ - +class Qwen3NextTransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, network_config): self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0) self.n_routed_experts = network_config.get("num_experts", 0) @@ -51,6 +41,10 @@ def __init__(self, layer_num, network_config): self.head_dim_ = network_config.get( "head_dim", network_config["hidden_size"] // network_config["num_attention_heads"] ) + num_full_attention_layers = network_config["full_attention_interval"] + self.is_linear_attention_layer = (layer_num + 1) % num_full_attention_layers != 0 + if self.is_linear_attention_layer: + self._init_linear_layer_metadata(layer_num, network_config) return def _bind_func(self): @@ -63,11 +57,11 @@ def _bind_ffn(self): if self.is_moe: moe_mode = os.environ.get("MOE_MODE", "TP") if moe_mode == "EP": - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._moe_ffn_edp, self) + self._ffn = partial(Qwen3NextTransformerLayerInfer._moe_ffn_edp, self) else: - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._moe_ffn, self) + self._ffn = partial(Qwen3NextTransformerLayerInfer._moe_ffn, self) else: - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn, self) + self._ffn = partial(Qwen3NextTransformerLayerInfer._ffn, self) return def _compute_shared_expert( @@ -170,32 +164,7 @@ def _get_o( o_tensor = layer_weight.o_proj.mm(input) return o_tensor - -class Qwen3NextFullAttentionTransformerLayerInfer(Qwen3NextFullAttentionBaseLayerInfer): - """ - Full attention layer for Qwen3Next that uses the abstracted attention backend. - Inherits from Qwen3NextFullAttentionBaseLayerInfer to get shared Qwen3Next logic. - """ - - pass - - -class Qwen3NextGatedDeltaNetTransformerLayerInfer(LlamaTransformerLayerInfer): - """ - Linear attention (Gated Delta Networks) layer for Qwen3Next. - """ - - def __init__(self, layer_num, network_config): - self.n_routed_experts = network_config.get("num_experts", 0) - self.is_moe = ( - network_config.get("num_experts", 0) > 0 - and layer_num not in network_config.get("mlp_only_layers", []) - and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 - ) - super().__init__(layer_num, network_config) - # MoE configuration - self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) - self.norm_topk_prob = network_config.get("norm_topk_prob", False) + def _init_linear_layer_metadata(self, layer_num, network_config): # Linear attention specific dimensions self.num_v_heads = network_config["linear_num_value_heads"] @@ -215,20 +184,9 @@ def __init__(self, layer_num, network_config): self.tp_key_dim = self.key_dim // self.tp_world_size_ self.tp_value_dim = self.value_dim // self.tp_world_size_ - # Template required dimensions (not used for GDN but required by interface) - self.tp_q_head_num_ = self.tp_num_k_heads - self.tp_k_head_num_ = self.tp_num_k_heads - self.tp_v_head_num_ = self.tp_num_v_heads - self.tp_o_head_num_ = self.tp_num_v_heads - self.head_dim_ = self.head_v_dim - assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads - # MTP configuration - self.mtp_step = get_env_start_args().mtp_step - self.mtp_size = self.mtp_step + 1 - # SSM state dtype optimization ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} start_args = get_env_start_args() @@ -238,152 +196,84 @@ def __init__(self, layer_num, network_config): # GDN kernel output dtype is self.data_type # Conversion needed only if SSM state uses different dtype self.needs_ssm_dtype_conversion = get_llm_data_type() != self.ssm_state_dtype - self._bind_func() - return - - def _bind_func(self): - """Bind layer-specific implementations""" - self._bind_ffn() - return - - def _bind_ffn(self): - """Bind FFN implementation based on MoE configuration.""" - if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._moe_ffn_edp, self) - else: - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._moe_ffn, self) - else: - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn, self) return - def _compute_shared_expert( - self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight - ): - input = input.view(-1, self.embed_dim_) - shared_expert_out = super()._ffn(input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input).sigmoid_() - shared_expert_out.mul_(gate) - return shared_expert_out - - def _moe_ffn( - self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight - ): - """MoE FFN with tensor parallelism.""" - - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) - - hidden_states = input.view(-1, self.embed_dim_) - num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(hidden_states) - layer_weight.experts.experts( - hidden_states, - router_logits=router_logits, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - use_grouped_topk=False, - topk_group=None, - num_expert_group=None, - ) - hidden_states = hidden_states.view(num_tokens, hidden_dim) - hidden_states.add_(shared_expert_out) - return hidden_states - - def _moe_ffn_edp( - self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight - ): - """MoE FFN with expert parallelism.""" - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) - hidden_states = input - token_num, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(hidden_states) - ep_output = layer_weight.experts.experts( - hidden_states, - router_logits=router_logits, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - use_grouped_topk=False, - topk_group=None, - num_expert_group=None, - is_prefill=infer_state.is_prefill, - ) - ep_output = ep_output.view(token_num, hidden_dim) - ep_output.add_(shared_expert_out) - return ep_output + # ==================== GDN Helper Methods ==================== - def _gdn_layer_forward( + def context_attention_forward( self, input_embdings, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, - is_prefill: bool, ): - """Unified forward for both prefill and decode in GDN layers.""" - # Attention + GDN processing - input1 = layer_weight.att_norm_weight_(input=input_embdings, eps=self.eps_, alloc_func=self.alloc_tensor) - gdn_out = self.gdn_forward(input1, infer_state, layer_weight, is_prefill=is_prefill) - if self.tp_world_size_ > 1: - all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + if not self.is_linear_attention_layer: + return super().context_attention_forward(input_embdings, infer_state, layer_weight) - # FFN - input_embdings.add_(gdn_out.view(-1, self.embed_dim_)) - gdn_out = None - input1 = layer_weight.ffn_norm_weight_(input=input_embdings, eps=self.eps_, alloc_func=self.alloc_tensor) - - ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return input_embdings + all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return gdn_out - def context_forward( + def token_attention_forward( self, input_embdings, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - """Override context_forward to use GDN logic instead of standard attention flow.""" - return self._gdn_layer_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + if not self.is_linear_attention_layer: + return super().token_attention_forward(input_embdings, infer_state, layer_weight) + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=False) + if self.tp_world_size_ > 1: + all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return gdn_out - def token_forward( + def gdn_forward( self, - input_embdings, + input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, + is_prefill: bool, ): - """Override token_forward to use GDN logic instead of standard attention flow.""" - return self._gdn_layer_forward(input_embdings, infer_state, layer_weight, is_prefill=False) + assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) - def overlap_tpsp_token_forward( - self, - input_embdings, - input_embdings1, - infer_state: Qwen3NextInferStateInfo, - infer_state1: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextTransformerLayerWeight, - ): - """Microbatch overlap for decode: process two half-batches sequentially. - Enables --enable_decode_microbatch_overlap for GDN layers.""" - input_embdings = self.token_forward(input_embdings, infer_state, layer_weight) - input_embdings1 = self.token_forward(input_embdings1, infer_state1, layer_weight) - return input_embdings, input_embdings1 + # Common preprocessing + input = input.view(-1, self.embed_dim_) + conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) - def overlap_tpsp_context_forward( - self, - input_embdings, - input_embdings1, - infer_state: Qwen3NextInferStateInfo, - infer_state1: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextTransformerLayerWeight, - ): - """Microbatch overlap for context: process two half-batches sequentially.""" - input_embdings = self.context_forward(input_embdings, infer_state, layer_weight) - input_embdings1 = self.context_forward(input_embdings1, infer_state1, layer_weight) - return input_embdings, input_embdings1 + mixed_qkvzba = layer_weight.linear_in_proj.mm(input) + # mixed_qkv is now returned pre-concatenated (no torch.cat needed) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) - # ==================== GDN Helper Methods ==================== + # Dispatch to appropriate kernel + if is_prefill: + # Prefill: compute g/beta upfront (chunk kernel doesn't support fused gating) + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) + core_attn_out = self._gdn_prefill_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + else: + # Decode (non-MTP): fuse gating into recurrent kernel to save 2 kernel launches + core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) + + # Common postprocessing + num_tokens = z.shape[0] # batch (decode) or total_tokens (prefill/MTP) + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) + gated_rmsnorm_forward( + core_attn_out, + layer_weight.linear_norm.weight, + None, # RMSNormWeight has no bias + self.eps_, + z, + out=norm_out, + ) + # Merge head and value dims in a single view: (num_tokens * HV, V) → (num_tokens, HV * V) + core_attn_out = norm_out.view(num_tokens, -1) + + output = layer_weight.linear_out_proj.mm(core_attn_out) + # Note: all_reduce is handled by context_forward/token_forward callers + return output def _split_qkvzba(self, mixed_qkvzba, is_decode=False): qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim @@ -421,24 +311,6 @@ def _rearrange_mixed_qkv(self, mixed_qkv, decode=False): value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim) return query, key, value - def context_attention_forward( - self, - input_embdings, - infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextTransformerLayerWeight, - ): - gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) - return gdn_out - - def token_attention_forward( - self, - input_embdings, - infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextTransformerLayerWeight, - ): - gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=False) - return gdn_out - def _gdn_prefill_kernel( self, mixed_qkv: torch.Tensor, @@ -525,144 +397,3 @@ def _gdn_decode_kernel( b_raw=b, ) return core_attn_out - - def _gdn_decode_mtp_kernel( - self, - mixed_qkv: torch.Tensor, - conv_states: torch.Tensor, - ssm_states: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextTransformerLayerWeight, - ): - """ - Optimized decode kernel for GDN forward pass (MTP mode with multiple steps). - - Key optimizations: - 1. Uses pre-allocated work buffer to avoid per-step .contiguous() allocations - 2. Uses optimized flat Triton kernels for state copying - 3. Direct slice assignment for output instead of .copy_() - - Note: Sequential processing is required because each MTP step depends on - the previous step's final state (both conv and SSM states). - """ - total_tokens = mixed_qkv.shape[0] - batch_size = total_tokens // self.mtp_size - - # Pre-allocate output tensor - core_attn_out = torch.empty( - (total_tokens, 1, self.tp_num_v_heads, self.head_v_dim), - dtype=mixed_qkv.dtype, - device=mixed_qkv.device, - ) - - # Pre-allocate work buffer for conv1d input (avoids per-step .contiguous()) - qkv_work_buffer = torch.empty( - (batch_size, mixed_qkv.shape[-1]), - dtype=mixed_qkv.dtype, - device=mixed_qkv.device, - ) - - # Process each MTP step sequentially (required due to state dependencies) - for step_idx in range(self.mtp_size): - cur_buffer_idx = infer_state.mtp_buffer_idx_list[step_idx] - - # ========== Conv1D processing ========== - # Copy strided data to contiguous work buffer - qkv_work_buffer.copy_(mixed_qkv[step_idx :: self.mtp_size]) - - # causal_conv1d_update operates in-place on contiguous input - causal_conv1d_update( - qkv_work_buffer, - conv_states, - layer_weight.linear_conv1d.mm_param.weight, - bias=layer_weight.linear_conv1d.bias, - activation=self.activation, - conv_state_indices=cur_buffer_idx, - ) - - # ========== Recurrent processing ========== - query_i, key_i, value_i = self._rearrange_mixed_qkv(qkv_work_buffer, decode=True) - g_i = g[step_idx :: self.mtp_size].unsqueeze(1) - beta_i = beta[step_idx :: self.mtp_size].unsqueeze(1) - - core_attn_out_i, _ = fused_recurrent_gated_delta_rule( - q=query_i, - k=key_i, - v=value_i, - g=g_i, - beta=beta_i, - initial_state=ssm_states, - inplace_final_state=True, - ssm_state_indices=cur_buffer_idx, - use_qk_l2norm_in_kernel=True, - ) - - # Direct slice assignment (no .copy_() needed) - core_attn_out[step_idx :: self.mtp_size] = core_attn_out_i - - # ========== State propagation to next step ========== - if step_idx < self.mtp_step: - next_buffer_idx = infer_state.mtp_buffer_idx_list[step_idx + 1] - if conv_states.is_contiguous() and ssm_states.is_contiguous(): - copy_states_fused(conv_states, ssm_states, cur_buffer_idx, next_buffer_idx) - else: - copy_conv_states(conv_states, cur_buffer_idx, next_buffer_idx) - copy_ssm_states(ssm_states, cur_buffer_idx, next_buffer_idx) - - return core_attn_out - - def gdn_forward( - self, - input: torch.Tensor, - infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextTransformerLayerWeight, - is_prefill: bool, - ): - assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) - - # Common preprocessing - input = input.view(-1, self.embed_dim_) - conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) - - mixed_qkvzba = layer_weight.linear_in_proj.mm(input) - # mixed_qkv is now returned pre-concatenated (no torch.cat needed) - mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) - - # Dispatch to appropriate kernel - if is_prefill: - # Prefill: compute g/beta upfront (chunk kernel doesn't support fused gating) - g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) - core_attn_out = self._gdn_prefill_kernel( - mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight - ) - elif self.mtp_step == 0: - # Decode (non-MTP): fuse gating into recurrent kernel to save 2 kernel launches - core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) - else: - # Decode (MTP): compute g/beta upfront (multiple recurrent calls per step) - g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) - core_attn_out = self._gdn_decode_mtp_kernel( - mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight - ) - - # Common postprocessing - num_tokens = z.shape[0] # batch (decode) or total_tokens (prefill/MTP) - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) - gated_rmsnorm_forward( - core_attn_out, - layer_weight.linear_norm.weight, - None, # RMSNormWeight has no bias - self.eps_, - z, - out=norm_out, - ) - # Merge head and value dims in a single view: (num_tokens * HV, V) → (num_tokens, HV * V) - core_attn_out = norm_out.view(num_tokens, -1) - - output = layer_weight.linear_out_proj.mm(core_attn_out) - # Note: all_reduce is handled by context_forward/token_forward callers - return output diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index be68e6aeb1..31dae85ec8 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -14,7 +14,7 @@ class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): num_full_attention_layers = network_config["full_attention_interval"] - self.is_linear_attention = (layer_num + 1) % num_full_attention_layers != 0 + self.is_linear_attention_layer = (layer_num + 1) % num_full_attention_layers != 0 super().__init__(layer_num, data_type, network_config, quant_cfg) return @@ -42,7 +42,7 @@ def _init_qkv(self): ) def _init_weight(self): - if self.is_linear_attention: + if self.is_linear_attention_layer: self._init_gdn_weight() else: self._init_qkv() @@ -71,7 +71,7 @@ def _init_norm(self): weight_name=self._ffn_norm_weight_name, data_type=self.data_type_, ) - if not self.is_linear_attention: + if not self.is_linear_attention_layer: self.qk_norm_weight_ = QKGEMMANormWeight( dim=self.head_dim, q_weight_name=self._q_norm_name, @@ -268,6 +268,6 @@ def _parse_linear_conv1d(self, weight): def load_hf_weights(self, weights): self._split_q_with_gate(weights) - if self.is_linear_attention: + if self.is_linear_attention_layer: self._preprocess_weight(weights) super().load_hf_weights(weights) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index ab3fb3933c..50461bd770 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -8,8 +8,7 @@ ) from lightllm.models.qwen3next.layer_weights.pre_and_post_layer_weight import Qwen3NextPreAndPostLayerWeight from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( - Qwen3NextFullAttentionTransformerLayerInfer, - Qwen3NextGatedDeltaNetTransformerLayerInfer, + Qwen3NextTransformerLayerInfer, ) from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger @@ -26,9 +25,14 @@ @ModelRegistry("qwen3_next") class Qwen3NextTpPartModel(Qwen3MOEModel): + # weight class pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight transformer_weight_class = Qwen3NextTransformerLayerWeight + # infer class + transformer_layer_infer_class = Qwen3NextTransformerLayerInfer + + # infer state class infer_state_class = Qwen3NextInferStateInfo use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states @@ -135,17 +139,3 @@ def _init_req_manager(self): create_max_seq_len = max(create_max_seq_len, self.max_seq_length) self.req_manager = ReqManagerForMamba(self.max_req_num, create_max_seq_len, self.mem_manager) - - def _init_infer_layer(self): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config) - self.post_infer = self.post_layer_infer_class(network_config=self.config) - num_full_attention_layers = self.config["full_attention_interval"] - - self.layers_infer = [ - ( - Qwen3NextFullAttentionTransformerLayerInfer(i, network_config=self.config) - if (i + 1) % num_full_attention_layers == 0 - else Qwen3NextGatedDeltaNetTransformerLayerInfer(i, network_config=self.config) - ) - for i in range(self.config["n_layer"]) - ] diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py deleted file mode 100644 index 5a39debaa9..0000000000 --- a/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py +++ /dev/null @@ -1,1333 +0,0 @@ -""" -Optimized GDN Decode MTP (Multi-Token Prediction) Kernel - -This module provides an optimized Triton kernel for GDN decode with MTP support, -eliminating the need for sequential Python loops and reducing memory operations. - -Key optimizations: -1. Fused data reorganization from interleaved to batched layout -2. Parallel processing of all batch items with proper state indexing -3. Auto-tuned configurations for different batch sizes and model dimensions -""" - -import torch -import triton -import triton.language as tl -from lightllm.common.triton_utils.autotuner import autotune - - -@triton.jit -def _reorganize_mtp_data_kernel( - # Input pointers (interleaved layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...]) - src_ptr, - # Output pointers (batched layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...]) - dst_ptr, - # Dimensions - batch_size, - mtp_size, - dim_size, - # Strides - src_stride_token, - src_stride_dim, - dst_stride_token, - dst_stride_dim, - # Block sizes - BLOCK_DIM: tl.constexpr, -): - """ - Reorganize data from interleaved MTP layout to batched layout. - - Input layout: [step0_batch0, step0_batch1, ..., step0_batchN, step1_batch0, ...] - Output layout: [batch0_step0, batch0_step1, ..., batch0_stepM, batch1_step0, ...] - - This enables efficient processing with the recurrent kernel. - """ - batch_idx = tl.program_id(0) - step_idx = tl.program_id(1) - block_dim_idx = tl.program_id(2) - - # Calculate source and destination token indices - src_token_idx = step_idx * batch_size + batch_idx - dst_token_idx = batch_idx * mtp_size + step_idx - - # Calculate dimension offsets - dim_start = block_dim_idx * BLOCK_DIM - dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) - mask = dim_offsets < dim_size - - # Load from source (interleaved layout) - src_offset = src_token_idx * src_stride_token + dim_offsets * src_stride_dim - data = tl.load(src_ptr + src_offset, mask=mask, other=0.0) - - # Store to destination (batched layout) - dst_offset = dst_token_idx * dst_stride_token + dim_offsets * dst_stride_dim - tl.store(dst_ptr + dst_offset, data, mask=mask) - - -@triton.jit -def _reorganize_mtp_data_back_kernel( - # Input pointers (batched layout): [batch_size, mtp_size, num_heads, head_dim] - src_ptr, - # Output pointers (interleaved layout): [total_tokens, 1, num_heads, head_dim] - dst_ptr, - # Dimensions - batch_size, - mtp_size, - num_heads, - head_dim, - # Strides for src: [batch_size, mtp_size, num_heads, head_dim] - src_stride_batch, - src_stride_mtp, - src_stride_head, - src_stride_dim, - # Strides for dst: [total_tokens, 1, num_heads, head_dim] - dst_stride_token, - dst_stride_seq, - dst_stride_head, - dst_stride_dim, - # Block sizes - BLOCK_HEAD: tl.constexpr, - BLOCK_DIM: tl.constexpr, -): - """ - Reorganize output data from batched layout back to interleaved layout. - - Input shape: [batch_size, mtp_size, num_heads, head_dim] - Output shape: [batch_size * mtp_size, 1, num_heads, head_dim] (interleaved) - - Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] - """ - batch_idx = tl.program_id(0) - step_idx = tl.program_id(1) - block_idx = tl.program_id(2) - - # Decompose block_idx into head and dim blocks - num_dim_blocks = tl.cdiv(head_dim, BLOCK_DIM) - block_head_idx = block_idx // num_dim_blocks - block_dim_idx = block_idx % num_dim_blocks - - # Calculate destination token index (interleaved) - dst_token_idx = step_idx * batch_size + batch_idx - - # Calculate offsets - head_start = block_head_idx * BLOCK_HEAD - dim_start = block_dim_idx * BLOCK_DIM - - head_offsets = head_start + tl.arange(0, BLOCK_HEAD) - dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) - - head_mask = head_offsets < num_heads - dim_mask = dim_offsets < head_dim - mask = head_mask[:, None] & dim_mask[None, :] - - # Load from source (batched layout): [batch_size, mtp_size, num_heads, head_dim] - src_base = src_ptr + batch_idx * src_stride_batch + step_idx * src_stride_mtp - src_offset = head_offsets[:, None] * src_stride_head + dim_offsets[None, :] * src_stride_dim - data = tl.load(src_base + src_offset, mask=mask, other=0.0) - - # Store to destination (interleaved layout): [total_tokens, 1, num_heads, head_dim] - # The seq dimension (1) is skipped since it's always 0 - dst_base = dst_ptr + dst_token_idx * dst_stride_token - dst_offset = head_offsets[:, None] * dst_stride_head + dim_offsets[None, :] * dst_stride_dim - tl.store(dst_base + dst_offset, data, mask=mask) - - -def _get_reorganize_mtp_configs(): - """Generate candidate configurations for MTP data reorganization.""" - configs = [] - for block_dim in [64, 128, 256, 512]: - for num_warps in [2, 4, 8]: - for num_stages in [2, 3, 4]: - configs.append( - { - "BLOCK_DIM": block_dim, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - -def _get_reorganize_static_key(src: torch.Tensor, mtp_size: int): - """Static key based on tensor properties.""" - return { - "dtype": str(src.dtype), - "mtp_size": mtp_size, - } - - -def _get_reorganize_run_key(src: torch.Tensor, mtp_size: int): - """Run key based on batch size and dimension.""" - total_tokens = src.shape[0] - batch_size = total_tokens // mtp_size - dim_size = src.shape[-1] - return f"{batch_size}_{dim_size}" - - -@autotune( - kernel_name="gdn_decode_mtp_reorganize:v1", - configs_gen_func=_get_reorganize_mtp_configs, - static_key_func=_get_reorganize_static_key, - run_key_func=_get_reorganize_run_key, - mutates_args=["dst"], -) -def reorganize_mtp_to_batched( - src: torch.Tensor, - dst: torch.Tensor, - mtp_size: int, - run_config: dict = None, -): - """ - Reorganize data from interleaved MTP layout to batched layout. - - Args: - src: Input tensor with interleaved layout [total_tokens, dim] - Layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...] - dst: Output tensor with batched layout [total_tokens, dim] - Layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...] - mtp_size: Number of MTP steps - run_config: Auto-tuned configuration - """ - total_tokens = src.shape[0] - batch_size = total_tokens // mtp_size - dim_size = src.shape[-1] - - if run_config is None: - BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_DIM = run_config["BLOCK_DIM"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) - - grid = (batch_size, mtp_size, num_blocks_dim) - - _reorganize_mtp_data_kernel[grid]( - src, - dst, - batch_size, - mtp_size, - dim_size, - src.stride(0), - src.stride(-1) if src.ndim > 1 else 1, - dst.stride(0), - dst.stride(-1) if dst.ndim > 1 else 1, - BLOCK_DIM=BLOCK_DIM, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def _get_reorganize_back_configs(): - """Generate candidate configurations for MTP output reorganization.""" - configs = [] - for block_head in [4, 8, 16, 32]: - for block_dim in [32, 64, 128]: - for num_warps in [2, 4, 8]: - for num_stages in [2, 3]: - if block_head * block_dim <= 4096: # Limit shared memory - configs.append( - { - "BLOCK_HEAD": block_head, - "BLOCK_DIM": block_dim, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - -def _get_reorganize_back_static_key( - src: torch.Tensor, - batch_size: int, - mtp_size: int, - num_heads: int, - head_dim: int, -): - """Static key for output reorganization.""" - return { - "dtype": str(src.dtype), - "mtp_size": mtp_size, - "num_heads": num_heads, - "head_dim": head_dim, - } - - -def _get_reorganize_back_run_key( - src: torch.Tensor, - batch_size: int, - mtp_size: int, - num_heads: int, - head_dim: int, -): - """Run key for output reorganization.""" - return batch_size - - -@autotune( - kernel_name="gdn_decode_mtp_reorganize_back:v1", - configs_gen_func=_get_reorganize_back_configs, - static_key_func=_get_reorganize_back_static_key, - run_key_func=_get_reorganize_back_run_key, - mutates_args=["dst"], -) -def reorganize_mtp_output_to_interleaved( - src: torch.Tensor, - dst: torch.Tensor, - batch_size: int, - mtp_size: int, - num_heads: int, - head_dim: int, - run_config: dict = None, -): - """ - Reorganize output from batched layout back to interleaved layout. - - Args: - src: Input tensor [batch_size, mtp_size, num_heads, head_dim] (4D) - dst: Output tensor [batch_size * mtp_size, 1, num_heads, head_dim] (4D) - batch_size: Number of batch items - mtp_size: Number of MTP steps - num_heads: Number of attention heads - head_dim: Head dimension - run_config: Auto-tuned configuration - - Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] - """ - if run_config is None: - BLOCK_HEAD = min(triton.next_power_of_2(num_heads), 16) - BLOCK_DIM = min(triton.next_power_of_2(head_dim), 64) - num_warps = 4 - num_stages = 2 - else: - BLOCK_HEAD = run_config["BLOCK_HEAD"] - BLOCK_DIM = run_config["BLOCK_DIM"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_head_blocks = triton.cdiv(num_heads, BLOCK_HEAD) - num_dim_blocks = triton.cdiv(head_dim, BLOCK_DIM) - num_blocks_total = num_head_blocks * num_dim_blocks - - grid = (batch_size, mtp_size, num_blocks_total) - - # src is 4D: [batch_size, mtp_size, num_heads, head_dim] - # dst is 4D: [total_tokens, 1, num_heads, head_dim] - _reorganize_mtp_data_back_kernel[grid]( - src, - dst, - batch_size, - mtp_size, - num_heads, - head_dim, - src.stride(0), # batch stride - src.stride(1), # mtp stride - src.stride(2), # head stride - src.stride(3), # dim stride - dst.stride(0), # token stride - dst.stride(1), # seq stride (=1) - dst.stride(2), # head stride - dst.stride(3), # dim stride - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_DIM=BLOCK_DIM, - num_warps=num_warps, - num_stages=num_stages, - ) - - -@triton.jit -def _prepare_mtp_indices_kernel( - # Input indices (per-step buffer indices) - buffer_idx_ptr, - # Output 2D indices for recurrent kernel - output_idx_ptr, - # Dimensions - batch_size, - mtp_size, - # Strides - input_stride, - output_stride_batch, - output_stride_step, -): - """ - Prepare 2D indices for the fused recurrent kernel. - - Input: mtp_size tensors of shape [batch_size] (buffer indices for each step) - Output: 2D tensor [batch_size, mtp_size] for ssm_state_indices - """ - batch_idx = tl.program_id(0) - step_idx = tl.program_id(1) - - # Load the buffer index for this batch and step - buffer_idx = tl.load(buffer_idx_ptr + step_idx * input_stride + batch_idx) - - # Store to the 2D output - output_offset = batch_idx * output_stride_batch + step_idx * output_stride_step - tl.store(output_idx_ptr + output_offset, buffer_idx) - - -def prepare_mtp_state_indices( - mtp_buffer_idx_list: list, - batch_size: int, - device: torch.device, -) -> torch.Tensor: - """ - Prepare 2D state indices for the fused recurrent kernel. - - Args: - mtp_buffer_idx_list: List of buffer index tensors, one per MTP step - batch_size: Number of batch items - device: Target device - - Returns: - 2D tensor of shape [batch_size, mtp_size] for ssm_state_indices - """ - - # Stack indices to create [mtp_size, batch_size] tensor - stacked_indices = torch.stack(mtp_buffer_idx_list, dim=0) - - # Transpose to get [batch_size, mtp_size] - return stacked_indices.T.contiguous() - - -@triton.jit -def _fused_conv1d_mtp_step_kernel( - # Input/output data - mixed_qkv_ptr, - # Conv state buffer - conv_states_ptr, - # Conv weight and bias - conv_weight_ptr, - conv_bias_ptr, - # Buffer indices (one per MTP step, each [batch_size]) - buffer_indices_ptr, - next_buffer_indices_ptr, - # Dimensions - batch_size, - dim_size, - conv_width, - # Step info - step_idx, - mtp_size, - is_last_step: tl.constexpr, - # Strides - qkv_stride_token, - qkv_stride_dim, - state_stride_buffer, - state_stride_dim, - state_stride_width, - weight_stride_dim, - weight_stride_width, - # Block sizes - BLOCK_DIM: tl.constexpr, - ACTIVATION_SILU: tl.constexpr, -): - """ - Fused kernel for conv1d update in MTP decode. - - Handles one MTP step for all batch items: - 1. Reads current conv state - 2. Updates with new input - 3. Computes conv1d output - 4. Optionally copies state to next MTP step - """ - batch_idx = tl.program_id(0) - block_dim_idx = tl.program_id(1) - - # Calculate token index in interleaved layout - token_idx = step_idx * batch_size + batch_idx - - # Load buffer indices - cur_buffer_idx = tl.load(buffer_indices_ptr + batch_idx).to(tl.int64) - - # Calculate dimension offsets - dim_start = block_dim_idx * BLOCK_DIM - dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) - dim_mask = dim_offsets < dim_size - - # Load input value - input_offset = token_idx * qkv_stride_token + dim_offsets * qkv_stride_dim - input_val = tl.load(mixed_qkv_ptr + input_offset, mask=dim_mask, other=0.0) - - # Load conv bias - bias_val = tl.load(conv_bias_ptr + dim_offsets, mask=dim_mask, other=0.0) - - # Compute conv1d output and update state - output_val = bias_val - state_base = conv_states_ptr + cur_buffer_idx * state_stride_buffer - - # Process each position in the conv window - for w in range(conv_width): - # Load weight for this position - weight_offset = dim_offsets * weight_stride_dim + w * weight_stride_width - weight_val = tl.load(conv_weight_ptr + weight_offset, mask=dim_mask, other=0.0) - - if w < conv_width - 1: - # Load from state buffer - state_offset = dim_offsets * state_stride_dim + w * state_stride_width - state_val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) - output_val += state_val * weight_val - else: - # Use current input for the last position - output_val += input_val * weight_val - - # Update conv state (shift and insert new value) - for w in range(conv_width - 2, -1, -1): - if w == conv_width - 2: - # Insert new input at the end - state_offset = dim_offsets * state_stride_dim + w * state_stride_width - tl.store(state_base + state_offset, input_val, mask=dim_mask) - else: - # Shift state - src_offset = dim_offsets * state_stride_dim + (w + 1) * state_stride_width - dst_offset = dim_offsets * state_stride_dim + w * state_stride_width - val = tl.load(state_base + src_offset, mask=dim_mask, other=0.0) - tl.store(state_base + dst_offset, val, mask=dim_mask) - - # Apply activation (SiLU) - if ACTIVATION_SILU: - output_val = output_val * tl.sigmoid(output_val) - - # Store output - tl.store(mixed_qkv_ptr + input_offset, output_val, mask=dim_mask) - - # Copy state to next step if not last - if not is_last_step: - next_buffer_idx = tl.load(next_buffer_indices_ptr + batch_idx).to(tl.int64) - next_state_base = conv_states_ptr + next_buffer_idx * state_stride_buffer - - for w in range(conv_width - 1): - state_offset = dim_offsets * state_stride_dim + w * state_stride_width - val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) - tl.store(next_state_base + state_offset, val, mask=dim_mask) - - -def _get_conv1d_mtp_configs(): - """Generate candidate configurations for conv1d MTP kernel.""" - configs = [] - for block_dim in [64, 128, 256, 512]: - for num_warps in [2, 4, 8]: - for num_stages in [1, 2, 3]: - configs.append( - { - "BLOCK_DIM": block_dim, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - -def _get_conv1d_mtp_static_key( - mixed_qkv: torch.Tensor, - conv_states: torch.Tensor, - conv_weight: torch.Tensor, - mtp_size: int, -): - """Static key for conv1d MTP kernel.""" - return { - "dtype": str(mixed_qkv.dtype), - "dim_size": mixed_qkv.shape[-1], - "conv_width": conv_weight.shape[-1], - "mtp_size": mtp_size, - } - - -def _get_conv1d_mtp_run_key( - mixed_qkv: torch.Tensor, - conv_states: torch.Tensor, - conv_weight: torch.Tensor, - mtp_size: int, -): - """Run key for conv1d MTP kernel.""" - total_tokens = mixed_qkv.shape[0] - batch_size = total_tokens // mtp_size - return batch_size - - -@autotune( - kernel_name="gdn_conv1d_mtp:v1", - configs_gen_func=_get_conv1d_mtp_configs, - static_key_func=_get_conv1d_mtp_static_key, - run_key_func=_get_conv1d_mtp_run_key, - mutates_args=["mixed_qkv", "conv_states"], -) -def fused_conv1d_mtp_update( - mixed_qkv: torch.Tensor, - conv_states: torch.Tensor, - conv_weight: torch.Tensor, - conv_bias: torch.Tensor, - mtp_buffer_idx_list: list, - mtp_size: int, - activation_silu: bool = True, - run_config: dict = None, -): - """ - Fused conv1d update for all MTP steps. - - Args: - mixed_qkv: Input tensor [batch_size * mtp_size, dim] (interleaved) - conv_states: Conv state buffer [num_buffers, dim, conv_width-1] - conv_weight: Conv weights [dim, conv_width] - conv_bias: Conv bias [dim] - mtp_buffer_idx_list: List of buffer index tensors per step - mtp_size: Number of MTP steps - activation_silu: Whether to apply SiLU activation - run_config: Auto-tuned configuration - """ - total_tokens = mixed_qkv.shape[0] - batch_size = total_tokens // mtp_size - dim_size = mixed_qkv.shape[-1] - conv_width = conv_weight.shape[-1] - - if run_config is None: - BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_DIM = run_config["BLOCK_DIM"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) - - for step_idx in range(mtp_size): - is_last_step = step_idx == mtp_size - 1 - cur_indices = mtp_buffer_idx_list[step_idx] - next_indices = mtp_buffer_idx_list[step_idx + 1] if not is_last_step else cur_indices - - grid = (batch_size, num_blocks_dim) - - _fused_conv1d_mtp_step_kernel[grid]( - mixed_qkv, - conv_states, - conv_weight, - conv_bias, - cur_indices, - next_indices, - batch_size, - dim_size, - conv_width, - step_idx, - mtp_size, - is_last_step, - mixed_qkv.stride(0), - mixed_qkv.stride(-1) if mixed_qkv.ndim > 1 else 1, - conv_states.stride(0), - conv_states.stride(1), - conv_states.stride(2), - conv_weight.stride(0), - conv_weight.stride(1), - BLOCK_DIM=BLOCK_DIM, - ACTIVATION_SILU=activation_silu, - num_warps=num_warps, - num_stages=num_stages, - ) - - -@triton.jit -def _copy_ssm_state_kernel( - # SSM state buffer - ssm_states_ptr, - # Buffer indices - src_indices_ptr, - dst_indices_ptr, - # Dimensions - batch_size, - num_heads, - key_dim, - value_dim, - # Strides - state_stride_buffer, - state_stride_head, - state_stride_key, - state_stride_value, - # Block sizes - BLOCK_KEY: tl.constexpr, - BLOCK_VALUE: tl.constexpr, -): - """ - Copy SSM states from source indices to destination indices. - """ - batch_idx = tl.program_id(0) - head_idx = tl.program_id(1) - block_idx = tl.program_id(2) - - # Calculate block positions - num_value_blocks = tl.cdiv(value_dim, BLOCK_VALUE) - block_key_idx = block_idx // num_value_blocks - block_value_idx = block_idx % num_value_blocks - - key_start = block_key_idx * BLOCK_KEY - value_start = block_value_idx * BLOCK_VALUE - - key_offsets = key_start + tl.arange(0, BLOCK_KEY) - value_offsets = value_start + tl.arange(0, BLOCK_VALUE) - - key_mask = key_offsets < key_dim - value_mask = value_offsets < value_dim - mask = key_mask[:, None] & value_mask[None, :] - - # Load indices - src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) - dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) - - # Calculate offsets - src_base = ssm_states_ptr + src_idx * state_stride_buffer + head_idx * state_stride_head - dst_base = ssm_states_ptr + dst_idx * state_stride_buffer + head_idx * state_stride_head - - offsets = key_offsets[:, None] * state_stride_key + value_offsets[None, :] * state_stride_value - - # Copy data - data = tl.load(src_base + offsets, mask=mask, other=0.0) - tl.store(dst_base + offsets, data, mask=mask) - - -@triton.jit -def _copy_conv_state_kernel( - # Conv state buffer [num_buffers, dim, conv_width-1] - conv_states_ptr, - # Buffer indices - src_indices_ptr, - dst_indices_ptr, - # Dimensions - batch_size, - dim_size, - width_size, - num_width_blocks, # Precomputed to avoid runtime division - # Strides - state_stride_buffer, - state_stride_dim, - state_stride_width, - # Block sizes - BLOCK_DIM: tl.constexpr, - BLOCK_WIDTH: tl.constexpr, -): - """ - Copy conv states from source indices to destination indices. - - Conv state shape: [num_buffers, dim, conv_width-1] - """ - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - # Calculate block positions using precomputed num_width_blocks - block_dim_idx = block_idx // num_width_blocks - block_width_idx = block_idx % num_width_blocks - - dim_start = block_dim_idx * BLOCK_DIM - width_start = block_width_idx * BLOCK_WIDTH - - dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) - width_offsets = width_start + tl.arange(0, BLOCK_WIDTH) - - dim_mask = dim_offsets < dim_size - width_mask = width_offsets < width_size - mask = dim_mask[:, None] & width_mask[None, :] - - # Load indices - src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) - dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) - - # Calculate offsets - src_base = conv_states_ptr + src_idx * state_stride_buffer - dst_base = conv_states_ptr + dst_idx * state_stride_buffer - - offsets = dim_offsets[:, None] * state_stride_dim + width_offsets[None, :] * state_stride_width - - # Copy data - data = tl.load(src_base + offsets, mask=mask, other=0.0) - tl.store(dst_base + offsets, data, mask=mask) - - -def _get_conv_copy_configs(): - """Generate candidate configurations for conv state copy.""" - configs = [] - for block_dim in [64, 128, 256]: - for block_width in [2, 4, 8]: - for num_warps in [2, 4]: - configs.append( - { - "BLOCK_DIM": block_dim, - "BLOCK_WIDTH": block_width, - "num_warps": num_warps, - "num_stages": 2, - } - ) - return configs - - -def _get_conv_copy_static_key( - conv_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for conv copy.""" - return { - "dtype": str(conv_states.dtype), - "dim_size": conv_states.shape[1], - "width_size": conv_states.shape[2], - } - - -def _get_conv_copy_run_key( - conv_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for conv copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_conv_state_copy:v1", - configs_gen_func=_get_conv_copy_configs, - static_key_func=_get_conv_copy_static_key, - run_key_func=_get_conv_copy_run_key, - mutates_args=["conv_states"], -) -def copy_conv_states( - conv_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Copy conv states from source indices to destination indices. - - Args: - conv_states: Conv state buffer [num_buffers, dim, conv_width-1] - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - batch_size = src_indices.shape[0] - dim_size = conv_states.shape[1] - width_size = conv_states.shape[2] - - if run_config is None: - BLOCK_DIM = triton.next_power_of_2(min(dim_size, 128)) - BLOCK_WIDTH = triton.next_power_of_2(min(width_size, 4)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_DIM = run_config["BLOCK_DIM"] - BLOCK_WIDTH = run_config["BLOCK_WIDTH"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_dim_blocks = triton.cdiv(dim_size, BLOCK_DIM) - num_width_blocks = triton.cdiv(width_size, BLOCK_WIDTH) - num_blocks_total = num_dim_blocks * num_width_blocks - - grid = (batch_size, num_blocks_total) - - _copy_conv_state_kernel[grid]( - conv_states, - src_indices, - dst_indices, - batch_size, - dim_size, - width_size, - num_width_blocks, # Pass precomputed value - conv_states.stride(0), - conv_states.stride(1), - conv_states.stride(2), - BLOCK_DIM=BLOCK_DIM, - BLOCK_WIDTH=BLOCK_WIDTH, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def _get_ssm_copy_configs(): - """Generate candidate configurations for SSM state copy.""" - configs = [] - for block_key in [16, 32, 64]: - for block_value in [16, 32, 64, 128]: - for num_warps in [2, 4, 8]: - if block_key * block_value <= 4096: - configs.append( - { - "BLOCK_KEY": block_key, - "BLOCK_VALUE": block_value, - "num_warps": num_warps, - "num_stages": 2, - } - ) - return configs - - -def _get_ssm_copy_static_key( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for SSM copy.""" - return { - "dtype": str(ssm_states.dtype), - "num_heads": ssm_states.shape[1], - "key_dim": ssm_states.shape[2], - "value_dim": ssm_states.shape[3], - } - - -def _get_ssm_copy_run_key( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for SSM copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_ssm_state_copy:v1", - configs_gen_func=_get_ssm_copy_configs, - static_key_func=_get_ssm_copy_static_key, - run_key_func=_get_ssm_copy_run_key, - mutates_args=["ssm_states"], -) -def copy_ssm_states( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Copy SSM states from source indices to destination indices. - - Args: - ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - batch_size = src_indices.shape[0] - num_heads = ssm_states.shape[1] - key_dim = ssm_states.shape[2] - value_dim = ssm_states.shape[3] - - if run_config is None: - BLOCK_KEY = triton.next_power_of_2(min(key_dim, 32)) - BLOCK_VALUE = triton.next_power_of_2(min(value_dim, 64)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_KEY = run_config["BLOCK_KEY"] - BLOCK_VALUE = run_config["BLOCK_VALUE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_key_blocks = triton.cdiv(key_dim, BLOCK_KEY) - num_value_blocks = triton.cdiv(value_dim, BLOCK_VALUE) - num_blocks_total = num_key_blocks * num_value_blocks - - grid = (batch_size, num_heads, num_blocks_total) - - _copy_ssm_state_kernel[grid]( - ssm_states, - src_indices, - dst_indices, - batch_size, - num_heads, - key_dim, - value_dim, - ssm_states.stride(0), - ssm_states.stride(1), - ssm_states.stride(2), - ssm_states.stride(3), - BLOCK_KEY=BLOCK_KEY, - BLOCK_VALUE=BLOCK_VALUE, - num_warps=num_warps, - num_stages=num_stages, - ) - - -# ============================================================================= -# Optimized Flat Copy Kernels (for contiguous memory) -# ============================================================================= -# These kernels leverage the fact that both conv_states and ssm_states are -# contiguous in memory, allowing us to flatten the inner dimensions and use -# efficient 1D vectorized copy patterns. - - -@triton.jit -def _copy_state_flat_kernel( - # State buffer pointer (flattened view) - state_ptr, - # Buffer indices - src_indices_ptr, - dst_indices_ptr, - # Dimensions - batch_size, - flat_size, # Total elements per buffer entry (flattened inner dims) - # Strides - stride_buffer, # Stride to next buffer entry (in elements) - # Block size - BLOCK_SIZE: tl.constexpr, -): - """ - Optimized flat copy kernel for contiguous state buffers. - - Instead of using 2D/3D block patterns with stride calculations, this kernel - treats each buffer entry as a flat 1D array and uses vectorized loads/stores - for efficient memory transfer. - - Grid: (batch_size, num_blocks) where num_blocks = ceil(flat_size / BLOCK_SIZE) - """ - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - # Calculate element range for this block - elem_start = block_idx * BLOCK_SIZE - elem_offsets = elem_start + tl.arange(0, BLOCK_SIZE) - elem_mask = elem_offsets < flat_size - - # Load buffer indices for this batch item - src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) - dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) - - # Calculate source and destination base pointers - src_base = state_ptr + src_idx * stride_buffer - dst_base = state_ptr + dst_idx * stride_buffer - - # Vectorized copy - data = tl.load(src_base + elem_offsets, mask=elem_mask, other=0.0) - tl.store(dst_base + elem_offsets, data, mask=elem_mask) - - -@triton.jit -def _copy_states_fused_kernel( - # Conv state buffer (flattened view) - conv_state_ptr, - # SSM state buffer (flattened view) - ssm_state_ptr, - # Buffer indices - src_indices_ptr, - dst_indices_ptr, - # Dimensions - batch_size, - conv_flat_size, # Total elements per conv buffer entry - ssm_flat_size, # Total elements per ssm buffer entry - # Strides (in elements) - conv_stride_buffer, - ssm_stride_buffer, - # Block sizes - CONV_BLOCK_SIZE: tl.constexpr, - SSM_BLOCK_SIZE: tl.constexpr, -): - """ - Fused kernel to copy both conv_states and ssm_states in a single launch. - - This reduces kernel launch overhead by processing both state copies together. - Each thread block handles one batch item and copies both states sequentially. - - Grid: (batch_size, max(conv_blocks, ssm_blocks)) - """ - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - # Load buffer indices (same for both conv and ssm) - src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) - dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) - - # ========== Copy Conv State ========== - conv_num_blocks = tl.cdiv(conv_flat_size, CONV_BLOCK_SIZE) - if block_idx < conv_num_blocks: - conv_elem_start = block_idx * CONV_BLOCK_SIZE - conv_elem_offsets = conv_elem_start + tl.arange(0, CONV_BLOCK_SIZE) - conv_mask = conv_elem_offsets < conv_flat_size - - conv_src_base = conv_state_ptr + src_idx * conv_stride_buffer - conv_dst_base = conv_state_ptr + dst_idx * conv_stride_buffer - - conv_data = tl.load(conv_src_base + conv_elem_offsets, mask=conv_mask, other=0.0) - tl.store(conv_dst_base + conv_elem_offsets, conv_data, mask=conv_mask) - - # ========== Copy SSM State ========== - ssm_num_blocks = tl.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) - if block_idx < ssm_num_blocks: - ssm_elem_start = block_idx * SSM_BLOCK_SIZE - ssm_elem_offsets = ssm_elem_start + tl.arange(0, SSM_BLOCK_SIZE) - ssm_mask = ssm_elem_offsets < ssm_flat_size - - ssm_src_base = ssm_state_ptr + src_idx * ssm_stride_buffer - ssm_dst_base = ssm_state_ptr + dst_idx * ssm_stride_buffer - - ssm_data = tl.load(ssm_src_base + ssm_elem_offsets, mask=ssm_mask, other=0.0) - tl.store(ssm_dst_base + ssm_elem_offsets, ssm_data, mask=ssm_mask) - - -def _get_flat_copy_configs(): - """Generate candidate configurations for flat copy kernel.""" - configs = [] - # Larger block sizes for better memory throughput on contiguous data - for block_size in [256, 512, 1024, 2048]: - for num_warps in [4, 8]: - configs.append( - { - "BLOCK_SIZE": block_size, - "num_warps": num_warps, - "num_stages": 2, - } - ) - return configs - - -def _get_conv_flat_copy_static_key( - conv_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for conv flat copy.""" - return { - "dtype": str(conv_states.dtype), - "flat_size": conv_states.shape[1] * conv_states.shape[2], - } - - -def _get_conv_flat_copy_run_key( - conv_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for conv flat copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_conv_state_flat_copy:v1", - configs_gen_func=_get_flat_copy_configs, - static_key_func=_get_conv_flat_copy_static_key, - run_key_func=_get_conv_flat_copy_run_key, - mutates_args=["conv_states"], -) -def copy_conv_states_flat( - conv_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Optimized flat copy for conv states leveraging contiguous memory. - - Args: - conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - assert conv_states.is_contiguous(), "conv_states must be contiguous for flat copy" - - batch_size = src_indices.shape[0] - # Flatten inner dimensions - flat_size = conv_states.shape[1] * conv_states.shape[2] - stride_buffer = conv_states.stride(0) - - if run_config is None: - BLOCK_SIZE = 1024 - num_warps = 4 - num_stages = 2 - else: - BLOCK_SIZE = run_config["BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) - grid = (batch_size, num_blocks) - - _copy_state_flat_kernel[grid]( - conv_states, - src_indices, - dst_indices, - batch_size, - flat_size, - stride_buffer, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def _get_ssm_flat_copy_static_key( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for ssm flat copy.""" - return { - "dtype": str(ssm_states.dtype), - "flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], - } - - -def _get_ssm_flat_copy_run_key( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for ssm flat copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_ssm_state_flat_copy:v1", - configs_gen_func=_get_flat_copy_configs, - static_key_func=_get_ssm_flat_copy_static_key, - run_key_func=_get_ssm_flat_copy_run_key, - mutates_args=["ssm_states"], -) -def copy_ssm_states_flat( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Optimized flat copy for SSM states leveraging contiguous memory. - - Args: - ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - assert ssm_states.is_contiguous(), "ssm_states must be contiguous for flat copy" - - batch_size = src_indices.shape[0] - # Flatten inner dimensions (num_heads * key_dim * value_dim) - flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] - stride_buffer = ssm_states.stride(0) - - if run_config is None: - BLOCK_SIZE = 1024 - num_warps = 4 - num_stages = 2 - else: - BLOCK_SIZE = run_config["BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) - grid = (batch_size, num_blocks) - - _copy_state_flat_kernel[grid]( - ssm_states, - src_indices, - dst_indices, - batch_size, - flat_size, - stride_buffer, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def _get_fused_copy_configs(): - """Generate candidate configurations for fused copy kernel.""" - configs = [] - # Use power-of-2 block sizes for both conv and ssm - for conv_block in [256, 512, 1024]: - for ssm_block in [256, 512, 1024]: - for num_warps in [4, 8]: - configs.append( - { - "CONV_BLOCK_SIZE": conv_block, - "SSM_BLOCK_SIZE": ssm_block, - "num_warps": num_warps, - "num_stages": 2, - } - ) - return configs - - -def _get_fused_copy_static_key( - conv_states: torch.Tensor, - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for fused copy.""" - return { - "conv_dtype": str(conv_states.dtype), - "ssm_dtype": str(ssm_states.dtype), - "conv_flat_size": conv_states.shape[1] * conv_states.shape[2], - "ssm_flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], - } - - -def _get_fused_copy_run_key( - conv_states: torch.Tensor, - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for fused copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_states_fused_copy:v1", - configs_gen_func=_get_fused_copy_configs, - static_key_func=_get_fused_copy_static_key, - run_key_func=_get_fused_copy_run_key, - mutates_args=["conv_states", "ssm_states"], -) -def copy_states_fused( - conv_states: torch.Tensor, - ssm_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Fused copy for both conv and SSM states in a single kernel launch. - - This reduces kernel launch overhead by processing both state copies together. - - Args: - conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) - ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - assert conv_states.is_contiguous(), "conv_states must be contiguous for fused copy" - assert ssm_states.is_contiguous(), "ssm_states must be contiguous for fused copy" - - batch_size = src_indices.shape[0] - - # Flatten inner dimensions - conv_flat_size = conv_states.shape[1] * conv_states.shape[2] - ssm_flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] - - conv_stride_buffer = conv_states.stride(0) - ssm_stride_buffer = ssm_states.stride(0) - - if run_config is None: - CONV_BLOCK_SIZE = 512 - SSM_BLOCK_SIZE = 512 - num_warps = 4 - num_stages = 2 - else: - CONV_BLOCK_SIZE = run_config["CONV_BLOCK_SIZE"] - SSM_BLOCK_SIZE = run_config["SSM_BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - # Grid covers both conv and ssm blocks - conv_num_blocks = triton.cdiv(conv_flat_size, CONV_BLOCK_SIZE) - ssm_num_blocks = triton.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) - max_blocks = max(conv_num_blocks, ssm_num_blocks) - grid = (batch_size, max_blocks) - - _copy_states_fused_kernel[grid]( - conv_states, - ssm_states, - src_indices, - dst_indices, - batch_size, - conv_flat_size, - ssm_flat_size, - conv_stride_buffer, - ssm_stride_buffer, - CONV_BLOCK_SIZE=CONV_BLOCK_SIZE, - SSM_BLOCK_SIZE=SSM_BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) diff --git a/lightllm/models/qwen3next_mtp/__init__.py b/lightllm/models/qwen3next_mtp/__init__.py deleted file mode 100644 index 779237817d..0000000000 --- a/lightllm/models/qwen3next_mtp/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel - -__all__ = ["Qwen3NextMTPModel"] diff --git a/lightllm/models/qwen3next_mtp/layer_infer/__init__.py b/lightllm/models/qwen3next_mtp/layer_infer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py deleted file mode 100644 index ef3fe38153..0000000000 --- a/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch - -from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer - - -class Qwen3NextMTPPreLayerInfer(LlamaPreLayerInfer): - """ - Qwen3Next MTP Pre-Layer Inference. - Similar to DeepSeek MTP but with different weight structure. - - MTP forward flow: - 1. Get embedding from input_ids - 2. Get hidden state from main model (passed via infer_state) - 3. Normalize embedding with pre_fc_norm_embedding - 4. Normalize hidden with pre_fc_norm_hidden - 5. Concat normalized embedding and hidden - 6. Project through fc to get hidden_dim output - """ - - def __init__(self, network_config): - super().__init__(network_config) - self.eps_ = network_config["rms_norm_eps"] - self.hidden_size = network_config["hidden_size"] - return - - def _mtp_forward( - self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight - ): - tgt_embdings = infer_state.mtp_draft_input_hiddens - assert input_embdings.shape[0] == tgt_embdings.shape[0] - - # Normalize embedding - input_embdings_normed = layer_weight.pre_fc_norm_embedding_weight_(input=input_embdings, eps=self.eps_) - - # Normalize hidden state - tgt_embdings_normed = layer_weight.pre_fc_norm_hidden_weight_(input=tgt_embdings, eps=self.eps_) - - # Concat normalized embedding and hidden - cat_embdings = torch.cat((input_embdings_normed, tgt_embdings_normed), dim=-1) - - # Project to hidden_size - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.fc_weight_.shape[1]), dtype=cat_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.fc_weight_, out=ans_logics) - - return ans_logics - - def context_forward( - self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight - ): - input_embdings = super().context_forward(input_ids, infer_state, layer_weight) - return self._mtp_forward(input_embdings, infer_state, layer_weight) - - def token_forward( - self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight - ): - input_embdings = super().token_forward(input_ids, infer_state, layer_weight) - return self._mtp_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py deleted file mode 100644 index 03630c17c1..0000000000 --- a/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,30 +0,0 @@ -from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import Qwen3NextFullAttentionBaseLayerInfer -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class Qwen3NextMTPTransformerLayerInfer(Qwen3NextFullAttentionBaseLayerInfer): - """ - Qwen3Next MTP Transformer Layer Inference. - MTP layers use full attention (not linear attention) with MoE FFN and shared expert. - Inherits shared methods from Qwen3NextFullAttentionBaseLayerInfer. - """ - - def __init__(self, layer_num, network_config): - super().__init__(layer_num, network_config) - self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) - self.tp_v_head_num_ = max(self.tp_v_head_num_, 1) - return - - def _bind_ffn(self): - """MTP always uses shared expert + MoE""" - from functools import partial - import os - - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_ep, self) - else: - self._ffn = partial(Qwen3NextFullAttentionBaseLayerInfer._ffn_with_shared_expert_tp, self) - return diff --git a/lightllm/models/qwen3next_mtp/layer_weights/__init__.py b/lightllm/models/qwen3next_mtp/layer_weights/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index 8a74ef8567..0000000000 --- a/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,47 +0,0 @@ -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NoTpGEMMANormWeight - - -class Qwen3NextMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config): - super().__init__(data_type, network_config) - self.wte_weight_ = None - self.lm_head_weight_ = None - - hidden_size = network_config["hidden_size"] - # Use Gemma-style normalization for all MTP norm layers - self.final_norm_weight_ = NoTpGEMMANormWeight( - dim=hidden_size, - weight_name="mtp.norm.weight", - data_type=self.data_type_, - ) - self.pre_fc_norm_embedding_weight_ = NoTpGEMMANormWeight( - dim=hidden_size, - weight_name="mtp.pre_fc_norm_embedding.weight", - data_type=self.data_type_, - ) - self.pre_fc_norm_hidden_weight_ = NoTpGEMMANormWeight( - dim=hidden_size, - weight_name="mtp.pre_fc_norm_hidden.weight", - data_type=self.data_type_, - ) - return - - def load_hf_weights(self, weights): - if "mtp.fc.weight" in weights: - self.fc_weight_ = self._cuda(weights["mtp.fc.weight"]).t() - - # Load weights for norm weight objects - self.final_norm_weight_.load_hf_weights(weights) - self.pre_fc_norm_embedding_weight_.load_hf_weights(weights) - self.pre_fc_norm_hidden_weight_.load_hf_weights(weights) - - return - - def verify_load(self): - # Verify all norm weights loaded correctly - return ( - self.final_norm_weight_.verify_load() - and self.pre_fc_norm_embedding_weight_.verify_load() - and self.pre_fc_norm_hidden_weight_.verify_load() - ) diff --git a/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py deleted file mode 100644 index d52da5647d..0000000000 --- a/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,141 +0,0 @@ -import os -import torch -import math -import numpy as np -from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.utils.envs_utils import enable_env_vars -from lightllm.common.basemodel.layer_weights.meta_weights import ( - ROWMMWeight, - COLMMWeight, - RMSNormWeight, - QKRMSNORMWeight, - KVROWNMMWeight, -) -from functools import partial - - -class Qwen3NextMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - super().__init__(layer_num, data_type, network_config, quant_cfg) - return - - def _init_weight_names(self): - self._q_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.q_proj.weight" - self._q_norm_name = f"mtp.layers.{self.layer_num_}.self_attn.q_norm.weight" - self._q_bias_name = None - self._k_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.k_proj.weight" - self._k_norm_name = f"mtp.layers.{self.layer_num_}.self_attn.k_norm.weight" - self._k_bias_name = None - self._v_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.v_proj.weight" - self._v_bias_name = None - self._kv_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.kv_proj.weight" - self._kv_bias_name = None - self._o_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.o_proj.weight" - self._o_bias_name = None - self._att_norm_weight_name = f"mtp.layers.{self.layer_num_}.input_layernorm.weight" - self._att_norm_bias_name = None - self._ffn_norm_weight_name = f"mtp.layers.{self.layer_num_}.post_attention_layernorm.weight" - self._ffn_norm_bias_name = None - - def _init_qkv(self): - # Override parent's QKVROWNMMWeight which requires kv_head_num % tp == 0. - # Qwen3-Next has few KV heads; KVROWNMMWeight handles repeating. - in_dim = self.n_embed - q_out_dim = self.q_head_num_ * self.head_dim - self.q_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=self._q_weight_name, - data_type=self.data_type_, - bias_names=self._q_bias_name, - quant_method=self.get_quant_method("q_proj"), - ) - self.kv_proj = KVROWNMMWeight( - in_dim=in_dim, - kv_head_num=self.k_head_num_, - head_dim=self.head_dim, - weight_names=[self._k_weight_name, self._v_weight_name], - data_type=self.data_type_, - bias_names=[self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("kv_proj"), - ) - - def _init_weight(self): - self._init_moe() - self._init_shared_expert_weight() - - hidden_size = self.network_config_["hidden_size"] - self.att_norm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=self._att_norm_weight_name, - data_type=self.data_type_, - ) - self.ffn_norm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=self._ffn_norm_weight_name, - data_type=self.data_type_, - ) - - self._init_qkv() - self._init_o() - self.q_norm_weight_ = QKRMSNORMWeight( - dim=self.head_dim, weight_name=self._q_norm_name, data_type=self.data_type_ - ) - self.k_norm_weight_ = QKRMSNORMWeight( - dim=self.head_dim, weight_name=self._k_norm_name, data_type=self.data_type_ - ) - self._o_gate_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" - q_out_dim = self.q_head_num_ * self.head_dim - self.o_gate_proj = ROWMMWeight( - in_dim=self.n_embed, - out_dims=[q_out_dim], - weight_names=self._o_gate_weight_name, - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), - ) - return - - def load_hf_weights(self, weights): - self._split_q_with_gate(weights) - super().load_hf_weights(weights) - - def _init_shared_expert_weight(self): - prefix = f"mtp.layers.{self.layer_num_}.mlp.shared_expert" - hidden_size = self.network_config_["hidden_size"] - shared_inter = self.network_config_["shared_expert_intermediate_size"] - self.shared_expert_gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[shared_inter, shared_inter], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_gate_up_proj"), - ) - self.shared_expert_down_proj = COLMMWeight( - in_dim=shared_inter, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_expert_down_proj"), - ) - self.shared_expert_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"mtp.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) - - def _split_q_with_gate(self, weights): - if self._q_weight_name in weights: - weight = weights[self._q_weight_name] - num_heads = self.q_head_num_ - weight = weight.view(num_heads * 2, self.head_dim, -1) - _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) - _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) - weights[self._q_weight_name] = _q_proj - weights[self._o_gate_weight_name] = _gate_proj diff --git a/lightllm/models/qwen3next_mtp/model.py b/lightllm/models/qwen3next_mtp/model.py deleted file mode 100644 index 92e4918bea..0000000000 --- a/lightllm/models/qwen3next_mtp/model.py +++ /dev/null @@ -1,101 +0,0 @@ -from lightllm.models.qwen3next.model import Qwen3NextTpPartModel -from lightllm.models.qwen3next_mtp.layer_infer.pre_layer_infer import Qwen3NextMTPPreLayerInfer -from lightllm.models.qwen3next_mtp.layer_infer.transformer_layer_infer import Qwen3NextMTPTransformerLayerInfer -from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight -from lightllm.models.qwen3next_mtp.layer_weights.transformer_layer_weight import Qwen3NextMTPTransformerLayerWeight -from lightllm.common.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights -from lightllm.models.registry import ModelRegistry - - -@ModelRegistry("qwen3next_mtp") -class Qwen3NextMTPModel(Qwen3NextTpPartModel): - - pre_and_post_weight_class = Qwen3NextMTPPreAndPostLayerWeight - pre_layer_infer_class = Qwen3NextMTPPreLayerInfer - transformer_weight_class = Qwen3NextMTPTransformerLayerWeight - transformer_layer_infer_class = Qwen3NextMTPTransformerLayerInfer - - def __init__(self, kvargs: dict): - self.mtp_n_layers = 1 - self._pre_init(kvargs) - super().__init__(kvargs) - return - - def _pre_init(self, kvargs: dict): - """Extract main model and memory layer start from kwargs.""" - self.main_model: TpPartBaseModel = kvargs.pop("main_model") - self.mem_layer_start = kvargs.pop("mem_layer_start") - return - - def autotune_layers(self): - return 1 - - def _init_some_value(self): - self.layers_num = self.mtp_n_layers - - def _init_config(self): - super()._init_config() - self.config["n_layers"] = self.mtp_n_layers - self.config["num_hidden_layers"] = self.mtp_n_layers - return - - def _init_custom(self): - """Initialize custom components, sharing cos/sin cache with main model.""" - self._cos_cached = self.main_model._cos_cached - self._sin_cached = self.main_model._sin_cached - return - - def _init_req_manager(self): - """Share request manager with main model.""" - self.req_manager = self.main_model.req_manager - return - - def _init_mem_manager(self): - """Share memory manager with main model.""" - self.mem_manager = self.main_model.mem_manager - return - - def _check_mem_size(self): - """Skip mem size check for MTP models since they share memory with main model.""" - self.max_total_token_num = self.mem_manager.size - return - - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - quant_cfg=self.quant_cfg, - ) - for i in range(self.mtp_n_layers) - ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ - self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ - return - - def _init_infer_layer(self): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config) - self.post_infer = self.post_layer_infer_class(network_config=self.config) - self.layers_infer = [ - self.transformer_layer_infer_class( - i * self.config["full_attention_interval"] - 1, # Ensure full attention layer - network_config=self.config, - ) - for i in range(self.mtp_n_layers) - ] - # Ensure full attention layer - for i, layer in enumerate(self.layers_infer): - layer.layer_num_ = i + self.mem_layer_start - return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 1f7a31351d..a18156324e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -41,7 +41,6 @@ from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.models.mistral_mtp.model import MistralMTPModel -from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token @@ -352,8 +351,6 @@ def init_mtp_draft_model(self, main_kvargs: dict): elif model_type == "mistral": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) - elif model_type == "qwen3_next": - self.draft_models.append(Qwen3NextMTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "glm4_moe_lite": assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) From ec499ce8ac2d1eb82747749bbf10c0edd231c873 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 18 Mar 2026 12:51:32 +0000 Subject: [PATCH 122/180] openai api simplify --- lightllm/server/api_cli.py | 2 -- lightllm/server/api_start.py | 3 +-- lightllm/server/build_prompt.py | 34 +++------------------------------ 3 files changed, 4 insertions(+), 35 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 18eb16d9ac..8ff03f3e29 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -584,8 +584,6 @@ def make_argument_parser() -> argparse.ArgumentParser: "eagle_with_att", "vanilla_no_att", "eagle_no_att", - "qwen3next_vanilla", - "qwen3next_eagle", None, ], default=None, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 69cadfbb4f..77355f0d06 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -162,8 +162,7 @@ def normal_or_p_d_start(args): # mtp params check if args.mtp_mode is not None: - if args.mtp_draft_model_dir is None: - args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step + assert args.mtp_draft_model_dir is not None assert args.mtp_step > 0 else: assert args.mtp_draft_model_dir is None diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index c91e8a2e09..a38008af6f 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -44,28 +44,9 @@ def init_tokenizer(args): async def build_prompt(request, tools) -> str: global tokenizer - import json - # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] - # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility - # Qwen's chat template expects arguments to be a dict (uses |items filter) - # but OpenAI format sends arguments as a JSON string - for msg in messages: - tool_calls = msg.get("tool_calls") - if tool_calls and isinstance(tool_calls, list): - for tool_call in tool_calls: - func = tool_call.get("function") - if func and isinstance(func, dict): - args = func.get("arguments") - if isinstance(args, str) and args: - try: - func["arguments"] = json.loads(args) - except (json.JSONDecodeError, TypeError): - # Keep original string if not valid JSON - pass - kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -77,16 +58,7 @@ async def build_prompt(request, tools) -> str: try: input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools) - except: - # This except branch will be triggered when the chosen model - # has a different tools input format that is not compatiable - # with openAI's apply_chat_template tool_call format, like Mistral. - if tools is not None: - tools = [t if "function" in t else {"function": t} for t in tools] - input_str = tokenizer.apply_chat_template( - **kwargs, - tokenize=True, - add_generation_prompt=True, - tools=tools, - ) + except BaseException as e: + logger.error(f"Failed to build prompt: {e}") + raise e return input_str From 3c8597d6eba1537f91bcfb8da9f949057441c6ab Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 18 Mar 2026 14:28:10 +0000 Subject: [PATCH 123/180] simplify mem manager --- lightllm/common/allocator_utils.py | 98 ---------------- .../kv_cache_mem_manager/mem_manager.py | 108 +++++++++++++++--- .../mamba_cache_mem_manager/cache_manager.py | 84 +++++++++++++- .../layer_infer/transformer_layer_infer.py | 86 ++++++-------- 4 files changed, 205 insertions(+), 171 deletions(-) delete mode 100644 lightllm/common/allocator_utils.py diff --git a/lightllm/common/allocator_utils.py b/lightllm/common/allocator_utils.py deleted file mode 100644 index 803ed0a715..0000000000 --- a/lightllm/common/allocator_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import List, Union - -import torch - -from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class TokenAllocator: - def __init__(self, size, shared_can_use_token_num_name: str): - self.size = size - - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size - - self.can_use_mem_size = self.size - - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - self.shared_can_use_token_num = SharedInt(shared_can_use_token_num_name) - - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.HOLD_TOKEN_MEMINDEX = self.size - - def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """_summary_ - - Args: - free_index (torch.Tensor): _description_ - """ - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) - self.mem_state[start:end] = free_index_tensor - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - return - - def free_all(self): - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) - - def resize_mem(self, new_size): - """ - just for test code - """ - self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - return diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 8d6fb48c28..1203cbdec7 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -18,17 +18,14 @@ from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm -from lightllm.common.allocator_utils import TokenAllocator from multiprocessing.reduction import ForkingPickler from filelock import FileLock logger = init_logger(__name__) -KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_kv_cache_token_can_use_num" - -class MemoryManager(TokenAllocator): +class MemoryManager: def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num @@ -39,8 +36,27 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False # profile the max total token num if the size is None self.profile_size(mem_fraction) - super().__init__(self.size, f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + + self.can_use_mem_size = self.size + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name + + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) + + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, @@ -48,6 +64,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False head_dim, layer_num, ) + self.HOLD_TOKEN_MEMINDEX = self.size def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ @@ -324,13 +341,59 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index]} + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + self.mem_state.numpy()[start:end] = free_index + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + return + + def free_all(self): + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) - # 重写resize_mem方法,添加_free_buffers和_init_buffers调用 def resize_mem(self, new_size): """ just for test code @@ -341,13 +404,24 @@ def resize_mem(self, new_size): head_dim = self.head_dim layer_num = self.layer_num - # 调用父类的resize_mem - super().resize_mem(new_size) - + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._free_buffers() self._init_buffers(size, dtype, head_num, head_dim, layer_num) return + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kv_buffer[:, index]} + + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + def copy_kv_from_other_dp_ranks( self, mem_managers: List["MemoryManager"], @@ -439,12 +513,12 @@ def __init__(self) -> None: self.dp_world_size = self.global_world_size // args.dp # 兼容多机 dp size=1 纯 tp 模式的情况 self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 - self.shared_tp_can_use_token_nums = [ - SharedInt(f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") + self.shared_tp_infos = [ + SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] def get_unrefed_token_num(self, dp_rank_in_node: int): if self.is_multinode_tp: - return self.shared_tp_can_use_token_nums[0].get_value() - return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() + return self.shared_tp_infos[0].get_value() + return self.shared_tp_infos[dp_rank_in_node].get_value() diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index fe5ac093e0..9d2d372e17 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -5,7 +5,6 @@ from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args -from lightllm.common.allocator_utils import TokenAllocator from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_mamba_buffer, fork_mamba_buffer from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt @@ -28,7 +27,7 @@ def get_cell_size(self): return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) -class MambaCacheManager(TokenAllocator): +class MambaCacheManager: def __init__( self, size: int, @@ -38,7 +37,23 @@ def __init__( ssm_state_dtype: torch.dtype, ssm_state_shape: Tuple[int, ...], ): - super().__init__(size, f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + # init the mem state + self.size = size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num = SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.HOLD_TOKEN_MEMINDEX = self.size + + # init the layer cache self.conv_state_cache = LayerCache(size, conv_state_dtype, conv_state_shape, layer_num) self.ssm_state_cache = LayerCache(size, ssm_state_dtype, ssm_state_shape, layer_num) self.HOLD_BUFFER_INDEX = size @@ -83,6 +98,26 @@ def fork_ssm_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: t self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes ) + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" + + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + def free(self, free_index: Union[torch.Tensor, List[int]]): """ Free the allocated cache buffers and clear them. @@ -101,14 +136,51 @@ def free(self, free_index: Union[torch.Tensor, List[int]]): self.conv_state_cache.buffer[:, free_index_tensor, ...] = 0 self.ssm_state_cache.buffer[:, free_index_tensor, ...] = 0 - # Call parent's free method to update allocator state - super().free(free_index) + # update the mem state + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) + self.mem_state[start:end] = free_index_tensor + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + return def free_all(self): self.conv_state_cache.buffer.fill_(0) self.ssm_state_cache.buffer.fill_(0) - super().free_all() + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) + + return + + def resize_mem(self, new_size): + """ + just for test code + """ + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) return diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 6e2f8d7c9c..69d48b30f6 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -47,13 +47,46 @@ def __init__(self, layer_num, network_config): self._init_linear_layer_metadata(layer_num, network_config) return + def _init_linear_layer_metadata(self, layer_num, network_config): + + # Linear attention specific dimensions + self.num_v_heads = network_config["linear_num_value_heads"] + self.num_k_heads = network_config["linear_num_key_heads"] + self.head_k_dim = network_config["linear_key_head_dim"] + self.head_v_dim = network_config["linear_value_head_dim"] + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_dim = network_config["linear_conv_kernel_dim"] + self.activation = network_config["hidden_act"] + + # Tensor parallelism dimensions + self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ + self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ + self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ + self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ + self.tp_key_dim = self.key_dim // self.tp_world_size_ + self.tp_value_dim = self.value_dim // self.tp_world_size_ + + assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" + self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + # SSM state dtype optimization + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + start_args = get_env_start_args() + self.ssm_state_dtype = ssm_dtype_dict.get(start_args.mamba_ssm_data_type, torch.bfloat16) + + # Pre-compute whether dtype conversion is needed + # GDN kernel output dtype is self.data_type + # Conversion needed only if SSM state uses different dtype + self.needs_ssm_dtype_conversion = get_llm_data_type() != self.ssm_state_dtype + return + def _bind_func(self): super()._bind_func() self._bind_ffn() return def _bind_ffn(self): - """Bind FFN implementation based on MoE configuration.""" if self.is_moe: moe_mode = os.environ.get("MOE_MODE", "TP") if moe_mode == "EP": @@ -76,7 +109,6 @@ def _compute_shared_expert( def _moe_ffn( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - """MoE FFN with tensor parallelism.""" shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) @@ -99,7 +131,6 @@ def _moe_ffn( def _moe_ffn_edp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - """MoE FFN with expert parallelism.""" shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input token_num, hidden_dim = hidden_states.shape @@ -124,9 +155,6 @@ def _get_qkv( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - QKV projection with output gating, Q/K normalization, and partial rotary embedding. - """ input = input.view(-1, self.embed_dim_) qkv_out = layer_weight.qkv_proj.mm(input) q, cache_kv = qkv_out.split( @@ -164,40 +192,6 @@ def _get_o( o_tensor = layer_weight.o_proj.mm(input) return o_tensor - def _init_linear_layer_metadata(self, layer_num, network_config): - - # Linear attention specific dimensions - self.num_v_heads = network_config["linear_num_value_heads"] - self.num_k_heads = network_config["linear_num_key_heads"] - self.head_k_dim = network_config["linear_key_head_dim"] - self.head_v_dim = network_config["linear_value_head_dim"] - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - self.conv_kernel_dim = network_config["linear_conv_kernel_dim"] - self.activation = network_config["hidden_act"] - - # Tensor parallelism dimensions - self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ - self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ - self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ - self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ - self.tp_key_dim = self.key_dim // self.tp_world_size_ - self.tp_value_dim = self.value_dim // self.tp_world_size_ - - assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" - self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads - - # SSM state dtype optimization - ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} - start_args = get_env_start_args() - self.ssm_state_dtype = ssm_dtype_dict.get(start_args.mamba_ssm_data_type, torch.bfloat16) - - # Pre-compute whether dtype conversion is needed - # GDN kernel output dtype is self.data_type - # Conversion needed only if SSM state uses different dtype - self.needs_ssm_dtype_conversion = get_llm_data_type() != self.ssm_state_dtype - return - # ==================== GDN Helper Methods ==================== def context_attention_forward( @@ -236,15 +230,12 @@ def gdn_forward( ): assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) - # Common preprocessing input = input.view(-1, self.embed_dim_) conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) mixed_qkvzba = layer_weight.linear_in_proj.mm(input) - # mixed_qkv is now returned pre-concatenated (no torch.cat needed) mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) - # Dispatch to appropriate kernel if is_prefill: # Prefill: compute g/beta upfront (chunk kernel doesn't support fused gating) g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) @@ -255,24 +246,20 @@ def gdn_forward( # Decode (non-MTP): fuse gating into recurrent kernel to save 2 kernel launches core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) - # Common postprocessing - num_tokens = z.shape[0] # batch (decode) or total_tokens (prefill/MTP) + num_tokens = z.shape[0] core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) gated_rmsnorm_forward( core_attn_out, layer_weight.linear_norm.weight, - None, # RMSNormWeight has no bias + None, self.eps_, z, out=norm_out, ) - # Merge head and value dims in a single view: (num_tokens * HV, V) → (num_tokens, HV * V) core_attn_out = norm_out.view(num_tokens, -1) - output = layer_weight.linear_out_proj.mm(core_attn_out) - # Note: all_reduce is handled by context_forward/token_forward callers return output def _split_qkvzba(self, mixed_qkvzba, is_decode=False): @@ -352,7 +339,6 @@ def _gdn_prefill_kernel( head_first=False, use_qk_l2norm_in_kernel=True, ) - # Use pre-computed dtype conversion flag to avoid runtime check if self.needs_ssm_dtype_conversion: ssm_states[infer_state.b_buffer_idx] = last_recurrent_state.to(self.ssm_state_dtype, copy=False) else: From 20edcc1a176d50e202ce8bfb69cb66a7d04e7052 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 19 Mar 2026 05:56:52 +0000 Subject: [PATCH 124/180] slime code --- lightllm/models/qwen3_5/infer_struct.py | 99 +------------------ lightllm/models/qwen3next/infer_struct.py | 50 +--------- .../layer_infer/transformer_layer_infer.py | 4 - lightllm/utils/config_utils.py | 2 + 4 files changed, 7 insertions(+), 148 deletions(-) diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py index 9ce407cacf..d837c4d291 100644 --- a/lightllm/models/qwen3_5/infer_struct.py +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -1,11 +1,3 @@ -""" -Qwen3.5 Multimodal Inference State - -This module provides inference state for Qwen3.5 multimodal model that combines: -- Qwen3Next features (output gating, MTP-aware batching, hybrid attention buffer management) -- Qwen3VL multimodal support (mrope position encoding for images/videos) -""" - import torch from typing import List @@ -14,97 +6,12 @@ class Qwen35InferStateInfo(Qwen2VLInferStateInfo): - """ - Inference state for Qwen3.5 multimodal model with: - - gate_value attribute for output gating in full attention layers - - MTP-aware batching for multi-token prediction - - Custom buffer management for hybrid attention (full + linear) - - mrope position encoding support for multimodal inputs - """ - def __init__(self): super().__init__() - # For output gating in full attention layers (from Qwen3Next) self.gate_value = None - # MTP-aware attributes (from Qwen3Next) - self.b_att_seq_len = None - self.att_batch_size = None - self.real_req_idx = None - self.mtp_buffer_idx_list = None - self.b_buffer_idx = None - - def _compute_mrope_delta(self, images: List) -> int: - """Compute the position delta for mrope based on image tokens. - - The position delta is the sum of all image position deltas (grid_thwd[3]) - which accounts for the extra position IDs consumed by multimodal content. - """ - position_delta = 0 - for image in images: - position_delta += image["grid_thwd"][3] - return position_delta def init_some_extra_state(self, model): - """Initialize Qwen3.5-specific state including mrope and MTP support""" - # First, initialize mrope position encoding using parent class - # which now has the corrected delta computation - rope_scaling = model.config.get("rope_scaling", {}) - self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) - - # Call the grandparent's (LlamaInferStateInfo) init_some_extra_state first - # to set up basic state - from lightllm.common.basemodel.infer_struct import InferStateInfo - - InferStateInfo.init_some_extra_state(self, model) - - # Now handle mrope position encoding with corrected delta computation - if self.is_prefill: - self.position_ids = self.get_mrope_position(self.multimodal_params) - else: - # Decode phase: compute correct mrope delta - b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] - for batch_idx, p in enumerate(self.multimodal_params): - b_position_delta[batch_idx] = self._compute_mrope_delta(p.get("images", [])) - - position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) - self.position_ids = position_ids.unsqueeze(0).expand(3, -1) - - self.position_ids = self.position_ids.contiguous() - self.position_cos = model._cos_cached[self.position_ids] - self.position_sin = model._sin_cached[self.position_ids] - - # Now handle MTP-aware batching (from Qwen3Next) - args_mtp_step = get_env_start_args().mtp_step - mtp_size = args_mtp_step + 1 - - if self.is_prefill: - # Prefill: Standard initialization - self.b_att_seq_len = self.b_seq_len - self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() - else: - # Decode: MTP-aware handling - # In MTP mode, each request has (mtp_step + 1) tokens - # att_batch_size is the number of unique requests - self.att_batch_size = self.batch_size // mtp_size - - # Use only the sequence lengths for the last token of each MTP group - if args_mtp_step > 0: - self.b_att_seq_len = self.b_seq_len[args_mtp_step::mtp_size].contiguous() - self.real_req_idx = self.b_req_idx[args_mtp_step::mtp_size] - else: - self.b_att_seq_len = self.b_seq_len - self.real_req_idx = self.b_req_idx - - # Buffer indices for Mamba cache (conv and SSM states) - self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.real_req_idx, :].flatten().contiguous() - - # Create per-step buffer indices for MTP - if args_mtp_step > 0: - buffer_idx_list = [] - for step_id in range(mtp_size): - buffer_idx_list.append(self.b_buffer_idx[step_id::mtp_size].tolist()) - self.mtp_buffer_idx_list = torch.tensor( - buffer_idx_list, dtype=torch.int32, device=self.b_buffer_idx.device - ) - + super().init_some_extra_state(model) + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() return diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py index 2883534a93..cd7c8d908d 100644 --- a/lightllm/models/qwen3next/infer_struct.py +++ b/lightllm/models/qwen3next/infer_struct.py @@ -4,59 +4,13 @@ class Qwen3NextInferStateInfo(LlamaInferStateInfo): - """ - Inference state for Qwen3Next with: - - gate_value attribute for output gating in full attention layers - - MTP-aware batching for multi-token prediction - - Custom buffer management for hybrid attention (full + linear) - """ - def __init__(self): super().__init__() - # For output gating in full attention layers self.gate_value = None - # MTP-aware attributes - self.b_att_seq_len = None - self.att_batch_size = None - self.real_req_idx = None - self.mtp_buffer_idx_list = None - self.b_buffer_idx = None def init_some_extra_state(self, model): - """Initialize Qwen3Next-specific state""" super().init_some_extra_state(model) - - args_mtp_step = get_env_start_args().mtp_step - mtp_size = args_mtp_step + 1 - - if self.is_prefill: - # Prefill: Standard initialization - self.b_att_seq_len = self.b_seq_len - self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() - else: - # Decode: MTP-aware handling - # In MTP mode, each request has (mtp_step + 1) tokens - # att_batch_size is the number of unique requests - self.att_batch_size = self.batch_size // mtp_size - - # Use only the sequence lengths for the last token of each MTP group - if args_mtp_step > 0: - self.b_att_seq_len = self.b_seq_len[args_mtp_step::mtp_size].contiguous() - self.real_req_idx = self.b_req_idx[args_mtp_step::mtp_size] - else: - self.b_att_seq_len = self.b_seq_len - self.real_req_idx = self.b_req_idx - - # Buffer indices for Mamba cache (conv and SSM states) - self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.real_req_idx, :].flatten().contiguous() - - # Create per-step buffer indices for MTP - if args_mtp_step > 0: - buffer_idx_list = [] - for step_id in range(mtp_size): - buffer_idx_list.append(self.b_buffer_idx[step_id::mtp_size].tolist()) - self.mtp_buffer_idx_list = torch.tensor( - buffer_idx_list, dtype=torch.int32, device=self.b_buffer_idx.device - ) + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() return diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 69d48b30f6..ec07b38c5a 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -237,13 +237,11 @@ def gdn_forward( mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) if is_prefill: - # Prefill: compute g/beta upfront (chunk kernel doesn't support fused gating) g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) core_attn_out = self._gdn_prefill_kernel( mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight ) else: - # Decode (non-MTP): fuse gating into recurrent kernel to save 2 kernel launches core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) num_tokens = z.shape[0] @@ -355,8 +353,6 @@ def _gdn_decode_kernel( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - """Decode kernel for GDN forward pass (single-token, non-MTP mode). - Uses fused gating: g/beta computed inline in the recurrent kernel.""" mixed_qkv = causal_conv1d_update( mixed_qkv, conv_states, diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 7d7397beaf..a4fbc594bc 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -202,6 +202,8 @@ def has_vision_module(model_path: str) -> bool: ): # Qwen3OmniMoeVisionTransformerPretrainedModel return True + elif model_type in ["qwen3_5", "qwen3_5_moe"]: + return True else: raise Exception("unknown vision model type") except: From eed986378e061147370cf7cafd1070c4a82a25f4 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 19 Mar 2026 06:15:57 +0000 Subject: [PATCH 125/180] remove mtp of base_backend --- lightllm/models/qwen3_vl/qwen3_visual.py | 7 ----- .../mode_backend/chunked_prefill/impl.py | 17 ----------- .../mode_backend/dp_backend/impl.py | 30 ------------------- 3 files changed, 54 deletions(-) diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index f636715033..bed8898115 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -381,15 +381,8 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - orig_size = image_data.size pixel_values, image_grid_thw = self.processor.preprocess(image_data) - # Debug logging for image processing - logger.debug( - f"[VISUAL_DEBUG] Image {i}: orig_size={orig_size}, " - f"pixel_values.shape={pixel_values.shape}, grid_thw={image_grid_thw}" - ) - img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 85d1e01b9c..2039a28d32 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -261,23 +261,6 @@ def decode_mtp( gpu_tensor=mtp_accept_len, ) - # Copy accepted buffer states back to buffer[0] for MTP - # Only copy when accept_len > 1 (accept_len == 1 means buffer[0] is already correct) - mask = mtp_accept_len > 1 - if mask.sum() > 0: - actual_req_idxes = model_input.b_req_idx[b_req_mtp_start_loc[mask]] - # Source: the accepted buffer (at index accept_len - 1) - src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ - actual_req_idxes, mtp_accept_len[mask] - 1 - ] - # Destination: buffer[0] for each request - dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] - # P2P copy both conv_states and ssm_states - if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_state_buffers"): - g_infer_context.req_manager.buffer_mem_manager.copy_state_buffers( - src_buffer_indexes, dst_buffer_indexes - ) - verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 5d0b6c701d..26749e2069 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -453,21 +453,6 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) - - # Copy accepted buffer states back to buffer[0] for MTP - # Only copy when accept_len > 1 - mask = mtp_accept_len > 1 - if mask.sum() > 0: - actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] - src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ - actual_req_idxes, mtp_accept_len[mask] - 1 - ] - dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] - if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_state_buffers"): - g_infer_context.req_manager.buffer_mem_manager.copy_state_buffers( - src_buffer_indexes, dst_buffer_indexes - ) - verify_event = torch.cuda.Event() verify_event.record() @@ -780,21 +765,6 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf gpu_tensor=mtp_accept_len, ) all_next_token_ids.append(next_token_ids) - - # Copy accepted buffer states back to buffer[0] for MTP - # Only copy when accept_len > 1 - mask = mtp_accept_len > 1 - if mask.sum() > 0: - actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] - src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ - actual_req_idxes, mtp_accept_len[mask] - 1 - ] - dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] - if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_state_buffers"): - g_infer_context.req_manager.buffer_mem_manager.copy_state_buffers( - src_buffer_indexes, dst_buffer_indexes - ) - verify_event = torch.cuda.Event() verify_event.record() From 90df4f1fd4dea45a0b512e532adf222346659be9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 19 Mar 2026 06:26:05 +0000 Subject: [PATCH 126/180] slime mode_backend --- .../basemodel/triton_kernel/norm/qk_norm.py | 2 +- lightllm/models/qwen3next/model.py | 2 -- .../model_infer/mode_backend/base_backend.py | 23 ++++--------------- .../mode_backend/chunked_prefill/impl.py | 3 --- .../mode_backend/dp_backend/impl.py | 2 ++ 5 files changed, 7 insertions(+), 25 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py index 9031582791..e152a8dd83 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py @@ -78,10 +78,10 @@ def _qk_rms_norm_fused_kernel( WK_ptr, stride_k_row, stride_k_col, + eps, # Dimensions num_heads_q: tl.constexpr, # Q 的头数 (用于判断边界) head_dim: tl.constexpr, - eps: tl.constexpr, BLOCK_SIZE: tl.constexpr, FP32_MULTIPLY: tl.constexpr, ): diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 50461bd770..b00f57f3ec 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -35,8 +35,6 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): # infer state class infer_state_class = Qwen3NextInferStateInfo - use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states - def get_radix_class(self): return HybridRadixCache diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a18156324e..08932e4e41 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -172,8 +172,6 @@ def init_model(self, kvargs): self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) - self.use_buffer_manager = getattr(self.model, "use_buffer_manager", False) - radix_cache_class = self.model.get_radix_class() self.radix_cache = ( radix_cache_class( @@ -290,33 +288,21 @@ def decode(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): raise NotImplementedError() def init_mtp_draft_model(self, main_kvargs: dict): - # Support deepseekv3 and qwen3_next MTP modes self.mtp_step = self.args.mtp_step self.draft_models = [] os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att", "qwen3next_vanilla"]: + if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: num_mtp_modules = self.args.mtp_step - elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att", "qwen3next_eagle"]: + elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: num_mtp_modules = 1 else: assert False, f"error mtp mode {self.args.mtp_mode}" for i in range(num_mtp_modules): - # Get MTP model config first to calculate mem_layer_start mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) - - # Calculate mem_layer_start: main model layers + previous MTP model layers - # For models with integrated MTP (like qwen3_next), each MTP module has 1 layer - # For models with separate MTP configs, use the config's num_hidden_layers model_type = mtp_model_cfg.get("model_type", "") - if model_type == "qwen3_next": - # Qwen3Next has integrated MTP with 1 layer per module - mtp_layers_per_module = 1 - else: - mtp_layers_per_module = mtp_model_cfg["num_hidden_layers"] - mem_layer_start = self.model.config["num_hidden_layers"] + i * mtp_layers_per_module mtp_model_kvargs = { "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, @@ -329,7 +315,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): "data_type": main_kvargs.get("data_type", "float16"), "graph_max_batch_size": main_kvargs.get("graph_max_batch_size", 16), "graph_max_len_in_batch": main_kvargs.get("graph_max_len_in_batch", 8196), - "disable_cudagraph": True, # Disable CUDA graphs for MTP draft models + "disable_cudagraph": main_kvargs.get("disable_cudagraph", False), "mem_fraction": main_kvargs["mem_fraction"], "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), "quant_type": main_kvargs.get("quant_type", None), @@ -337,13 +323,12 @@ def init_mtp_draft_model(self, main_kvargs: dict): "run_mode": "normal", "main_model": self.model, "mtp_previous_draft_models": self.draft_models.copy(), - "mem_layer_start": mem_layer_start, - "mtp_index": i, } # Select MTP model class based on model type model_type = mtp_model_cfg.get("model_type", "") if model_type == "deepseek_v3": + assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) elif model_type == "qwen3_moe": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 2039a28d32..a8a5224ebc 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -24,7 +24,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args -from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from .control_state import ControlState logger = init_logger(__name__) @@ -137,7 +136,6 @@ def prefill_normal( extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) - # 第四阶段 event_pack.notify_pre_post_handle() return @@ -260,7 +258,6 @@ def decode_mtp( key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) - verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 26749e2069..bb0e848e76 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -453,6 +453,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) + verify_event = torch.cuda.Event() verify_event.record() @@ -765,6 +766,7 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf gpu_tensor=mtp_accept_len, ) all_next_token_ids.append(next_token_ids) + verify_event = torch.cuda.Event() verify_event.record() From 91edf3bccb77354c5ea7d1825264e132f12b7278 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 19 Mar 2026 10:17:24 +0000 Subject: [PATCH 127/180] fix invalid memory of release_memory --- lightllm/common/allocator_utils.py | 98 --------------- .../kv_cache_mem_manager/mem_manager.py | 112 ++++++++++++++---- .../mamba_cache_mem_manager/cache_manager.py | 2 +- lightllm/models/qwen3next/mem_manager.py | 2 - lightllm/server/httpserver/manager.py | 1 - lightllm/server/router/manager.py | 2 - .../model_infer/mode_backend/base_backend.py | 6 +- .../server/router/model_infer/model_rpc.py | 2 - lightllm/server/visualserver/manager.py | 2 - 9 files changed, 97 insertions(+), 130 deletions(-) delete mode 100644 lightllm/common/allocator_utils.py diff --git a/lightllm/common/allocator_utils.py b/lightllm/common/allocator_utils.py deleted file mode 100644 index 803ed0a715..0000000000 --- a/lightllm/common/allocator_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import List, Union - -import torch - -from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class TokenAllocator: - def __init__(self, size, shared_can_use_token_num_name: str): - self.size = size - - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size - - self.can_use_mem_size = self.size - - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - self.shared_can_use_token_num = SharedInt(shared_can_use_token_num_name) - - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.HOLD_TOKEN_MEMINDEX = self.size - - def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """_summary_ - - Args: - free_index (torch.Tensor): _description_ - """ - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) - self.mem_state[start:end] = free_index_tensor - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - return - - def free_all(self): - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) - - def resize_mem(self, new_size): - """ - just for test code - """ - self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - return diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index e671eac01d..1203cbdec7 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -18,7 +18,6 @@ from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm -from lightllm.common.allocator_utils import TokenAllocator from multiprocessing.reduction import ForkingPickler from filelock import FileLock @@ -26,11 +25,7 @@ logger = init_logger(__name__) -def _get_kvcache_shm_name(): - return f"{get_unique_server_name()}_kv_cache_token_can_use_num" - - -class MemoryManager(TokenAllocator): +class MemoryManager: def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num @@ -41,15 +36,35 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False # profile the max total token num if the size is None self.profile_size(mem_fraction) - super().__init__(self.size, f"{_get_kvcache_shm_name()}_{get_current_rank_in_node()}") + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + + self.can_use_mem_size = self.size + + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name + + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, head_num, - self.head_dim, + head_dim, layer_num, ) + self.HOLD_TOKEN_MEMINDEX = self.size def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ @@ -326,13 +341,59 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index]} + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + self.mem_state.numpy()[start:end] = free_index + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + return + + def free_all(self): + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) - # 重写resize_mem方法,添加_free_buffers和_init_buffers调用 def resize_mem(self, new_size): """ just for test code @@ -343,13 +404,24 @@ def resize_mem(self, new_size): head_dim = self.head_dim layer_num = self.layer_num - # 调用父类的resize_mem - super().resize_mem(new_size) - + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._free_buffers() self._init_buffers(size, dtype, head_num, head_dim, layer_num) return + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kv_buffer[:, index]} + + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + def copy_kv_from_other_dp_ranks( self, mem_managers: List["MemoryManager"], @@ -441,12 +513,12 @@ def __init__(self) -> None: self.dp_world_size = self.global_world_size // args.dp # 兼容多机 dp size=1 纯 tp 模式的情况 self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 - self.shared_tp_can_use_token_nums = [ - SharedInt(f"{_get_kvcache_shm_name()}_{rank_in_node}") + self.shared_tp_infos = [ + SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] def get_unrefed_token_num(self, dp_rank_in_node: int): if self.is_multinode_tp: - return self.shared_tp_can_use_token_nums[0].get_value() - return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() + return self.shared_tp_infos[0].get_value() + return self.shared_tp_infos[dp_rank_in_node].get_value() diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index f5f0727a6b..dbeffc7ad4 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -159,7 +159,7 @@ def free(self, free_index: Union[torch.Tensor, List[int]]): return def free_all(self): - self.conv_state_cache.buffer.fill_(0) + self.conv_state_cache.buffer[:] = 0 self.ssm_state_cache.buffer.fill_(0) self.can_use_mem_size = len(self.mem_state) self.shared_can_use_token_num.set_value(self.can_use_mem_size) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index 5c8d486edf..1e2fe2fa3f 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -136,8 +136,6 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" ) - for _ in range(self.mtp_layer_num): - self.kv_buffer.append(torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")) def free_all(self): super().free_all() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 3dbdb274be..99a7d86ef4 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -617,7 +617,6 @@ async def transfer_to_next_module( if self.pd_mode.is_P_or_NORMAL(): if not self.args.disable_vision: self.send_to_visual.send_pyobj(req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL) - print(f"send_to_visual: {req_to_next_module}") return if not self.args.disable_audio: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 094c416a97..f6ac2c54aa 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -552,7 +552,6 @@ async def _recv_new_reqs_and_schedule(self): # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - print(f"router recv req: {recv_req}") if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) elif isinstance(recv_req, GeneralHttpToModelRpcReq): @@ -601,7 +600,6 @@ def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): return reqs def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: - print(f"router forward to model: {req}") for model_rpc_client in self.model_rpc_clients: ret = model_rpc_client.forward_to_model(req) if not ret.success: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ae0231d097..9d13d818c1 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -374,10 +374,10 @@ def flush_cache(self, request: FlushCacheReq): def release_memory_occupation(self, tags: List[MemoryTag]): try: - self.model.release_memory_occupation(tags) - self.flush_cache(request=None) self.model.req_manager.free_all() self.model.mem_manager.free_all() + self.model.release_memory_occupation(tags) + self.flush_cache(request=None) return True, "Succeeded to release memory occupation." except Exception as e: self.logger.error(f"release memory occupation failed: {str(e)}") @@ -386,6 +386,8 @@ def release_memory_occupation(self, tags: List[MemoryTag]): def resume_memory_occupation(self, tags: List[MemoryTag]): try: self.model.resume_memory_occupation(tags) + self.model.req_manager.free_all() + self.model.mem_manager.free_all() return True, "Succeeded to resume memory occupation." except Exception as e: self.logger.error(f"resume memory occupation failed: {str(e)}") diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index f92e3c57fb..fb19ef93e9 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -149,7 +149,6 @@ def resume_memory_occupation(self, tags: List[MemoryTag]): def exposed_forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: try: req = obtain(req) - print(f"forward to model backend: {req.func_name}") if self.backend is None or not hasattr(self.backend, req.func_name): raise ValueError(f"Backend does not support function {req.func_name}") success, ret = getattr(self.backend, req.func_name)(req.func_args) @@ -195,7 +194,6 @@ async def get_max_total_token_num(self): return obtain(await ans) def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: - print(f"forward to model client: {req.func_name}") ans = self.conn.root.forward_to_model(req) return obtain(ans) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 188f0e23a2..bedc594de1 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -195,9 +195,7 @@ async def loop_for_netio_req(self): ) self.waiting_reqs.append(recv_req) else: - print(f"visual recv req: {recv_req}") self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL) - print(f"visual send req: {recv_req}") self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 From 711667aa0f5aa07fcb7dd65f946c850215637efc Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 19 Mar 2026 10:29:23 +0000 Subject: [PATCH 128/180] flush_cache for hybrid cache --- lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 08f6ba3fff..179d799fb1 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -171,3 +171,7 @@ def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback): self.evict_tree_set.add(parent_node) return + + def flush_cache(self): + super().flush_cache() + self.evict_buffer_set.clear() From b181c0a3e454555c6a44b7cdbb0e01f7bda26a6b Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 19 Mar 2026 10:55:16 +0000 Subject: [PATCH 129/180] fix rpyc --- lightllm/utils/rpyc_fix_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/utils/rpyc_fix_utils.py b/lightllm/utils/rpyc_fix_utils.py index 1a3fd9affd..101b3938d8 100644 --- a/lightllm/utils/rpyc_fix_utils.py +++ b/lightllm/utils/rpyc_fix_utils.py @@ -103,7 +103,7 @@ def fix_accept(self): while self.active: try: sock, addrinfo = self.listener.accept() - if str(sock.family) != "AddressFamily.AF_UNIX": + if sock.family != socket.AF_UNIX: logger.info("set nodelay mode") sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) From c6a6dda8e599c0f2912c337511be368060c7ef31 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 20 Mar 2026 03:20:39 +0000 Subject: [PATCH 130/180] fix: node is None --- lightllm/server/router/model_infer/infer_batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 731fbea405..4766ffbb6c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -187,7 +187,8 @@ def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> b self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None - if node.buffer_idx is None: + # 请求可能在排队时就被终止,导致node可能为None + if node is not None and node.buffer_idx is None: req_to_buffer_index = self.req_manager.req_to_buffer_index buffer_idx = req_to_buffer_index[req.req_idx, 0].item() self.radix_cache.add_buffer_idx_to_node(node, buffer_idx) From ee3a7d582f971d525eea077ebab25194ef3a7852 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 20 Mar 2026 05:55:25 +0000 Subject: [PATCH 131/180] fix resume invalid memory --- lightllm/common/basemodel/basemodel.py | 2 ++ lightllm/server/router/model_infer/mode_backend/base_backend.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 3f084fffd7..a05861f110 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1101,6 +1101,7 @@ def resume_kv_cache(self): torch.cuda.empty_cache() gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + self.mem_manager.free_all() def resume_graph(self): torch.cuda.empty_cache() @@ -1113,3 +1114,4 @@ def resume_all(self): self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) + self.mem_manager.free_all() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 9d13d818c1..7e0c793c34 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -386,8 +386,6 @@ def release_memory_occupation(self, tags: List[MemoryTag]): def resume_memory_occupation(self, tags: List[MemoryTag]): try: self.model.resume_memory_occupation(tags) - self.model.req_manager.free_all() - self.model.mem_manager.free_all() return True, "Succeeded to resume memory occupation." except Exception as e: self.logger.error(f"resume memory occupation failed: {str(e)}") From 32d795d70213110f7a251c65d92d5048b70f7522 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 20 Mar 2026 05:56:49 +0000 Subject: [PATCH 132/180] fix reqs queue --- lightllm/server/router/model_infer/infer_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1e487b968c..91551b1053 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -203,7 +203,7 @@ def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> b self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None - if node.buffer_idx is None: + if node is not None and node.buffer_idx is None: req_to_buffer_index = self.req_manager.req_to_buffer_index buffer_idx = req_to_buffer_index[req.req_idx, 0].item() self.radix_cache.add_buffer_idx_to_node(node, buffer_idx) From a0937a9f8131cb534f094756c257d6d4340eb974 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 20 Mar 2026 06:08:38 +0000 Subject: [PATCH 133/180] fix --- lightllm/common/basemodel/basemodel.py | 1 + lightllm/common/mamba_cache_mem_manager/cache_manager.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index a05861f110..049e83bfb8 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1101,6 +1101,7 @@ def resume_kv_cache(self): torch.cuda.empty_cache() gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + torch.cuda.synchronize() self.mem_manager.free_all() def resume_graph(self): diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index dbeffc7ad4..f5f0727a6b 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -159,7 +159,7 @@ def free(self, free_index: Union[torch.Tensor, List[int]]): return def free_all(self): - self.conv_state_cache.buffer[:] = 0 + self.conv_state_cache.buffer.fill_(0) self.ssm_state_cache.buffer.fill_(0) self.can_use_mem_size = len(self.mem_state) self.shared_can_use_token_num.set_value(self.can_use_mem_size) From 1de0e53e5c7589cfab2db6846ccf345d07aa94ae Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 20 Mar 2026 07:11:59 +0000 Subject: [PATCH 134/180] fix --- lightllm/common/basemodel/basemodel.py | 3 ++- lightllm/models/qwen3next/mem_manager.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 049e83bfb8..466cffa9de 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1047,6 +1047,7 @@ def _gen_special_model_input(self, token_num: int): return special_model_input def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): + torch.cuda.synchronize() if tags is None: self.release_all() return @@ -1101,8 +1102,8 @@ def resume_kv_cache(self): torch.cuda.empty_cache() gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) - torch.cuda.synchronize() self.mem_manager.free_all() + torch.cuda.synchronize() def resume_graph(self): torch.cuda.empty_cache() diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index 1e2fe2fa3f..35e5c527ad 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -27,6 +27,7 @@ def calculate_mamba_cache_size( import torch.distributed as dist use_ratio = max_total_token_num is None and start_args.mamba_cache_size is None + print(f"mem_fraction ", mem_fraction, flush=True) world_size = dist.get_world_size() total_memory = get_total_gpu_memory() From 2dbd2f7ba2c4b2065f09023f476a643bb9c898c7 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 20 Mar 2026 09:18:22 +0000 Subject: [PATCH 135/180] pop weight after load --- lightllm/common/basemodel/basemodel.py | 2 ++ .../layer_weights/meta_weights/embedding_weight.py | 3 +++ .../meta_weights/fused_moe/ep_redundancy.py | 4 +++- .../meta_weights/fused_moe/fused_moe_weight.py | 7 +++++++ .../layer_weights/meta_weights/mm_weight/mm_weight.py | 2 ++ .../layer_weights/meta_weights/norm_weight.py | 10 ++++++++++ .../layer_weights/meta_weights/parameter_weight.py | 4 ++++ 7 files changed, 31 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 466cffa9de..0c70f0a2c4 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -196,6 +196,8 @@ def load_weights(self, weight_dict: dict): transformer_layer_list=self.trans_layers_weight, weight_dict=weight_dict, ) + if weight_dict is not None: + print(f"weight_dict keys: {weight_dict.keys()}") def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index d94a4c709b..9fef2b9084 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -36,6 +36,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) self.weight.load_ok = True + del weights[self.weight_name] def verify_load(self): return self.weight.load_ok @@ -115,6 +116,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) self.weight.load_ok = True + del weights[self.weight_name] def verify_load(self): return self.weight.load_ok @@ -173,6 +175,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"max_position_embeddings: {loaded_max_position_embeddings} != expected: {self.max_position_embeddings}" self.weight.copy_(t_weight.to(self.data_type_)) self.weight.load_ok = True + del weights[self.weight_name] def verify_load(self): return self.weight.load_ok diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py index 749400c8d8..98aa4b71bb 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py @@ -69,11 +69,13 @@ def load_hf_weights(self, weights): w3_weight = f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w3_weight_name}.weight" if w1_weight in weights: self.experts_gate_projs[i] = weights[w1_weight] + del weights[w1_weight] if w3_weight in weights: self.experts_up_projs[i] = weights[w3_weight] + del weights[w3_weight] if w2_weight in weights: self.w2_list[i] = weights[w2_weight] - + del weights[w2_weight] self._load_weight_scale(weights) self._fuse() diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6a9d0f05f9..7bb31a6454 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -268,6 +268,7 @@ def load_hf_weights(self, weights): # Load bias if self.e_score_correction_bias_name in weights: self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + del weights[self.e_score_correction_bias_name] self._load_weight(self.expert_idx_to_local_idx, weights) if self.redundancy_expert_num > 0: self._load_weight(self.redundancy_expert_idx_to_local_idx, weights) @@ -345,10 +346,13 @@ def _load_expert( col_slice_func = self.col_slicer._slice_weight if w1_weight in weights: self.quant_method.load_weight(row_slice_func(weights[w1_weight]), self.w1_list[local_expert_idx]) + del weights[w1_weight] if w3_weight in weights: self.quant_method.load_weight(row_slice_func(weights[w3_weight]), self.w3_list[local_expert_idx]) + del weights[w3_weight] if w2_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + del weights[w2_weight] def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" @@ -358,10 +362,13 @@ def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): col_slice_func = self.col_slicer._slice_weight if w1_merge_weight in weights: self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1) + del weights[w1_merge_weight] if w2_merge_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2) + del weights[w2_merge_weight] if w3_merge_weight in weights: self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3) + del weights[w3_merge_weight] def _load_expert_scale( self, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 5021699143..fd4b395811 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -126,6 +126,7 @@ def _load_weight( slicer = self._get_param_slicer(sub_child_index) weight = slicer._slice_weight(weights[param_name]) self.quant_method.load_weight(weight, self.mm_param_list[sub_child_index]) + del weights[param_name] return def _load_bias( @@ -136,6 +137,7 @@ def _load_bias( bias = slicer._slice_bias(weights[param_name]) self.bias_list[sub_child_index].copy_(bias) self.bias_list[sub_child_index].load_ok = True + del weights[param_name] return def _load_weight_scale( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 89a3d24119..f13171d712 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -24,6 +24,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) self.weight.load_ok = True + del weights[self.weight_name] def verify_load(self): return self.weight.load_ok @@ -77,6 +78,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): self.weight.copy_(weights[self.weight_name]) self.weight += 1 self.weight.load_ok = True + del weights[self.weight_name] class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): @@ -98,9 +100,11 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) self.weight.load_ok = True + del weights[self.weight_name] if self.bias_name in weights: self.bias.copy_(weights[self.bias_name]) self.bias.load_ok = True + del weights[self.bias_name] def verify_load(self): return self.weight.load_ok and self.bias.load_ok @@ -191,6 +195,7 @@ def load_hf_weights(self, weights): # the padding part is zero self.weight[end - start :].zero_() self.weight.load_ok = True + del weights[self.weight_name] class NoTpGEMMANormWeight(RMSNormWeight): @@ -201,6 +206,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) self.weight += 1 + del weights[self.weight_name] class QKRMSNORMWeight(BaseWeightTpl, PlatformAwareOp): @@ -222,9 +228,11 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.q_weight_name in weights: self.q_weight.copy_(weights[self.q_weight_name]) self.q_weight.load_ok = True + del weights[self.q_weight_name] if self.k_weight_name in weights: self.k_weight.copy_(weights[self.k_weight_name]) self.k_weight.load_ok = True + del weights[self.q_weight_name] def verify_load(self): return self.q_weight.load_ok and self.k_weight.load_ok @@ -292,10 +300,12 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): self.q_weight.copy_(weights[self.q_weight_name]) self.q_weight += 1 self.q_weight.load_ok = True + del weights[self.q_weight_name] if self.k_weight_name in weights: self.k_weight.copy_(weights[self.k_weight_name]) self.k_weight += 1 self.k_weight.load_ok = True + del weights[self.k_weight_name] def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple: assert q.ndim == 2 and self.q_weight.ndim == 1 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py index 0afb0ecab2..284cb0d39c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -36,10 +36,12 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: t_weight = weights[self.weight_name] self.weight.copy_(t_weight.to(self.data_type_)) self.weight.load_ok = True + del weights[self.weight_name] if self.bias_name is not None and self.bias_name in weights: t_bias = weights[self.bias_name] self.bias.copy_(t_bias.to(self.data_type_)) self.bias.load_ok = True + del weights[self.bias_name] def verify_load(self) -> bool: if self.weight is not None and not getattr(self.weight, "load_ok", False): @@ -77,7 +79,9 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: t_weight = weights[self.weight_name][start:end] self.weight.copy_(t_weight.to(self.data_type_)) self.weight.load_ok = True + del weights[self.weight_name] if self.bias_name is not None and self.bias_name in weights: t_bias = weights[self.bias_name][start:end] self.bias.copy_(t_bias.to(self.data_type_)) self.bias.load_ok = True + del weights[self.bias_name] From 33bbfdaa9405dc8e49bd6877b649ec7e67a1d707 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 20 Mar 2026 10:26:09 +0000 Subject: [PATCH 136/180] async update weight --- lightllm/server/router/manager.py | 18 ++++++++++-------- .../server/router/model_infer/model_rpc.py | 7 ++++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index f6ac2c54aa..9d6de30746 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -566,7 +566,7 @@ async def _recv_new_reqs_and_schedule(self): # 当队列已经开始清空的时候,将一次接受的数量下调 self.recv_max_count = 64 - self._process_special_reqs(special_reqs) + await self._process_special_reqs(special_reqs) if self.is_multinode_tp: self._multinode_tp_generate_new_batch() @@ -575,12 +575,12 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return - def _process_special_reqs(self, special_reqs: List[BaseReq]): + async def _process_special_reqs(self, special_reqs: List[BaseReq]): if self.is_multinode_tp: special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs) for req in special_reqs: assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq" - self.forward_to_model(req) + await self.forward_to_model(req) def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): req_num = len(reqs) @@ -599,12 +599,14 @@ def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) return reqs - def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: + async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: + forward_to_model_tasks = [] for model_rpc_client in self.model_rpc_clients: - ret = model_rpc_client.forward_to_model(req) - if not ret.success: - ret = ret - break + forward_to_model_tasks.append(model_rpc_client.forward_to_model(req)) + all_ret = await asyncio.gather(*forward_to_model_tasks) + succes = all(ret.success for ret in all_ret) + ret = all_ret[0] + ret.success = succes if self.is_multinode_tp: output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index fb19ef93e9..2060753ae6 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -182,6 +182,7 @@ async def _func(*args, **kwargs): self._init_model = async_wrap(self.conn.root.init_model) self._get_max_total_token_num = async_wrap(self.conn.root.get_max_total_token_num) + self._forward_to_model = async_wrap(self.conn.root.forward_to_model) return async def init_model(self, kvargs): @@ -193,9 +194,9 @@ async def get_max_total_token_num(self): ans = self._get_max_total_token_num() return obtain(await ans) - def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: - ans = self.conn.root.forward_to_model(req) - return obtain(ans) + async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + ans = self._forward_to_model(req) + return obtain(await ans) def _init_env( From 6017484d7e1caec874db20632fdefe52c9365c5f Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 23 Mar 2026 13:52:37 +0800 Subject: [PATCH 137/180] model.norm.weight: add 1 during runtime --- lightllm/models/qwen3next/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index b00f57f3ec..6e1dfa44ee 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -10,6 +10,7 @@ from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( Qwen3NextTransformerLayerInfer, ) +from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -26,11 +27,12 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): # weight class - pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight + # pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight transformer_weight_class = Qwen3NextTransformerLayerWeight # infer class transformer_layer_infer_class = Qwen3NextTransformerLayerInfer + post_layer_infer_class = Qwen3NextPostLayerInfer # infer state class infer_state_class = Qwen3NextInferStateInfo From b98f6d72bbcbebc6eb26d6aee8a1d0f929942724 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 23 Mar 2026 06:39:42 +0000 Subject: [PATCH 138/180] fix r3 --- lightllm/models/qwen3_moe/model.py | 12 +++++++++--- lightllm/models/qwen3_moe_mtp/model.py | 1 + lightllm/models/qwen3next/model.py | 6 ------ 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index bd8474f2f5..0571d5a3d6 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -1,8 +1,12 @@ import torch from typing import final from lightllm.models.registry import ModelRegistry -from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer -from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import ( + Qwen3MOETransformerLayerInfer, +) +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import ( + Qwen3MOETransformerLayerWeight, +) from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger @@ -26,6 +30,8 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() - # Only initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 9f83832a7e..b4be10d0d0 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -26,6 +26,7 @@ def _pre_init(self, kvargs: dict): return def _init_custom(self): + super()._init_custom() self._cos_cached = self.main_model._cos_cached self._sin_cached = self.main_model._sin_cached return diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 6e1dfa44ee..35339d4075 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -59,12 +59,6 @@ def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - def _init_custom(self): - super()._init_custom() - # Only initialize DeepEP group for MoE models with num_experts - if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) - def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 From c0cebbafa6d6bd3c8f315abcb041132a663e1e95 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 23 Mar 2026 08:09:27 +0000 Subject: [PATCH 139/180] fix qwen35 nrom --- .../layer_weights/pre_and_post_layer_weight.py | 12 ++++++++++++ lightllm/models/qwen3_5/model.py | 6 +++--- lightllm/models/qwen3next/model.py | 4 +--- 3 files changed, 16 insertions(+), 6 deletions(-) create mode 100644 lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py diff --git a/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..41d6a02106 --- /dev/null +++ b/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,12 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import GEMMANormWeight +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight + + +class Qwen35PreAndPostLayerWeight(Qwen3VLPreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + self.final_norm_weight_ = GEMMANormWeight( + dim=network_config["hidden_size"], + weight_name="model.norm.weight", + data_type=self.data_type_, + ) diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index 63503c77ba..bbe662a77b 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -9,8 +9,8 @@ from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( Qwen3VLMultimodalPreLayerInfer, ) -from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import ( - Qwen3VLPreAndPostLayerWeight, +from lightllm.models.qwen3_5.layer_weights.pre_and_post_layer_weight import ( + Qwen35PreAndPostLayerWeight, ) from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( Qwen35TransformerLayerInfer, @@ -52,7 +52,7 @@ class Qwen3_5TpPartModel(Qwen3NextTpPartModel): """ transformer_weight_class = Qwen35TransformerLayerWeight - pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + pre_and_post_weight_class = Qwen35PreAndPostLayerWeight pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer transformer_layer_infer_class = Qwen35TransformerLayerInfer diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 35339d4075..ebd7752961 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -10,7 +10,6 @@ from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( Qwen3NextTransformerLayerInfer, ) -from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -27,12 +26,11 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): # weight class - # pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight + pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight transformer_weight_class = Qwen3NextTransformerLayerWeight # infer class transformer_layer_infer_class = Qwen3NextTransformerLayerInfer - post_layer_infer_class = Qwen3NextPostLayerInfer # infer state class infer_state_class = Qwen3NextInferStateInfo From 5f4fa780a4b14e5511bf22cf1af274812b295a0d Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 23 Mar 2026 08:27:26 +0000 Subject: [PATCH 140/180] Revert "fix qwen35 nrom" This reverts commit c0cebbafa6d6bd3c8f315abcb041132a663e1e95. --- .../layer_weights/pre_and_post_layer_weight.py | 12 ------------ lightllm/models/qwen3_5/model.py | 6 +++--- lightllm/models/qwen3next/model.py | 4 +++- 3 files changed, 6 insertions(+), 16 deletions(-) delete mode 100644 lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py diff --git a/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index 41d6a02106..0000000000 --- a/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,12 +0,0 @@ -from lightllm.common.basemodel.layer_weights.meta_weights import GEMMANormWeight -from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight - - -class Qwen35PreAndPostLayerWeight(Qwen3VLPreAndPostLayerWeight): - def __init__(self, data_type, network_config): - super().__init__(data_type, network_config) - self.final_norm_weight_ = GEMMANormWeight( - dim=network_config["hidden_size"], - weight_name="model.norm.weight", - data_type=self.data_type_, - ) diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index bbe662a77b..63503c77ba 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -9,8 +9,8 @@ from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( Qwen3VLMultimodalPreLayerInfer, ) -from lightllm.models.qwen3_5.layer_weights.pre_and_post_layer_weight import ( - Qwen35PreAndPostLayerWeight, +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import ( + Qwen3VLPreAndPostLayerWeight, ) from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( Qwen35TransformerLayerInfer, @@ -52,7 +52,7 @@ class Qwen3_5TpPartModel(Qwen3NextTpPartModel): """ transformer_weight_class = Qwen35TransformerLayerWeight - pre_and_post_weight_class = Qwen35PreAndPostLayerWeight + pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer transformer_layer_infer_class = Qwen35TransformerLayerInfer diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index ebd7752961..35339d4075 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -10,6 +10,7 @@ from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( Qwen3NextTransformerLayerInfer, ) +from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -26,11 +27,12 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): # weight class - pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight + # pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight transformer_weight_class = Qwen3NextTransformerLayerWeight # infer class transformer_layer_infer_class = Qwen3NextTransformerLayerInfer + post_layer_infer_class = Qwen3NextPostLayerInfer # infer state class infer_state_class = Qwen3NextInferStateInfo From 64506b3c2e1393f981958a3f13cba49bd6ed104e Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 23 Mar 2026 08:33:44 +0000 Subject: [PATCH 141/180] fix --- .../layer_weights/pre_and_post_layer_weight.py | 11 +++++++++++ lightllm/models/qwen3_5/model.py | 7 +++---- 2 files changed, 14 insertions(+), 4 deletions(-) create mode 100644 lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py diff --git a/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..5056f97b02 --- /dev/null +++ b/lightllm/models/qwen3_5/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,11 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, GEMMANormWeight +from lightllm.models.qwen3next.layer_weights.pre_and_post_layer_weight import Qwen3NextPreAndPostLayerWeight +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import rename_weight_keys + + +class Qwen35PreAndPostLayerWeight(Qwen3NextPreAndPostLayerWeight): + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index 63503c77ba..a13238b6e5 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -9,9 +9,8 @@ from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( Qwen3VLMultimodalPreLayerInfer, ) -from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import ( - Qwen3VLPreAndPostLayerWeight, -) +from lightllm.models.qwen3_5.layer_weights.pre_and_post_layer_weight import Qwen35PreAndPostLayerWeight + from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( Qwen35TransformerLayerInfer, ) @@ -52,7 +51,7 @@ class Qwen3_5TpPartModel(Qwen3NextTpPartModel): """ transformer_weight_class = Qwen35TransformerLayerWeight - pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + pre_and_post_weight_class = Qwen35PreAndPostLayerWeight pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer transformer_layer_infer_class = Qwen35TransformerLayerInfer From 73b10ca7940a3b574f25641f9ef5f3263b3917b9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 23 Mar 2026 17:10:16 +0800 Subject: [PATCH 142/180] remove unused log --- lightllm/common/basemodel/basemodel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 0c70f0a2c4..466cffa9de 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -196,8 +196,6 @@ def load_weights(self, weight_dict: dict): transformer_layer_list=self.trans_layers_weight, weight_dict=weight_dict, ) - if weight_dict is not None: - print(f"weight_dict keys: {weight_dict.keys()}") def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 From a93509f24e568fdb9491e0ce936d605b4b17be20 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 23 Mar 2026 19:13:54 +0800 Subject: [PATCH 143/180] fix mamba_len --- lightllm/server/router/model_infer/infer_batch.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 91551b1053..64a4d531a8 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -466,7 +466,6 @@ def __init__( # 在开启radix cache的情况下,用于标记命中情况,用于插入算法 self.mamba_model_match_len = 0 - self.mamba_buffer_insert_len = 0 self.extra_need_to_free_token_index = [] # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache @@ -535,13 +534,6 @@ def _match_radix_cache(self): self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 - if g_infer_context.use_mamba_model: - MAMBA_PREFILL_BLOCK_SIZE = 128 - MAMBA_MIN_INSERT_LEN = 1024 - miss_prefix_len = miss_prefix_len - miss_prefix_len % MAMBA_PREFILL_BLOCK_SIZE - if miss_prefix_len > MAMBA_MIN_INSERT_LEN: - self.mamba_buffer_insert_len = miss_prefix_len - self.shm_req.shm_cur_kv_len = self.cur_kv_len return @@ -596,11 +588,6 @@ def get_chuncked_input_token_ids(self): def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) - - if self.mamba_buffer_insert_len > 0: - chunked_end = min(self.get_cur_total_len(), chunked_start + self.mamba_buffer_insert_len) - self.mamba_buffer_insert_len = 0 - return chunked_end def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int): From 1115543b990d6643f5da1561e5c1006f631095b8 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 23 Mar 2026 12:28:41 +0000 Subject: [PATCH 144/180] fix --- lightllm/server/router/model_infer/infer_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4766ffbb6c..f1b9d38a65 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -318,7 +318,7 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo can_alloc_token_num -= prefill_need_token_num revovered_reqs.append(req) - self._alloc_and_copy_req_buffers(revovered_reqs) + self._alloc_and_copy_req_buffers(self.req_manager, self.radix_cache, revovered_reqs) g_infer_state_lock.release() return From d562b7bb4e760c5cd3d5f62c313299a8524d643a Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 23 Mar 2026 12:31:12 +0000 Subject: [PATCH 145/180] fix --- lightllm/server/router/model_infer/infer_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 64a4d531a8..f0e126bb88 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -333,7 +333,7 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo can_alloc_token_num -= prefill_need_token_num revovered_reqs.append(req) - self._alloc_and_copy_req_buffers(revovered_reqs) + self._alloc_and_copy_req_buffers(self.req_manager, self.radix_cache, revovered_reqs) g_infer_state_lock.release() return From 8f11f08bbea8c2298cf4c1b597a21a18146596fe Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 23 Mar 2026 12:42:09 +0000 Subject: [PATCH 146/180] fix and remove unused code --- .../layer_infer/transformer_layer_infer.py | 3 +- lightllm/models/qwen3next/mem_manager.py | 4 +- .../triton_kernel/fused_add_gemma_rmsnorm.py | 186 ------------------ .../triton_kernel/fused_qkv_gating.py | 163 --------------- .../server/router/model_infer/infer_batch.py | 2 +- 5 files changed, 3 insertions(+), 355 deletions(-) delete mode 100644 lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index ec07b38c5a..b8b31f44d2 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -200,6 +200,7 @@ def context_attention_forward( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): + # full attention layer if not self.is_linear_attention_layer: return super().context_attention_forward(input_embdings, infer_state, layer_weight) @@ -271,8 +272,6 @@ def _split_qkvzba(self, mixed_qkvzba, is_decode=False): return mixed_qkv, z, b, a def _rearrange_mixed_qkv(self, mixed_qkv, decode=False): - if mixed_qkv is None: - return None, None, None if decode: query, key, value = torch.split( mixed_qkv, diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index 5c8d486edf..c8ecae9461 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -130,14 +130,12 @@ def __init__( def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., # None, kv_cache, mtp_kv_cache, mtp_kv_cache] - # Only full attention layers and MTP layers have KV cache. + # Only full attention layers have KV cache. self.kv_buffer = [None for _ in range(self.layer_num)] for layer_id in range(self.full_attn_layer_num): self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" ) - for _ in range(self.mtp_layer_num): - self.kv_buffer.append(torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")) def free_all(self): super().free_all() diff --git a/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py deleted file mode 100644 index 6413158a66..0000000000 --- a/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py +++ /dev/null @@ -1,186 +0,0 @@ -import torch - -import triton -import triton.language as tl - -from lightllm.common.triton_utils.autotuner import autotune - - -@triton.jit -def _fused_add_gemma_rmsnorm_kernel( - x_ptr, - r_ptr, - w_ptr, - y_ptr, - x_stride0, - x_stride1, - r_stride0, - r_stride1, - y_stride0, - y_stride1, - N: tl.constexpr, - EPS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """Fused in-place residual add + Gemma RMSNorm. - - For each row: - 1. sum = x + residual (written back to x in-place) - 2. rstd = 1 / sqrt(mean(sum²) + eps) - 3. y = sum * rstd * (w + 1.0) (Gemma-style) - """ - row = tl.program_id(0) - x_ptr = x_ptr + row * x_stride0 - r_ptr = r_ptr + row * r_stride0 - y_ptr = y_ptr + row * y_stride0 - - # Pass 1: compute sum = x + residual, write back to x, accumulate sum² for variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) - r = tl.load(r_ptr + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) - s = x + r - # Write sum back to x (in-place residual add) - tl.store(x_ptr + cols * x_stride1, s.to(x_ptr.dtype.element_ty), mask=mask) - _var += s * s - - var = tl.sum(_var, axis=0) / N - rstd = 1.0 / tl.sqrt(var + EPS) - - # Pass 2: normalize and apply Gemma-style linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - # Re-read x (now contains sum); hot in L2 from the write in pass 1 - s = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) - w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) - y = s * rstd * (w + 1.0) - tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) - - -def _get_fused_add_gemma_rmsnorm_configs(): - """Generate configurations for autotuning fused add + Gemma RMSNorm kernel.""" - configs = [] - for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: - for num_warps in [1, 2, 4, 8]: - configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1}) - return configs - - -def _get_fused_add_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): - """Generate static key for caching autotuned configurations.""" - N = x.shape[-1] - return { - "x_dtype": str(x.dtype), - "weight_dtype": str(w.dtype), - "N": N, - } - - -@autotune( - kernel_name="fused_add_gemma_rmsnorm:v1", - configs_gen_func=_get_fused_add_gemma_rmsnorm_configs, - static_key_func=_get_fused_add_gemma_rmsnorm_static_key, - run_key_func=lambda x: x.shape[-1], - mutates_args=["x"], -) -def fused_add_gemma_rmsnorm(x, residual, w, eps, out=None, run_config: dict = None): - """Fused in-place residual add + Gemma RMSNorm. - - x: [M, N] - modified in-place (x += residual) - residual: [M, N] - residual to add (will be viewed as [-1, N]) - w: [N] - norm weight (Gemma-style: applies w + 1.0) - eps: float - out: [M, N] - output buffer (allocated if None) - Returns: out - """ - N = x.shape[-1] - y = torch.empty_like(x) if out is None else out - x_arg = x.view(-1, N) - r_arg = residual.view(-1, N) - y_arg = y.view(-1, N) - - M = x_arg.shape[0] - - # Default heuristic when autotune is disabled or no config provided - if not run_config: - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_SIZE: - raise RuntimeError("This fused_add_gemma_rmsnorm doesn't support feature dim >= 64KB.") - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} - - BLOCK_SIZE = run_config["BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - _fused_add_gemma_rmsnorm_kernel[(M,)]( - x_arg, - r_arg, - w, - y_arg, - x_stride0=x_arg.stride(0), - x_stride1=x_arg.stride(1), - r_stride0=r_arg.stride(0), - r_stride1=r_arg.stride(1), - y_stride0=y_arg.stride(0), - y_stride1=y_arg.stride(1), - N=N, - EPS=eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) - - return y - - -def _fused_add_gemma_rmsnorm_torch(x, residual, weight, eps): - """Reference implementation for correctness testing.""" - original_dtype = x.dtype - x = x.to(torch.float32) - residual = residual.to(torch.float32) - s = x + residual - normed = s * torch.rsqrt(s.pow(2).mean(-1, keepdim=True) + eps) - out = normed * (1.0 + weight.float()) - return s.to(original_dtype), out.to(original_dtype) - - -def test_fused_add_gemma_rmsnorm(M=128, N=2048, dtype=torch.bfloat16, eps=1e-5, device="cuda"): - """Verify fused kernel matches separate add + gemma_rmsnorm.""" - x_shape = (M, N) - w_shape = (N,) - weight = torch.rand(w_shape, dtype=dtype, device=device) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) - residual = 0.1 * torch.randn(x_shape, dtype=dtype, device=device) - - # Clone x for reference (since fused modifies x in-place) - x_ref = x.clone() - x_fused = x.clone() - - # Reference: separate add + norm - x_ref_sum, y_ref = _fused_add_gemma_rmsnorm_torch(x_ref, residual, weight, eps) - - # Fused kernel - y_fused = fused_add_gemma_rmsnorm(x_fused, residual, weight, eps) - - # Check x was modified in-place (x += residual) - print(f"Test: M={M}, N={N}, dtype={dtype}") - print(f" x in-place max delta: {torch.max(torch.abs(x_fused - x_ref_sum)):.6e}") - print(f" output max delta: {torch.max(torch.abs(y_fused - y_ref)):.6e}") - - atol = 1e-2 if dtype == torch.float32 else 5e-2 - assert torch.allclose(x_fused, x_ref_sum, atol=atol, rtol=0), "x in-place update mismatch!" - assert torch.allclose(y_fused, y_ref, atol=atol, rtol=0), "output mismatch!" - print(" PASSED") - - -if __name__ == "__main__": - test_fused_add_gemma_rmsnorm(M=1, N=2048) - test_fused_add_gemma_rmsnorm(M=128, N=2048) - test_fused_add_gemma_rmsnorm(M=1, N=2048, dtype=torch.float16) - test_fused_add_gemma_rmsnorm(M=64, N=4096, dtype=torch.float32) - print("All tests passed!") diff --git a/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py deleted file mode 100644 index f37d4911af..0000000000 --- a/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -Fused QKV projection and GDN gating computation. - -This kernel fuses: -1. Linear projection (matmul with weight) -2. Output reorganization (split and reshape) -3. Gating computation (g and beta from a, b) - -This reduces kernel launches from 3 to 1 for the QKV+gating path. -""" - -import torch -import triton -import triton.language as tl -from typing import Tuple, Optional -from lightllm.common.triton_utils.autotuner import autotune - - -@triton.jit -def _fused_gdn_gating_only_kernel( - # Output pointers - g_ptr, - beta_ptr, - # Input pointers - a_ptr, - b_ptr, - A_log_ptr, - dt_bias_ptr, - # Dimensions - batch_size, - num_heads, - # Constants - beta_const: tl.constexpr, - threshold: tl.constexpr, - BLOCK_BATCH: tl.constexpr, - BLOCK_HEADS: tl.constexpr, -): - """ - Fused kernel for GDN gating computation with better memory access patterns. - - Computes: - - g = -exp(A_log) * softplus(a + dt_bias) - - beta = sigmoid(b) - """ - pid_batch = tl.program_id(0) - pid_head = tl.program_id(1) - - batch_offs = pid_batch * BLOCK_BATCH + tl.arange(0, BLOCK_BATCH) - head_offs = pid_head * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS) - - batch_mask = batch_offs < batch_size - head_mask = head_offs < num_heads - mask = batch_mask[:, None] & head_mask[None, :] - - # Load A_log and dt_bias (broadcast across batch) - A_log = tl.load(A_log_ptr + head_offs, mask=head_mask, other=0.0) - dt_bias = tl.load(dt_bias_ptr + head_offs, mask=head_mask, other=0.0) - - # Load a and b - offs = batch_offs[:, None] * num_heads + head_offs[None, :] - a = tl.load(a_ptr + offs, mask=mask, other=0.0) - b = tl.load(b_ptr + offs, mask=mask, other=0.0) - - # Compute g = -exp(A_log) * softplus(a + dt_bias) - x = a.to(tl.float32) + dt_bias.to(tl.float32) - softplus_x = tl.where(beta_const * x <= threshold, (1.0 / beta_const) * tl.log(1.0 + tl.exp(beta_const * x)), x) - g = -tl.exp(A_log.to(tl.float32)) * softplus_x - - # Compute beta = sigmoid(b) - beta_out = tl.sigmoid(b.to(tl.float32)) - - # Store outputs with layout [1, batch, num_heads] - out_offs = batch_offs[:, None] * num_heads + head_offs[None, :] - tl.store(g_ptr + out_offs, g.to(g_ptr.dtype.element_ty), mask=mask) - tl.store(beta_ptr + out_offs, beta_out.to(beta_ptr.dtype.element_ty), mask=mask) - - -def _get_fused_gating_configs(): - """Generate autotuning configurations.""" - configs = [] - for block_batch in [1, 4, 8, 16]: - for block_heads in [8, 16, 32]: - for num_warps in [2, 4, 8]: - configs.append( - { - "BLOCK_BATCH": block_batch, - "BLOCK_HEADS": block_heads, - "num_warps": num_warps, - } - ) - return configs - - -def _get_fused_gating_static_key(a: torch.Tensor): - return {"dtype": str(a.dtype), "num_heads": a.shape[1]} - - -def _get_fused_gating_run_key(a: torch.Tensor): - return a.shape[0] - - -@autotune( - kernel_name="fused_gdn_gating_v2:v1", - configs_gen_func=_get_fused_gating_configs, - static_key_func=_get_fused_gating_static_key, - run_key_func=_get_fused_gating_run_key, - mutates_args=["g", "beta"], -) -def fused_gdn_gating_v2( - a: torch.Tensor, - b: torch.Tensor, - A_log: torch.Tensor, - dt_bias: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - beta_const: float = 1.0, - threshold: float = 20.0, - run_config: Optional[dict] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Optimized GDN gating with pre-allocated output tensors. - - Args: - a: Input tensor [batch, num_heads] - b: Input tensor [batch, num_heads] - A_log: Log of A parameter [num_heads] - dt_bias: Bias for dt [num_heads] - g: Output tensor [1, batch, num_heads] (pre-allocated) - beta: Output tensor [1, batch, num_heads] (pre-allocated) - beta_const: Beta constant for softplus (default: 1.0) - threshold: Threshold for softplus approximation (default: 20.0) - run_config: Optional autotuning configuration - - Returns: - Tuple of (g, beta) - same tensors passed in, now filled - """ - batch_size, num_heads = a.shape - - if run_config is None: - run_config = {"BLOCK_BATCH": 8, "BLOCK_HEADS": 16, "num_warps": 4} - - grid = ( - triton.cdiv(batch_size, run_config["BLOCK_BATCH"]), - triton.cdiv(num_heads, run_config["BLOCK_HEADS"]), - ) - - _fused_gdn_gating_only_kernel[grid]( - g, - beta, - a, - b, - A_log, - dt_bias, - batch_size, - num_heads, - beta_const, - threshold, - run_config["BLOCK_BATCH"], - run_config["BLOCK_HEADS"], - num_warps=run_config["num_warps"], - ) - - return g, beta diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4766ffbb6c..f1b9d38a65 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -318,7 +318,7 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo can_alloc_token_num -= prefill_need_token_num revovered_reqs.append(req) - self._alloc_and_copy_req_buffers(revovered_reqs) + self._alloc_and_copy_req_buffers(self.req_manager, self.radix_cache, revovered_reqs) g_infer_state_lock.release() return From 267412dd53b4535bbbf499cfe847fe54f1bdc26d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 23 Mar 2026 12:58:21 +0000 Subject: [PATCH 147/180] fix format --- lightllm/server/core/objs/start_args_type.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 40086e43ab..8c4301b7dc 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -126,7 +126,9 @@ class StartArgs: vit_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} ) - llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]}) + llm_kv_type: str = field( + default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]} + ) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( From f7bee0875f4099d5db2e8076f072623830f3a752 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 23 Mar 2026 14:39:56 +0000 Subject: [PATCH 148/180] gatermsnorm weight and mamba profile_size --- .../layer_weights/meta_weights/__init__.py | 1 + .../layer_weights/meta_weights/norm_weight.py | 50 ++++++++ .../triton_kernel/norm}/gated_rmsnorm.py | 0 .../mamba_cache_mem_manager/cache_manager.py | 104 ++++++++++++++--- lightllm/models/qwen3_5/model.py | 9 +- .../layer_infer/transformer_layer_infer.py | 36 +++--- .../layer_weights/transformer_layer_weight.py | 3 +- lightllm/models/qwen3next/mem_manager.py | 108 +++--------------- lightllm/models/qwen3next/model.py | 53 ++------- 9 files changed, 183 insertions(+), 181 deletions(-) rename lightllm/{models/qwen3next/triton_kernel => common/basemodel/triton_kernel/norm}/gated_rmsnorm.py (100%) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 21b5b7959e..66390bfe0f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -11,6 +11,7 @@ TpRMSNormWeight, RMSNormWeight, GEMMANormWeight, + GatedRMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 89a3d24119..62eab2229b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -5,6 +5,7 @@ from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward +from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward from .platform_op import PlatformAwareOp @@ -79,6 +80,55 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): self.weight.load_ok = True +class GatedRMSNormWeight(RMSNormWeight): + def _triton_forward( + self, + input: torch.Tensor, + gate_value: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + input.ndim in [2, 3] and self.weight.ndim == 1 + ), f"input.ndim: {input.ndim} != 2 or weight.ndim: {self.weight.ndim} != 1" + if out is None: + out = alloc_func(input.shape, dtype=input.dtype, device=input.device) + return gated_rmsnorm_forward(x=input, weight=self.weight, bias=None, eps=eps, z=gate_value, out=out) + + def _cuda_forward( + self, + input: torch.Tensor, + gate_value: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + # only triton implementation is supported for rmsnorm on cuda platform + return self._triton_forward(input=input, gate_value=gate_value, eps=eps, out=out, alloc_func=alloc_func) + + def _musa_forward( + self, + input: torch.Tensor, + gate_value: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + # triton implementation is supported by musa. + return self._triton_forward(input=input, gate_value=gate_value, eps=eps, out=out, alloc_func=alloc_func) + + def __call__( + self, + input: torch.Tensor, + gate_value: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + return self._forward(input=input, gate_value=gate_value, eps=eps, out=out, alloc_func=alloc_func) + + class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): super().__init__(tp_rank=0, tp_world_size=1) diff --git a/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py similarity index 100% rename from lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py rename to lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 9d2d372e17..dc06ccf859 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -20,7 +20,6 @@ def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_ self.dtype = dtype self.shape = shape self.layer_num = layer_num - self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda") def get_cell_size(self): @@ -33,12 +32,33 @@ def __init__( size: int, layer_num: int, conv_state_dtype: torch.dtype, - conv_state_shape: Tuple[int, ...], ssm_state_dtype: torch.dtype, - ssm_state_shape: Tuple[int, ...], + conv_kernel_size: int, + num_linear_k_heads: int, + num_linear_v_heads: int, + head_linear_k_dim: int, + head_linear_v_dim: int, ): # init the mem state self.size = size + self.num_linear_k_heads = num_linear_k_heads + self.num_linear_v_heads = num_linear_v_heads + self.head_linear_k_dim = head_linear_k_dim + self.head_linear_v_dim = head_linear_v_dim + self.conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) + self.layer_num = layer_num + self.conv_kernel_size = conv_kernel_size + conv_state_shape = (self.conv_dim, conv_kernel_size - 1) + ssm_state_shape = ( + self.num_linear_v_heads, + self.head_linear_k_dim, + self.head_linear_v_dim, + ) + self.ssm_state_dtype = ssm_state_dtype + self.conv_state_dtype = conv_state_dtype + self.profile_size() self.mem_state = torch.arange( 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True ) @@ -51,20 +71,11 @@ def __init__( self.can_use_mem_size = self.size self.shared_can_use_token_num = SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.HOLD_TOKEN_MEMINDEX = self.size # init the layer cache - self.conv_state_cache = LayerCache(size, conv_state_dtype, conv_state_shape, layer_num) - self.ssm_state_cache = LayerCache(size, ssm_state_dtype, ssm_state_shape, layer_num) - self.HOLD_BUFFER_INDEX = size - - logger.warning( - f"Linear attention state cache size: {size}\n" - f"Conv state use : " - f"{self.conv_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" - f"Ssm state use : " - f"{self.ssm_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" - ) + self.conv_state_cache = LayerCache(self.size, conv_state_dtype, conv_state_shape, layer_num) + self.ssm_state_cache = LayerCache(self.size, ssm_state_dtype, ssm_state_shape, layer_num) + self.HOLD_BUFFER_INDEX = self.size def get_mamba_cache(self, layer_idx: int): conv_state = self.conv_state_cache.buffer[layer_idx] @@ -183,6 +194,69 @@ def resize_mem(self, new_size): self.shared_can_use_token_num.set_value(self.can_use_mem_size) return + def profile_size( + self, + ): + start_args = get_env_start_args() + if self.size is not None: + assert self.size < start_args.running_max_req_size * 2, ( + f"error mamba_cache_size {self.size} < running_max_req_size * 2 {start_args.running_max_req_size * 2}", + f"mamba_cache_size should be at least running_max_req_size * 2", + ) + return + from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory + import torch.distributed as dist + + mem_fraction = start_args.mem_fraction + world_size = dist.get_world_size() + total_memory = get_total_gpu_memory() + available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) + conv_cell_size = ( + self.layer_num + * self.conv_dim + * (self.conv_kernel_size - 1) + * torch._utils._element_size(self.conv_state_dtype) + ) + ssm_cell_size = ( + self.layer_num + * (self.num_linear_v_heads) + * self.head_linear_k_dim + * self.head_linear_v_dim + * torch._utils._element_size(self.ssm_state_dtype) + ) + total_cell_size = conv_cell_size + ssm_cell_size + mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 + mamba_memory_gb = available_memory * mamba_cache_ratio + mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) + + if mamba_cache_size < start_args.running_max_req_size * 2: + ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 + raise ValueError( + f"Insufficient memory for mamba cache allocation!\n\n" + f"mamba_cache_size should be at least running_max_req_size * 2\n" + f"Calculated mamba_cache_size ({mamba_cache_size}) < " + f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n\n" + f"Memory budget:\n" + f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated buffers: {mamba_cache_size}\n" + f" Required buffers: {start_args.running_max_req_size}\n\n" + f"Solutions:\n" + f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" + f" 2. Increase --mamba_cache_ratio from {ratio} to " + f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" + f" 3. Increase --mem_fraction to leave more memory for caches\n" + ) + + logger.info( + f"Mamba cache allocation:\n" + f" Available memory: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated mamba_cache_size: {mamba_cache_size}" + ) + self.size = mamba_cache_size + return + class ReadOnlyStaticsMambaCacheManager: """ diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index a13238b6e5..da79cf32b2 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -70,30 +70,23 @@ def _init_config(self): if self.vision_config is None: logger.warning("No vision_config found in checkpoint. " "Multimodal features may not work correctly.") - # Apply standard config repairs repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) - # Qwen3.5 stores RoPE config under text_config.rope_parameters. rope_parameters = self.config.get("rope_parameters") if isinstance(rope_parameters, dict): if "rope_theta" in rope_parameters and "rope_theta" not in self.config: self.config["rope_theta"] = rope_parameters["rope_theta"] if "partial_rotary_factor" in rope_parameters and "partial_rotary_factor" not in self.config: self.config["partial_rotary_factor"] = rope_parameters["partial_rotary_factor"] - # Preserve the richer RoPE metadata in the expected field when absent. if "rope_scaling" not in self.config: self.config["rope_scaling"] = rope_parameters # MoE routing parameters - set defaults for Qwen3.5 compatibility if "norm_topk_prob" not in self.config: - self.config["norm_topk_prob"] = True # Standard default for MoE models + self.config["norm_topk_prob"] = True - # Handle fine-tuning config if present if self.finetune_config: self.config["vocab_size"] = self.finetune_config.vocab_size - - # Calculate num_kv_heads for KV cache memory management - # Required by parent class _init_mem_manager() in Qwen3NextTpPartModel self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index b8b31f44d2..dce5e96b31 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -12,7 +12,6 @@ from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from typing import Tuple -from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule @@ -162,7 +161,6 @@ def _get_qkv( dim=-1, ) o_gate = layer_weight._o_gate_proj.mm(input) - # In-place sigmoid saves one allocation (gate_value is consumed once in _get_o) infer_state.gate_value = o_gate.sigmoid_() layer_weight.qk_norm_weight_( q, @@ -238,25 +236,24 @@ def gdn_forward( mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) if is_prefill: - g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) core_attn_out = self._gdn_prefill_kernel( - mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight ) else: - core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) + core_attn_out = self._gdn_decode_kernel( + mixed_qkv, + conv_states, + ssm_states, + a, + b, + infer_state, + layer_weight, + ) num_tokens = z.shape[0] - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) - gated_rmsnorm_forward( - core_attn_out, - layer_weight.linear_norm.weight, - None, - self.eps_, - z, - out=norm_out, - ) + core_attn_out = core_attn_out.view(-1, core_attn_out.shape[-1]) + z = z.contiguous().view(-1, z.shape[-1]) + norm_out = layer_weight.linear_norm(core_attn_out, z, self.eps_) core_attn_out = norm_out.view(num_tokens, -1) output = layer_weight.linear_out_proj.mm(core_attn_out) return output @@ -300,13 +297,12 @@ def _gdn_prefill_kernel( mixed_qkv: torch.Tensor, conv_states: torch.Tensor, ssm_states: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - """Prefill kernel for GDN forward pass.""" - # Conv1D processing + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) mixed_qkv = mixed_qkv.transpose(0, 1) out_tensor = causal_conv1d_fn( mixed_qkv, diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 31dae85ec8..8e251a4f50 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -5,6 +5,7 @@ COLMMWeight, RMSNormWeight, GEMMANormWeight, + GatedRMSNormWeight, TpParameterWeight, QKVROWNMMWeight, QKGEMMANormWeight, @@ -185,7 +186,7 @@ def _init_gdn_weight(self): # Norm is applied per-head across head_dim, not across all heads linear_norm_dim = self.linear_v_head_dim - self.linear_norm = RMSNormWeight( + self.linear_norm = GatedRMSNormWeight( dim=linear_norm_dim, weight_name=f"{prefix}.norm.weight", data_type=self.data_type_, diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index c8ecae9461..12a6d56b8c 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -9,87 +9,6 @@ class Qwen3NextHybridMemManager(MemoryManager): - @staticmethod - def calculate_mamba_cache_size( - start_args: StartArgs, - max_total_token_num: int, - mem_fraction: float, - config: dict, - head_linear_k_dim: int, - num_linear_k_heads: int, - head_linear_v_dim: int, - num_linear_v_heads: int, - tp_world_size: int, - data_type: torch.dtype, - ) -> int: - """Calculate mamba cache size based on available memory and mamba_cache_ratio.""" - from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory - import torch.distributed as dist - - use_ratio = max_total_token_num is None and start_args.mamba_cache_size is None - - world_size = dist.get_world_size() - total_memory = get_total_gpu_memory() - available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) - - conv_kernel_size = config["linear_conv_kernel_dim"] - conv_dim = ( - head_linear_k_dim * num_linear_k_heads * 2 + head_linear_v_dim * num_linear_v_heads - ) // tp_world_size - - num_linear_layers = config["n_layer"] - (config["n_layer"] // config["full_attention_interval"]) - - conv_cell_size = num_linear_layers * conv_dim * (conv_kernel_size - 1) * torch._utils._element_size(data_type) - - ssm_dtype = torch.bfloat16 if start_args.mamba_ssm_data_type == "bfloat16" else torch.float32 - ssm_cell_size = ( - num_linear_layers - * (num_linear_v_heads // tp_world_size) - * head_linear_k_dim - * head_linear_v_dim - * torch._utils._element_size(ssm_dtype) - ) - - total_cell_size = conv_cell_size + ssm_cell_size - - if use_ratio: - # mamba_cache_ratio = mamba_memory / total_cache_memory - mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 - mamba_memory_gb = available_memory * mamba_cache_ratio - else: - mamba_memory_gb = available_memory - mamba_cache_ratio = None - - mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) - - if mamba_cache_size < start_args.running_max_req_size * 2: - ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 - raise ValueError( - f"Insufficient memory for mamba cache allocation!\n\n" - f"mamba_cache_size should be at least running_max_req_size * 2\n" - f"Calculated mamba_cache_size ({mamba_cache_size}) < " - f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n\n" - f"Memory budget:\n" - f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" - f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" - f" Calculated buffers: {mamba_cache_size}\n" - f" Required buffers: {start_args.running_max_req_size}\n\n" - f"Solutions:\n" - f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" - f" 2. Increase --mamba_cache_ratio from {ratio} to " - f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" - f" 3. Increase --mem_fraction to leave more memory for caches\n" - ) - - logger.info( - f"Mamba cache allocation:\n" - f" Available memory: {mamba_memory_gb:.2f} GB\n" - f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" - f" Calculated mamba_cache_size: {mamba_cache_size}" - ) - - return mamba_cache_size - def __init__( self, full_attn_cache_size, @@ -98,31 +17,36 @@ def __init__( num_kv_heads, head_dim, layer_num, - mtp_layer_num, full_attention_interval: int, conv_state_dtype: torch.dtype, - conv_state_shape: Tuple[int, ...], ssm_state_dtype: torch.dtype, - ssm_state_shape: Tuple[int, ...], + conv_kernel_size: int, + num_linear_k_heads: int, + num_linear_v_heads: int, + head_linear_k_dim: int, + head_linear_v_dim: int, max_req_num: int, always_copy=False, mem_fraction=0.9, + network_config: dict = None, ): self.full_attention_interval = full_attention_interval assert layer_num % full_attention_interval == 0 self.layer_num = layer_num - self.mtp_layer_num = mtp_layer_num self.full_attn_layer_num = layer_num // full_attention_interval self.linear_attn_layer_num = layer_num - self.full_attn_layer_num self.mamba_cache_mem_manager = MambaCacheManager( - linear_attn_cache_size, - self.linear_attn_layer_num, - conv_state_dtype, - conv_state_shape, - ssm_state_dtype, - ssm_state_shape, + size=linear_attn_cache_size, + layer_num=self.linear_attn_layer_num, + conv_state_dtype=conv_state_dtype, + ssm_state_dtype=ssm_state_dtype, + conv_kernel_size=conv_kernel_size, + num_linear_k_heads=num_linear_k_heads, + num_linear_v_heads=num_linear_v_heads, + head_linear_k_dim=head_linear_k_dim, + head_linear_v_dim=head_linear_v_dim, ) super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) @@ -144,7 +68,7 @@ def free_all(self): def get_cell_size(self): # Only full attention layers and MTP layers have KV cache - kv_cache_layer_num = self.full_attn_layer_num + self.mtp_layer_num + kv_cache_layer_num = self.full_attn_layer_num return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype) def get_mamba_cache(self, layer_idx: int): diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index b00f57f3ec..bb9452c463 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -65,65 +65,28 @@ def _init_custom(self): def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 - start_args: StartArgs = get_env_start_args() - mamba_cache_size = start_args.mamba_cache_size - - self.num_linear_k_heads = self.config["linear_num_key_heads"] - self.num_linear_v_heads = self.config["linear_num_value_heads"] + self.num_linear_k_heads = self.config["linear_num_key_heads"] // self.tp_world_size_ + self.num_linear_v_heads = self.config["linear_num_value_heads"] // self.tp_world_size_ self.head_linear_k_dim = self.config["linear_key_head_dim"] self.head_linear_v_dim = self.config["linear_value_head_dim"] - - if mamba_cache_size is None: - mamba_cache_size = Qwen3NextHybridMemManager.calculate_mamba_cache_size( - start_args=start_args, - max_total_token_num=self.max_total_token_num, - mem_fraction=self.mem_fraction, - config=self.config, - head_linear_k_dim=self.head_linear_k_dim, - num_linear_k_heads=self.num_linear_k_heads, - head_linear_v_dim=self.head_linear_v_dim, - num_linear_v_heads=self.num_linear_v_heads, - tp_world_size=self.tp_world_size_, - data_type=self.data_type, - ) - else: - if mamba_cache_size < start_args.running_max_req_size * 2: - raise ValueError( - f"Explicitly set mamba_cache_size ({mamba_cache_size}) < " - f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n" - f"Please increase mamba_cache_size to at least {start_args.running_max_req_size * 2}" - ) - conv_kernel_size = self.config["linear_conv_kernel_dim"] - conv_dim = ( - self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads - ) - ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} - if start_args.mamba_ssm_data_type not in ssm_dtype_dict: - raise ValueError( - f"Invalid mamba_ssm_data_type: {start_args.mamba_ssm_data_type}." - f" Must be one of {list(ssm_dtype_dict.keys())}" - ) - self.mem_manager = Qwen3NextHybridMemManager( full_attn_cache_size=self.max_total_token_num, - linear_attn_cache_size=mamba_cache_size, + linear_attn_cache_size=start_args.mamba_cache_size, dtype=self.data_type, num_kv_heads=self.num_kv_heads, head_dim=self.config["head_dim"], layer_num=self.config["n_layer"], - mtp_layer_num=start_args.mtp_step, full_attention_interval=self.config["full_attention_interval"], conv_state_dtype=self.data_type, - conv_state_shape=(conv_dim // self.tp_world_size_, conv_kernel_size - 1), ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type], - ssm_state_shape=( - self.num_linear_v_heads // self.tp_world_size_, - self.head_linear_k_dim, - self.head_linear_v_dim, - ), + conv_kernel_size=conv_kernel_size, + num_linear_k_heads=self.num_linear_k_heads, + num_linear_v_heads=self.num_linear_v_heads, + head_linear_k_dim=self.head_linear_k_dim, + head_linear_v_dim=self.head_linear_v_dim, max_req_num=self.max_req_num, mem_fraction=self.mem_fraction, ) From b85b6caa2756acba8e6d71144cb241b2e3e97479 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 24 Mar 2026 07:15:07 +0000 Subject: [PATCH 149/180] simpliy code --- .../dynamic_prompt/hybrid_radix_cache.py | 25 ++----------------- test_gsmk.py => test/acc/test_gsmk.py | 2 +- 2 files changed, 3 insertions(+), 24 deletions(-) rename test_gsmk.py => test/acc/test_gsmk.py (99%) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 08f6ba3fff..c5f401eea9 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -20,43 +20,22 @@ def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_mana def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): if need_buffer_num > self.buffer_mem_manager.can_use_mem_size: need_evict_buffer_num = need_buffer_num - self.buffer_mem_manager.can_use_mem_size - - release_mems = [] - - def release_mem(mem_index): - release_mems.append(mem_index) - return - release_buffers = [] def release_buffer(buffer_idx): release_buffers.append(buffer_idx) return - self._evict_buffer(need_evict_buffer_num, release_buffer, release_mem) - self.buffer_mem_manager.free(release_buffers) - if len(release_mems) > 0: - mem_index = torch.concat(release_mems) - self.mem_manager.free(mem_index) + self._evict_buffer(need_evict_buffer_num, release_buffer) return - def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback): + def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback): while need_evict_buffer_num > 0: node = self.evict_buffer_set.pop(0) assert node.buffer_idx is not None evict_buffer_callback(node.buffer_idx) node.buffer_idx = None need_evict_buffer_num -= 1 - # 当一个节点的buffer_idx变为None时,事实上无法在后续进行match, - # 但当该节点子节点或者引用数不为0时,仍然需要保留, 否则则应该被删除 - if node.is_leaf() and node.ref_counter == 0: - self.evict_tree_set.discard(node) - evict_token_callback(node.token_mem_index_value) - self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) - parent_node: TreeNode = node.parent - parent_node.remove_child(node) - if parent_node.is_leaf(): - self.evict_tree_set.add(parent_node) return def match_prefix(self, key, update_refs=False): diff --git a/test_gsmk.py b/test/acc/test_gsmk.py similarity index 99% rename from test_gsmk.py rename to test/acc/test_gsmk.py index 78a5aa467f..16ebc6e095 100644 --- a/test_gsmk.py +++ b/test/acc/test_gsmk.py @@ -139,7 +139,7 @@ def parse_args(): parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--num-shots", type=int, default=5) - parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--num-questions", type=int, default=1000) parser.add_argument("--result-file", type=str, default="result.jsonl") parser.add_argument("--data-path", type=str, default="test.jsonl") return parser.parse_args() From 396584581ecddf4972b227b2e05c25cd7d5c6ab8 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 24 Mar 2026 09:24:02 +0000 Subject: [PATCH 150/180] update tp param --- .../meta_weights/parameter_weight.py | 24 +++++++++++++------ .../layer_weights/transformer_layer_weight.py | 3 --- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py index 0afb0ecab2..80d733394b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -1,6 +1,7 @@ import torch from typing import Dict, Optional, Tuple from .base_weight import BaseWeightTpl +from lightllm.utils.dist_utils import get_dp_world_size class ParameterWeight(BaseWeightTpl): @@ -54,19 +55,28 @@ def __init__( self, weight_name: str, data_type: torch.dtype, - split_n_embed: int, bias_name: Optional[str] = None, weight_shape: Optional[Tuple[int, ...]] = None, bias_shape: Optional[Tuple[int, ...]] = None, + dim: int = 0, # the default split dimension is 0 ): - self.split_n_embed = split_n_embed - # Calculate TP-split shapes if full shapes are provided + + assert ( + 0 <= dim < len(weight_shape) + ), f"split dimension: {dim} must be less than the length of weight_shape: {weight_shape}" + n_embed = weight_shape[dim] + tp_world_size = get_dp_world_size() + assert ( + n_embed % tp_world_size == 0 + ), f"weight_shape[{dim}]={weight_shape[dim]} must be divisible by tp_world_size_: {tp_world_size}" + self.dim = dim + self.split_n_embed = n_embed // tp_world_size tp_weight_shape = None tp_bias_shape = None if weight_shape is not None: - tp_weight_shape = (split_n_embed,) + weight_shape[1:] + tp_weight_shape = weight_shape[:dim] + (self.split_n_embed,) + weight_shape[dim + 1 :] if bias_shape is not None: - tp_bias_shape = (split_n_embed,) + bias_shape[1:] + tp_bias_shape = bias_shape[:dim] + (self.split_n_embed,) + bias_shape[dim + 1 :] super().__init__(weight_name, data_type, tp_weight_shape, bias_name, tp_bias_shape) def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: @@ -74,10 +84,10 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: end = self.split_n_embed * (self.tp_rank_ + 1) if self.weight_name in weights: - t_weight = weights[self.weight_name][start:end] + t_weight = weights[self.weight_name].narrow(self.dim, start, end - start) self.weight.copy_(t_weight.to(self.data_type_)) self.weight.load_ok = True if self.bias_name is not None and self.bias_name in weights: - t_bias = weights[self.bias_name][start:end] + t_bias = weights[self.bias_name].narrow(self.dim, start, end - start) self.bias.copy_(t_bias.to(self.data_type_)) self.bias.load_ok = True diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 8e251a4f50..7b5496ad16 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -165,11 +165,9 @@ def _init_gdn_weight(self): quant_method=self.get_quant_method("out_proj_weight"), ) - split_n_embed = self.linear_num_v_heads // self.tp_world_size_ self.linear_dt_bias = TpParameterWeight( weight_name=f"{prefix}.dt_bias", data_type=torch.float32, - split_n_embed=split_n_embed, bias_name=None, weight_shape=(self.linear_num_v_heads,), # Full shape before TP split bias_shape=None, @@ -178,7 +176,6 @@ def _init_gdn_weight(self): self.linear_A_log = TpParameterWeight( weight_name=f"{prefix}.A_log", data_type=torch.float32, - split_n_embed=split_n_embed, bias_name=None, weight_shape=(self.linear_num_v_heads,), # Full shape before TP split bias_shape=None, From ef41d77c246a745fe6c3189f54b9e4ec40493bd4 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 24 Mar 2026 15:48:35 +0000 Subject: [PATCH 151/180] fix: restore tool_calls arguments JSON string to dict conversion Qwen's chat template uses |items filter which expects arguments to be a dict, but OpenAI format sends arguments as a JSON string. This conversion is needed for Jinja template compatibility. --- lightllm/server/build_prompt.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index a38008af6f..ae3ba1b7cf 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -47,6 +47,23 @@ async def build_prompt(request, tools) -> str: # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility + # Qwen's chat template expects arguments to be a dict (uses |items filter) + # but OpenAI format sends arguments as a JSON string + for msg in messages: + tool_calls = msg.get("tool_calls") + if tool_calls and isinstance(tool_calls, list): + for tool_call in tool_calls: + func = tool_call.get("function") + if func and isinstance(func, dict): + args = func.get("arguments") + if isinstance(args, str) and args: + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, TypeError): + # Keep original string if not valid JSON + pass + kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings From 7d0458fd9de51a8b46b622d4f78ca88cd476022a Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 24 Mar 2026 15:48:35 +0000 Subject: [PATCH 152/180] fix: restore tool_calls arguments JSON string to dict conversion Qwen's chat template uses |items filter which expects arguments to be a dict, but OpenAI format sends arguments as a JSON string. This conversion is needed for Jinja template compatibility. --- lightllm/server/build_prompt.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 783411dadf..3b2fdc7096 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -49,6 +49,23 @@ async def build_prompt(request, tools) -> str: # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility + # Qwen's chat template expects arguments to be a dict (uses |items filter) + # but OpenAI format sends arguments as a JSON string + for msg in messages: + tool_calls = msg.get("tool_calls") + if tool_calls and isinstance(tool_calls, list): + for tool_call in tool_calls: + func = tool_call.get("function") + if func and isinstance(func, dict): + args = func.get("arguments") + if isinstance(args, str) and args: + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, TypeError): + # Keep original string if not valid JSON + pass + kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings From 8f1212a6e764a2a720180ac09f80812d79843b29 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 25 Mar 2026 06:45:17 +0000 Subject: [PATCH 153/180] fix build_prompt too --- lightllm/server/build_prompt.py | 42 ++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 3b2fdc7096..ded959e5ca 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -49,22 +49,36 @@ async def build_prompt(request, tools) -> str: # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] - # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility - # Qwen's chat template expects arguments to be a dict (uses |items filter) - # but OpenAI format sends arguments as a JSON string + # 当有工具调用时,content 被设置为None,被exclude_none=True排除, + # 导致后续模板处理无法识别, 这里补齐content字段为"", 以兼容原有模板逻辑。 for msg in messages: + if "content" not in msg: + msg["content"] = "" + + # 对于工具调用的消息,确保 tool_calls 字段存在且格式正确, 以兼容模板中对工具调用的处理逻辑。 + for msg in messages: + if msg.get("role") != "assistant" or "tool_calls" not in msg: + continue tool_calls = msg.get("tool_calls") - if tool_calls and isinstance(tool_calls, list): - for tool_call in tool_calls: - func = tool_call.get("function") - if func and isinstance(func, dict): - args = func.get("arguments") - if isinstance(args, str) and args: - try: - func["arguments"] = json.loads(args) - except (json.JSONDecodeError, TypeError): - # Keep original string if not valid JSON - pass + if not isinstance(tool_calls, list): + continue + # Drop empty tool_calls so templates take the normal assistant path + if len(tool_calls) == 0: + msg.pop("tool_calls", None) + continue + for tool_call in tool_calls: + func = tool_call.get("function") + if not func or not isinstance(func, dict): + continue + args = func.get("arguments") + if args and not isinstance(args, (dict, list)): + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, TypeError): + func["arguments"] = {} + elif not args: + # Missing or empty arguments default to empty dict + func["arguments"] = {} kwargs = {"conversation": messages} if request.character_settings: From 77bfcbaf819075f41def3111ee9a9649ab151e55 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 25 Mar 2026 22:06:38 +0800 Subject: [PATCH 154/180] fix --- lightllm/models/qwen3_5/model.py | 7 ++++--- lightllm/server/build_prompt.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index da79cf32b2..510e95f0bd 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -9,8 +9,9 @@ from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( Qwen3VLMultimodalPreLayerInfer, ) -from lightllm.models.qwen3_5.layer_weights.pre_and_post_layer_weight import Qwen35PreAndPostLayerWeight - +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import ( + Qwen3VLPreAndPostLayerWeight, +) from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( Qwen35TransformerLayerInfer, ) @@ -51,7 +52,7 @@ class Qwen3_5TpPartModel(Qwen3NextTpPartModel): """ transformer_weight_class = Qwen35TransformerLayerWeight - pre_and_post_weight_class = Qwen35PreAndPostLayerWeight + pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer transformer_layer_infer_class = Qwen35TransformerLayerInfer diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index ae3ba1b7cf..974b8a2ee1 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -46,6 +46,22 @@ async def build_prompt(request, tools) -> str: global tokenizer # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility + # Qwen's chat template expects arguments to be a dict (uses |items filter) + # but OpenAI format sends arguments as a JSON string + for msg in messages: + tool_calls = msg.get("tool_calls") + if tool_calls and isinstance(tool_calls, list): + for tool_call in tool_calls: + func = tool_call.get("function") + if func and isinstance(func, dict): + args = func.get("arguments") + if isinstance(args, str) and args: + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, TypeError): + # Keep original string if not valid JSON + pass # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility # Qwen's chat template expects arguments to be a dict (uses |items filter) From 358543298d23183179fca35de6461cc1626db629 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 26 Mar 2026 07:03:35 +0000 Subject: [PATCH 155/180] fix buffer idx --- lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index c5f401eea9..43e275f7d3 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -27,6 +27,8 @@ def release_buffer(buffer_idx): return self._evict_buffer(need_evict_buffer_num, release_buffer) + if len(release_buffers) > 0: + self.buffer_mem_manager.free(release_buffers) return def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback): From 334e3c453e5d65d4c957b0de40fda8837c5d8e0e Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 26 Mar 2026 13:58:42 +0000 Subject: [PATCH 156/180] fix --- lightllm/models/qwen3_5/model.py | 6 +++--- lightllm/server/api_openai.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py index 510e95f0bd..73334be729 100644 --- a/lightllm/models/qwen3_5/model.py +++ b/lightllm/models/qwen3_5/model.py @@ -9,8 +9,8 @@ from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( Qwen3VLMultimodalPreLayerInfer, ) -from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import ( - Qwen3VLPreAndPostLayerWeight, +from lightllm.models.qwen3_5.layer_weights.pre_and_post_layer_weight import ( + Qwen35PreAndPostLayerWeight, ) from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( Qwen35TransformerLayerInfer, @@ -52,7 +52,7 @@ class Qwen3_5TpPartModel(Qwen3NextTpPartModel): """ transformer_weight_class = Qwen35TransformerLayerWeight - pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + pre_and_post_weight_class = Qwen35PreAndPostLayerWeight pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer transformer_layer_infer_class = Qwen35TransformerLayerInfer diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index c4cdfa4ba3..9bc4d26eb0 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -329,8 +329,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req tool_choice = request.tool_choice tools = request.tools if tool_choice != "none" and any([i in full_text for i in TOOLS_TAG_LIST]): - if finish_reason == "stop": - finish_reason = "tool_calls" try: # 为 tool_call_parser 提供默认值 tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3" @@ -353,7 +351,8 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req HTTPStatus.BAD_REQUEST, "Failed to parse fc related info to json format!", ) - if finish_reason == "tool_calls": + if tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" text = "" chat_message = ChatMessage( role="assistant", From 0974ba98886d938b255b716326e388203a6826e7 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 27 Mar 2026 10:06:56 +0000 Subject: [PATCH 157/180] fix --- lightllm/models/qwen3next/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 4de03eced7..e31da60e4d 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -32,7 +32,6 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): # infer class transformer_layer_infer_class = Qwen3NextTransformerLayerInfer - post_layer_infer_class = Qwen3NextPostLayerInfer # infer state class infer_state_class = Qwen3NextInferStateInfo From fe91aa3626850359962b2c8b780ff0390adb58f9 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 31 Mar 2026 06:49:45 +0000 Subject: [PATCH 158/180] add instance_id with improved robustness and code quality Adds multi-instance port isolation to allow multiple LightLLM servers on the same machine without port conflicts. Each instance gets a dedicated 1000-port range (instance 0: 10000-10999, etc.). Changes: - Added --lightllm_instance_id CLI arg (0-7) for instance selection - Refactored port allocation to use deterministic ranges instead of random selection via portpicker - Removed portpicker dependency from requirements.txt - Base port configurable via LIGHTLLM_BASE_PORT env var - Removed SO_REUSEADDR from port probe to avoid false positives - Simplified to single linear scan (removed ineffective retry logic) --- lightllm/server/api_cli.py | 6 ++ lightllm/server/api_start.py | 24 +++---- lightllm/server/core/objs/start_args_type.py | 1 + lightllm/utils/net_utils.py | 69 ++++++++++++-------- requirements.txt | 3 +- 5 files changed, 64 insertions(+), 39 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 092bfd3d85..55fe1a48e2 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -227,6 +227,12 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--nccl_port", type=int, default=None, help="the nccl_port to build a distributed environment for PyTorch" ) + parser.add_argument( + "--lightllm_instance_id", + type=int, + default=0, + help="Instance ID (0~7) for multi-instance port isolation. Each ID maps to a dedicated port range.", + ) parser.add_argument( "--use_config_server_to_init_nccl", action="store_true", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 7508e4265e..6e3177e302 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -92,7 +92,6 @@ def _set_envs_and_config(args: StartArgs): def _launch_subprocesses(args: StartArgs): _set_envs_and_config(args) - set_unique_server_name(args) if args.enable_mps: from lightllm.utils.device_utils import enable_mps @@ -146,12 +145,6 @@ def _launch_subprocesses(args: StartArgs): check_recommended_shm_size(args) assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] - # 确保单机上多实列不冲突 - if args.zmq_mode == "ipc:///tmp/": - zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" - args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 - args.zmq_mode = zmq_mode - logger.info(f"zmq mode head: {args.zmq_mode}") logger.info(f"use tgi api: {args.use_tgi_api}") @@ -280,6 +273,8 @@ def _launch_subprocesses(args: StartArgs): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), + used_nccl_ports=already_uesd_ports, + instance_id=args.lightllm_instance_id, ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -311,6 +306,16 @@ def _launch_subprocesses(args: StartArgs): args.nccl_port = nccl_port if args.pd_decode_rpyc_port is None: args.pd_decode_rpyc_port = pd_decode_rpyc_port + + set_unique_server_name(args) + + # 确保单机上多实列不冲突 + if args.zmq_mode == "ipc:///tmp/": + zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" + args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 + args.zmq_mode = zmq_mode + logger.info(f"zmq mode head: {args.zmq_mode}") + args.router_port = router_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port @@ -462,10 +467,7 @@ def pd_master_start(args: StartArgs): logger.info(f"all start args:{args}") can_use_ports = alloc_can_use_network_port( - num=1, - used_ports=[ - args.port, - ], + num=1, used_nccl_ports=[args.nccl_port, args.port], instance_id=args.lightllm_instance_id ) metric_port = can_use_ports[0] diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index b818b08d4a..91f5af2c7c 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -68,6 +68,7 @@ class StartArgs: max_req_total_len: int = field(default=16384) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=None) + lightllm_instance_id: int = field(default=0) use_config_server_to_init_nccl: bool = field(default=False) trust_remote_code: bool = field(default=False) detail_log: bool = field(default=False) diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 486414e88e..9eb0a1d11c 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -1,43 +1,60 @@ import socket import subprocess import ipaddress -import random -import portpicker +import os from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) +DEFAULT_BASE_PORT = 10000 +PORTS_PER_INSTANCE = 1000 +MAX_INSTANCE_ID = 7 -def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000): - if used_nccl_ports is None: - used_nccl_ports = [] - port_list = [] - max_attempts = num * 50 # Allow more attempts to find ports in range +def _is_port_available(port: int) -> bool: + """Check if a port is available by attempting to bind it.""" + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return True + except OSError: + return False - for _ in range(max_attempts): - if len(port_list) >= num: - break - try: - port = portpicker.pick_unused_port() +def alloc_can_use_network_port(num=3, used_nccl_ports=None, instance_id=0): + """ + Allocate available network ports within an instance-specific range. - if port >= from_port_num and port not in used_nccl_ports: - port_list.append(port) - logger.debug(f"Allocated port: {port}") - else: - logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping") + Each instance gets a dedicated 1000-port range starting from BASE_PORT + (default 10000, override via LIGHTLLM_BASE_PORT env var). + Instance 0: 10000-10999, Instance 1: 11000-11999, etc. + """ + if instance_id < 0 or instance_id > MAX_INSTANCE_ID: + raise ValueError(f"instance_id must be in range [0, {MAX_INSTANCE_ID}], got {instance_id}") - except Exception as e: - logger.warning(f"Failed to allocate port: {e}") - continue + base_port = int(os.environ.get("LIGHTLLM_BASE_PORT", DEFAULT_BASE_PORT)) + range_start = base_port + instance_id * PORTS_PER_INSTANCE + range_end = range_start + PORTS_PER_INSTANCE + used_set = set(used_nccl_ports) if used_nccl_ports else set() - if len(port_list) < num: - logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}") - return None - - logger.info(f"Successfully allocated {len(port_list)} ports: {port_list}") - return port_list + port_list = [] + for port in range(range_start, range_end): + if len(port_list) >= num: + break + if port in used_set: + continue + if _is_port_available(port): + port_list.append(port) + used_set.add(port) + + if len(port_list) >= num: + logger.info(f"Instance {instance_id}: allocated {len(port_list)} ports in [{range_start}, {range_end}): {port_list}") + return port_list + + raise RuntimeError( + f"Failed to allocate {num} ports for instance {instance_id} in range [{range_start}, {range_end}). " + f"Only found {len(port_list)} available. Try a different instance_id or set LIGHTLLM_BASE_PORT." + ) def alloc_can_use_port(min_port, max_port): diff --git a/requirements.txt b/requirements.txt index ebf5fd09cc..bd02e48100 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,5 +94,4 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 -torch_memory_saver==0.0.9 -portpicker==1.6.0 \ No newline at end of file +torch_memory_saver==0.0.9 \ No newline at end of file From f4a0cb741f11073895852db79742ed61be768f5c Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 1 Apr 2026 16:40:55 +0000 Subject: [PATCH 159/180] fix: occasional accuracy drop in rollout --- lightllm/common/basemodel/basemodel.py | 1 + lightllm/common/req_manager.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 466cffa9de..144d7dcf61 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1103,6 +1103,7 @@ def resume_kv_cache(self): gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) self.mem_manager.free_all() + self.req_manager.resume() torch.cuda.synchronize() def resume_graph(self): diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index bbe2bb4a3b..61855044a2 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -95,6 +95,9 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def resume(self): + return + @property def has_recurrent_state(self): """Whether this model uses per-request recurrent state buffers (e.g. Mamba/linear attention).""" @@ -264,3 +267,9 @@ def alloc_buffer_for_req(self, req_index: torch.Tensor): if not buffer_indexes.is_cuda: buffer_indexes = buffer_indexes.cuda() self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) + + def resume(self, req_index: torch.Tensor): + # for rl kv cache resume + self.req_to_buffer_index.zero_() + self.req_to_buffer_index[req_index] = self.buffer_mem_manager.HOLD_BUFFER_INDEX + return From f4caa8f76e04219c742ffb96218e679d4b4e4763 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 1 Apr 2026 17:27:27 +0000 Subject: [PATCH 160/180] reset req manager --- lightllm/common/req_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 61855044a2..ec5d46872c 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -268,8 +268,8 @@ def alloc_buffer_for_req(self, req_index: torch.Tensor): buffer_indexes = buffer_indexes.cuda() self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) - def resume(self, req_index: torch.Tensor): + def resume(self): # for rl kv cache resume self.req_to_buffer_index.zero_() - self.req_to_buffer_index[req_index] = self.buffer_mem_manager.HOLD_BUFFER_INDEX + self.req_to_buffer_index[self.HOLD_REQUEST_ID, :] = self.buffer_mem_manager.HOLD_BUFFER_INDEX return From 8794f432055cf84b00f28938520418b8b5d533e0 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 2 Apr 2026 02:54:26 +0000 Subject: [PATCH 161/180] fix typo --- .../common/basemodel/layer_weights/meta_weights/norm_weight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 9d82cbd029..648f43442f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -282,7 +282,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.k_weight_name in weights: self.k_weight.copy_(weights[self.k_weight_name]) self.k_weight.load_ok = True - del weights[self.q_weight_name] + del weights[self.k_weight_name] def verify_load(self): return self.q_weight.load_ok and self.k_weight.load_ok From 1abf95a2a1485f9f248c93a9aa5facf5985d1832 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 10 Apr 2026 19:59:06 +0800 Subject: [PATCH 162/180] add fp8 rl for qwen35 --- .../layer_weights/transformer_layer_weight.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py index fe4b1883bd..204700fa34 100644 --- a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py @@ -30,6 +30,24 @@ def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_s weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] + if "mlp.experts.gate_proj" in k: + gate_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] + num_experts = gate_weight.shape[0] + + prefix = k.rsplit(".gate_proj", 1)[0] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] + + elif "mlp.experts.up_proj" in k: + up_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] + num_experts = up_weight.shape[0] + + prefix = k.rsplit(".up_proj", 1)[0] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] + elif "mlp.experts.down_proj" in k: down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] num_experts = down_weight.shape[0] From 901bd132996ae8fee83cce6092dc2c98d7d718ef Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 15 Apr 2026 14:18:53 +0000 Subject: [PATCH 163/180] fix abort --- lightllm/server/router/model_infer/mode_backend/base_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 7e0c793c34..b61b1ee25d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -802,7 +802,7 @@ def _get_classed_reqs( paused_reqs.append(req_obj) continue - if req_obj.infer_aborted or req_obj.finish_status.is_finished(): + if req_obj.finish_status.is_finished(): if support_overlap: # 延迟处理 req_obj.filter_mark = True From 8de8baf1e0236813ff9ba99e1b17c8a11fbcaf88 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 16 Apr 2026 13:50:41 +0000 Subject: [PATCH 164/180] add logs for detoken --- .../server/core/objs/out_token_circlequeue.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py index ea99dae5f6..da6c7330f3 100644 --- a/lightllm/server/core/objs/out_token_circlequeue.py +++ b/lightllm/server/core/objs/out_token_circlequeue.py @@ -4,6 +4,9 @@ LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1280)) LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 8)) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class QueueItem(ctypes.Structure): @@ -24,9 +27,18 @@ def __init__(self): def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int): str_bytes = token_str.encode("utf-8") - assert ( - len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES - ), f"Token string {len(str_bytes)} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes." + if len(str_bytes) > LIGHTLLM_TOKEN_MAX_BYTES: + logger.error( + "Token string exceeds max bytes: bytes=%d limit=%d src_index=%d count_output_tokens=%d preview=%s", + len(str_bytes), + LIGHTLLM_TOKEN_MAX_BYTES, + src_index, + count_output_tokens, + token_str, + ) + raise ValueError( + f"Token string {len(str_bytes)} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes." + ) ctypes.memmove(self.data, str_bytes, len(str_bytes)) self.data_len = len(str_bytes) self.src_index = src_index From 2dc39fa2f82608b13ffe41aaadc0f9fae2300b07 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 17 Apr 2026 08:55:34 +0000 Subject: [PATCH 165/180] fix decode overflow --- lightllm/server/core/objs/out_token_circlequeue.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py index da6c7330f3..8e856d432c 100644 --- a/lightllm/server/core/objs/out_token_circlequeue.py +++ b/lightllm/server/core/objs/out_token_circlequeue.py @@ -36,9 +36,7 @@ def set(self, token_str: str, src_index: int, special: bool, count_output_tokens count_output_tokens, token_str, ) - raise ValueError( - f"Token string {len(str_bytes)} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes." - ) + str_bytes = str_bytes[: LIGHTLLM_TOKEN_MAX_BYTES - 1] ctypes.memmove(self.data, str_bytes, len(str_bytes)) self.data_len = len(str_bytes) self.src_index = src_index From 8c2036994babe151e70494f56329e447535e7bfb Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 18 Apr 2026 12:55:31 +0000 Subject: [PATCH 166/180] fix bytes decode --- lightllm/server/core/objs/out_token_circlequeue.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py index 8e856d432c..8019c9a1a1 100644 --- a/lightllm/server/core/objs/out_token_circlequeue.py +++ b/lightllm/server/core/objs/out_token_circlequeue.py @@ -27,16 +27,19 @@ def __init__(self): def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int): str_bytes = token_str.encode("utf-8") - if len(str_bytes) > LIGHTLLM_TOKEN_MAX_BYTES: + max_data_len = max(LIGHTLLM_TOKEN_MAX_BYTES - 1, 0) + if len(str_bytes) > max_data_len: logger.error( "Token string exceeds max bytes: bytes=%d limit=%d src_index=%d count_output_tokens=%d preview=%s", len(str_bytes), - LIGHTLLM_TOKEN_MAX_BYTES, + max_data_len, src_index, count_output_tokens, token_str, ) - str_bytes = str_bytes[: LIGHTLLM_TOKEN_MAX_BYTES - 1] + str_bytes = str_bytes[:max_data_len] + # Ensure truncation never leaves an incomplete UTF-8 sequence. + str_bytes = str_bytes.decode("utf-8", errors="ignore").encode("utf-8") ctypes.memmove(self.data, str_bytes, len(str_bytes)) self.data_len = len(str_bytes) self.src_index = src_index From 1f466c746809ccd460508584fefaa5846dc475ab Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 8 May 2026 13:23:58 +0000 Subject: [PATCH 167/180] remove neo --- .../common/basemodel/attention/base_att.py | 1 - lightllm/common/basemodel/attention/fa3/fp.py | 7 +- lightllm/models/__init__.py | 2 - lightllm/models/neo_chat/__init__.py | 0 .../models/neo_chat/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 159 ------- .../models/neo_chat/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 23 - .../layer_weights/transformer_layer_weight.py | 67 --- lightllm/models/neo_chat/model.py | 53 --- lightllm/models/neo_chat_moe/__init__.py | 0 lightllm/models/neo_chat_moe/infer_struct.py | 103 ----- .../neo_chat_moe/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 201 -------- .../neo_chat_moe/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 23 - .../layer_weights/transformer_layer_weight.py | 86 ---- lightllm/models/neo_chat_moe/model.py | 150 ------ lightllm/models/neo_chat_moe/neo_visual.py | 281 ------------ .../neo_chat_moe/triton_kernel/__init__.py | 0 .../context_attention_fwd_neo.py | 430 ------------------ .../triton_kernel/get_neo_position.py | 191 -------- .../models/neo_chat_moe/vision_process.py | 141 ------ lightllm/server/tokenizer.py | 3 - .../visualserver/model_infer/model_rpc.py | 3 - 25 files changed, 2 insertions(+), 1922 deletions(-) delete mode 100644 lightllm/models/neo_chat/__init__.py delete mode 100644 lightllm/models/neo_chat/layer_infer/__init__.py delete mode 100644 lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/neo_chat/layer_weights/__init__.py delete mode 100644 lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py delete mode 100644 lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py delete mode 100644 lightllm/models/neo_chat/model.py delete mode 100644 lightllm/models/neo_chat_moe/__init__.py delete mode 100644 lightllm/models/neo_chat_moe/infer_struct.py delete mode 100644 lightllm/models/neo_chat_moe/layer_infer/__init__.py delete mode 100644 lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/neo_chat_moe/layer_weights/__init__.py delete mode 100644 lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py delete mode 100644 lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py delete mode 100644 lightllm/models/neo_chat_moe/model.py delete mode 100644 lightllm/models/neo_chat_moe/neo_visual.py delete mode 100644 lightllm/models/neo_chat_moe/triton_kernel/__init__.py delete mode 100644 lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py delete mode 100644 lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py delete mode 100644 lightllm/models/neo_chat_moe/vision_process.py diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 7e6856eeb7..1286a46ec2 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -60,7 +60,6 @@ class AttControl: sliding_window: Tuple[int, int] = (-1, -1) use_att_sink: bool = False sink_weight: torch.Tensor = None - scale: float = None # mla 专用传参项 mla_prefill: bool = False mla_prefill_dict: Dict = None diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 2f5fccd57b..952bb39d91 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -220,11 +220,8 @@ def _normal_decode_att( sink_weight = None k_descale, v_descale = None, None # disable quantization - if att_control.scale is not None: - sm_scale = att_control.scale - else: - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index ecc1cbf491..2caee91709 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -39,8 +39,6 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel -from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel -from lightllm.models.neo_chat.model import NeoTpPartModel from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel from lightllm.models.qwen3_5_moe.model import Qwen3_5MOETpPartModel diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py deleted file mode 100644 index a3436b28ee..0000000000 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -from functools import partial -from typing import Tuple -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo -from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo -from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd -from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer -from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight -from lightllm.distributed import all_reduce -import torch.distributed as dist -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer - - -class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): - def __init__(self, data_type, network_config): - super().__init__(data_type, network_config) - return - - def _bind_attention(self): - self._context_attention_kernel = self._context_attention_kernel - self._token_attention_kernel = self._token_decode_attention_normal - self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal - return - - def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight): - input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) # [T, Hq*D] - - q_hw = layer_weight.q_hw_proj.mm(input) - q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) - q_h, q_w = q_hw.chunk(2, dim=-1) - - k_hw = layer_weight.k_hw_proj.mm(input) - k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) - k_h, k_w = k_hw.chunk(2, dim=-1) - - cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - - layer_weight.q_norm_weight_(q, eps=self.eps_) - - q_h_2d = q_h.reshape(q.shape[0], -1) - q_w_2d = q_w.reshape(q.shape[0], -1) - layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) - layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) - q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - - layer_weight.k_norm_weight_( - cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - eps=self.eps_, - ) - - k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] - k_w_2d = k_w.reshape(q.shape[0], -1) - layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) - layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) - k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) - k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) - - cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - rotary_emb_fwd( - q_h, - k_h, - infer_state.position_cos_h, - infer_state.position_sin_h, - ) - rotary_emb_fwd( - q_w, - k_w, - infer_state.position_cos_w, - infer_state.position_sin_w, - ) - - q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) - q3 = torch.cat([q3, q_h, q_w], dim=-1) - q = q3.reshape(q3.shape[0], -1) - - k = cache_kv[:, : self.tp_k_head_num_, :] - k = torch.cat([k, k_h, k_w], dim=-1) - - v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] - v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) - v = torch.cat([v, v_pad], dim=-1) - - cache_kv = torch.cat([k, v], dim=1) - return q, cache_kv - - def _context_attention_kernel( - self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd_neo( - q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), - infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) - o3 = o3[:, :, : self.head_dim_].contiguous() - return o3.view(o3.shape[0], -1) - - def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - - q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) - - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) - - k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - token_att_fwd( - q_3d, - k_3d, - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_kv_start_loc, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, - ) - - from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( - token_attention_softmax_and_reducev, - ) - - token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd - - v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ - ] - - o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) - - token_softmax_reducev_fwd( - att_m_tensor, - v_3d, - o_3d, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_kv_start_loc, - infer_state.b_seq_len, - ) - return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index e6489f39af..0000000000 --- a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -import numpy as np -from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight - -# add key: language_model.xxx -> xxx -# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now -def rename_weight_keys(weights): - prefix = "language_model." - keys = list(weights.keys()) - for k in keys: - if prefix in k: - weights[k.replace(prefix, "")] = weights.pop(k) - - -class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config): - super().__init__(data_type, network_config) - return - - def load_hf_weights(self, weights): - rename_weight_keys(weights) - super().load_hf_weights(weights) - return diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py deleted file mode 100644 index e62afae9bc..0000000000 --- a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,67 +0,0 @@ -from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ( - QKRMSNORMWeight, - ROWMMWeight, -) - - -class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - super().__init__(layer_num, data_type, network_config, quant_cfg) - return - - def _init_weight_names(self): - super()._init_weight_names() - self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" - self._q_bias_hw_name = None - self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" - self._k_bias_hw_name = None - - self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" - self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" - - self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" - self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" - - def _init_qkv(self): - super()._init_qkv() - self.q_hw_proj = ROWMMWeight( - in_dim=self.network_config_["hidden_size"], - out_dims=[self.q_head_num_ * self.head_dim], - weight_names=self._q_weight_hw_name, - data_type=self.data_type_, - bias_names=self._q_bias_hw_name, - quant_method=self.get_quant_method("q_hw_proj"), - ) - self.k_hw_proj = ROWMMWeight( - in_dim=self.network_config_["hidden_size"], - out_dims=[self.k_head_num_ * self.head_dim], - weight_names=self._k_weight_hw_name, - data_type=self.data_type_, - bias_names=self._k_bias_hw_name, - quant_method=self.get_quant_method("k_hw_proj"), - ) - - def _init_norm(self): - super()._init_norm() - - self.q_norm_h_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, - weight_name=self._q_norm_h_name, - data_type=self.data_type_, - ) - self.q_norm_w_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, - weight_name=self._q_norm_w_name, - data_type=self.data_type_, - ) - self.k_norm_h_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, - weight_name=self._k_norm_h_name, - data_type=self.data_type_, - ) - self.k_norm_w_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, - weight_name=self._k_norm_w_name, - data_type=self.data_type_, - ) diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py deleted file mode 100644 index 14d9f96dc7..0000000000 --- a/lightllm/models/neo_chat/model.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -import json -from lightllm.common.build_utils import repair_config -from lightllm.models.registry import ModelRegistry, llm_model_type_is -from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer -from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer -from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight -from lightllm.models.qwen2_vl.model import QWen2VLTokenizer -from lightllm.models.qwen3.model import Qwen3TpPartModel -from lightllm.server.core.objs import SamplingParams -from lightllm.models.qwen3_moe.model import Qwen3MOEModel -from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem -from lightllm.models.neo_chat_moe.vision_process import smart_resize -from lightllm.models.internvl.model import InternvlTokenizer -from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer -from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight -from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight -from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer -from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo - - -@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3")) -class NeoTpPartModel(Qwen3TpPartModel): - - pre_layer_infer_class = LlamaMultimodalPreLayerInfer - transformer_layer_infer_class = NeoChatTransformerLayerInfer - - pre_and_post_weight_class = NeoChatPreAndPostLayerWeight - transformer_weight_class = NeoChatTransformerLayerWeight - - infer_state_class = NeoChatInferStateInfo - - def __init__(self, kvargs): - super().__init__(kvargs) - return - - def _init_inferstate_cls(self): - pass - - def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: - all_config = json.load(json_file) - self.config = all_config["llm_config"] - # rename keys - repair_config(self.config, same_names=["num_attention_heads", "n_head"]) - repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) - repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) - if self.finetune_config: - self.config["vocab_size"] = self.finetune_config.vocab_size - return diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py deleted file mode 100644 index 961ed2a61d..0000000000 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import Optional, List -import torch -import numpy as np -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager -from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton -from lightllm.models.llama.model import LlamaTpPartModel - - -class NeoChatInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self.position_cos = None - self.position_sin = None - self.position_cos_h = None - self.position_sin_h = None - self.position_cos_w = None - self.position_sin_w = None - - def init_some_extra_state(self, model: LlamaTpPartModel): - LlamaInferStateInfo.init_some_extra_state(self, model) - if self.is_prefill: - self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( - non_blocking=True - ) - self.position_ids = self.get_neo_position(self.multimodal_params) - else: - b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] - for batch_idx, p in enumerate(self.multimodal_params): - position_delta = 0 - for image in p["images"]: - position_delta += image["grid_thwd"][3] - b_position_delta[batch_idx] = position_delta - position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) - self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() - self.position_ids[1:].zero_() - - self.position_ids = self.position_ids.contiguous() - self.position_cos = model._cos_cached[self.position_ids[0]] - self.position_sin = model._sin_cached[self.position_ids[0]] - self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] - self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] - self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] - self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] - return - - def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: - if len(multimodal_params) == 0: - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - return position_ids - b_image_start_idx = [] - b_image_nums = [] - b_image_start_num = [] - b_image_len = [] - image_start_num = 0 - b_image_thwd = [] - - # pad multimodal_params to batch size. - batch_size = self.b_q_seq_len.shape[0] - multimodal_params = multimodal_params + [ - {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) - ] - - for _, p in enumerate(multimodal_params): - images = p.get("images", []) - for img in images: - b_image_start_idx.append(img["start_idx"]) - a = img["start_idx"] - print(f"img start_idx: {a}") - b_image_len.append(img["token_num"]) - b_image_thwd.append(img["grid_thwd"]) - b_image_nums.append(len(images)) - b_image_start_num.append(image_start_num) - image_start_num += len(images) - - # 没有任何图片 - if image_start_num == 0: - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - return position_ids.contiguous() - b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) - b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 - b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) - b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) - b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) - - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - - get_neo_position_triton( - b_image_start_idx=b_image_start_idx, - b_image_thwd=b_image_thwd, - b_image_nums=b_image_nums, - b_image_start_num=b_image_start_num, - b_image_len=b_image_len, - position_ids=position_ids, - b_ready_cache_len=self.b_ready_cache_len, - b_q_seq_len=self.b_q_seq_len, - b_start_loc=self.b_q_start_loc, - b_image_token_tag=self.b_image_token_tag, - ) - return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py deleted file mode 100644 index 6a5259cafd..0000000000 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,201 +0,0 @@ -import torch -from functools import partial -from typing import Tuple -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo -from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo -from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd -from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer -from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight -from lightllm.distributed import all_reduce -import torch.distributed as dist -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.common.basemodel.attention.base_att import AttControl - - -class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, data_type, network_config): - self._is_merge_kv = network_config.get("merge_kv", True) - super().__init__(data_type, network_config) - return - - def _bind_attention(self): - self._context_attention_kernel = self._context_attention_kernel - self._token_attention_kernel = self._token_attention_kernel - return - - def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): - if self._is_merge_kv: - return self._get_qkv_mergekv(input, infer_state, layer_weight) - else: - return self._get_qkv_not_mergekv(input, infer_state, layer_weight) - - def _get_qkv_not_mergekv( - self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight - ): - pass - # input = input.view(-1, self.embed_dim_) - # q = layer_weight.q_proj.mm(input) # [T, Hq*D] - - # q_hw = layer_weight.q_hw_proj.mm(input) - # q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) - # q_h, q_w = q_hw.chunk(2, dim=-1) - - # k_hw = layer_weight.k_hw_proj.mm(input) - # k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) - # k_h, k_w = k_hw.chunk(2, dim=-1) - - # cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - - # layer_weight.q_norm_weight_(q, eps=self.eps_) - - # q_h_2d = q_h.reshape(q.shape[0], -1) - # q_w_2d = q_w.reshape(q.shape[0], -1) - # layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) - # layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) - # q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - # q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - - # layer_weight.k_norm_weight_( - # cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - # eps=self.eps_, - # ) - - # k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] - # k_w_2d = k_w.reshape(q.shape[0], -1) - # layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) - # layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) - # k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) - # k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) - - # cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - # rotary_emb_fwd( - # q.view(-1, self.tp_q_head_num_, self.head_dim_), - # cache_kv[:, : self.tp_k_head_num_, :], - # infer_state.position_cos, - # infer_state.position_sin, - # ) - # rotary_emb_fwd( - # q_h, - # k_h, - # infer_state.position_cos_h, - # infer_state.position_sin_h, - # ) - # rotary_emb_fwd( - # q_w, - # k_w, - # infer_state.position_cos_w, - # infer_state.position_sin_w, - # ) - - # q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) - # q3 = torch.cat([q3, q_h, q_w], dim=-1) - # q = q3.reshape(q3.shape[0], -1) - - # k = cache_kv[:, : self.tp_k_head_num_, :] - # k = torch.cat([k, k_h, k_w], dim=-1) - - # v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] - # v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) - # v = torch.cat([v, v_pad], dim=-1) - - # cache_kv = torch.cat([k, v], dim=1) - # return q, cache_kv - - def _get_qkv_mergekv( - self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight - ): - input = input.view(-1, self.embed_dim_) - - qkv = layer_weight.qkv_proj.mm(input) - q, cache_kv = qkv.split( - [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 - ) - q = q.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) - q_t, q_hw = q.chunk(2, dim=-1) - - cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - k = cache_kv[:, : self.tp_k_head_num_, :] - v = cache_kv[:, self.tp_k_head_num_ :, :] - k_t, k_hw = k.chunk(2, dim=-1) - - q_t_2d = q_t.reshape(q.shape[0], -1) - q_hw_2d = q_hw.reshape(q.shape[0], -1) - layer_weight.q_norm_weight_(q_t_2d, eps=self.eps_) - layer_weight.q_norm_hw_weight_(q_hw_2d, eps=self.eps_) - - q_t = q_t_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - q_hw = q_hw_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - q_h, q_w = q_hw.chunk(2, dim=-1) - - k_t_2d = k_t.reshape(k.shape[0], -1) - k_hw_2d = k_hw.reshape(k.shape[0], -1) - layer_weight.k_norm_weight_(k_t_2d, eps=self.eps_) - layer_weight.k_norm_hw_weight_(k_hw_2d, eps=self.eps_) - k_t = k_t_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) - k_hw = k_hw_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) - k_h, k_w = k_hw.chunk(2, dim=-1) - - rotary_emb_fwd( - q_t, - k_t, - infer_state.position_cos, - infer_state.position_sin, - ) - rotary_emb_fwd( - q_h, - k_h, - infer_state.position_cos_h, - infer_state.position_sin_h, - ) - rotary_emb_fwd( - q_w, - k_w, - infer_state.position_cos_w, - infer_state.position_sin_w, - ) - - q = torch.cat([q_t, q_h, q_w], dim=-1) - q = q.reshape(q.shape[0], -1) - - k = torch.cat([k_t, k_h, k_w], dim=-1) - cache_kv = torch.cat([k, v], dim=1) - return q, cache_kv - - def _context_attention_kernel( - self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd_neo( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] - infer_state.b_req_idx, - infer_state.b_q_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_q_seq_len, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_image_token_tag, - ) - return o_tensor - - def _token_attention_kernel( - self, - q: torch.Tensor, - infer_state: NeoChatInferStateInfo, - layer_weight: NeoChatMOETransformerLayerWeight, - ) -> torch.Tensor: - _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) - _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - att_control = AttControl() - # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5) - o_tensor = infer_state.decode_att_state.decode_att( - q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor - ) - o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, : self.head_dim_].contiguous() - return o_tensor diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index 4b0eae91c3..0000000000 --- a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -import numpy as np -from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight - -# add key: language_model.xxx -> xxx -# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now -def rename_weight_keys(weights): - prefix = "language_model." - keys = list(weights.keys()) - for k in keys: - if prefix in k: - weights[k.replace(prefix, "")] = weights.pop(k) - - -class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config): - super().__init__(data_type, network_config) - return - - def load_hf_weights(self, weights): - rename_weight_keys(weights) - super().load_hf_weights(weights) - return diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py deleted file mode 100644 index 83ec33060c..0000000000 --- a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,86 +0,0 @@ -from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ( - QKRMSNORMWeight, - ROWMMWeight, - RMSNormWeight, -) - - -class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - self._is_merge_kv = network_config.get("merge_kv", True) - super().__init__(layer_num, data_type, network_config, quant_cfg) - return - - def _init_weight_names(self): - super()._init_weight_names() - - if self._is_merge_kv: - self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight" - self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight" - else: - self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" - self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" - - self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" - self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" - - def _init_qkv(self): - super()._init_qkv() - - def _init_norm(self): - hidden_size = self.network_config_["hidden_size"] - self.att_norm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=self._att_norm_weight_name, - data_type=self.data_type_, - ) - self.ffn_norm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=self._ffn_norm_weight_name, - data_type=self.data_type_, - ) - - self.q_norm_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, - weight_name=self._q_norm_name, - data_type=self.data_type_, - ) - self.k_norm_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, - weight_name=self._k_norm_name, - data_type=self.data_type_, - ) - - if self._is_merge_kv: - self.q_norm_hw_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, - weight_name=self._q_norm_hw_name, - data_type=self.data_type_, - ) - self.k_norm_hw_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 2, - weight_name=self._k_norm_hw_name, - data_type=self.data_type_, - ) - else: - self.q_norm_h_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 4, - weight_name=self._q_norm_h_name, - data_type=self.data_type_, - ) - self.q_norm_w_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 4, - weight_name=self._q_norm_w_name, - data_type=self.data_type_, - ) - self.k_norm_h_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 4, - weight_name=self._k_norm_h_name, - data_type=self.data_type_, - ) - self.k_norm_w_weight_ = QKRMSNORMWeight( - dim=self.head_dim // 4, - weight_name=self._k_norm_w_name, - data_type=self.data_type_, - ) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py deleted file mode 100644 index cf4404090f..0000000000 --- a/lightllm/models/neo_chat_moe/model.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import json -from lightllm.common.build_utils import repair_config -from lightllm.models.registry import ModelRegistry, llm_model_type_is -from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer -from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer -from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight -from lightllm.models.qwen2_vl.model import QWen2VLTokenizer -from lightllm.models.qwen3.model import Qwen3TpPartModel -from lightllm.server.core.objs import SamplingParams -from lightllm.models.qwen3_moe.model import Qwen3MOEModel -from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem -from lightllm.models.neo_chat_moe.vision_process import smart_resize -from lightllm.models.internvl.model import InternvlTokenizer -from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer -from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight -from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight -from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer -from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo - -IMG_START_TOKEN = "" -IMG_END_TOKEN = "" -IMG_TOKEN = "" -AUDIO_START_TOKEN = "" - - -class NeoChatTokenizer(BaseMultiModalTokenizer): - def __init__(self, tokenizer, model_cfg, **kwargs): - super().__init__(tokenizer) - self.tokenizer = tokenizer - self.min_pixel = model_cfg.get("vision_config").get("min_pixels") - self.max_pixel = model_cfg.get("vision_config").get("max_pixels") - self.patch_size = model_cfg.get("vision_config").get("patch_size") - self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") - - self.image_token_id = model_cfg.get("image_token_id") - self.image_start_tag = IMG_START_TOKEN - self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) - self.image_end_tag = IMG_END_TOKEN - self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) - - def init_imageitem_extral_params( - self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams - ): - img.extra_params["min_pixels"] = ( - sampling_params.min_pixels if sampling_params.min_pixels > 0 else self.min_pixel - ) - img.extra_params["max_pixels"] = ( - sampling_params.max_pixels if sampling_params.max_pixels > 0 else self.max_pixel - ) - assert ( - img.extra_params["min_pixels"] <= img.extra_params["max_pixels"] - ), "min_pixels should be less than or equal to max_pixels" - return - - def init_audioitem_extral_params( - self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams - ): - raise NotImplementedError - - def get_audio_token_length(self, audio: AudioItem): - raise NotImplementedError - - def get_image_token_length(self, img: ImageItem): - width, height = img.image_w, img.image_h - resized_height, resized_width = smart_resize( - height=height, - width=width, - factor=int(self.patch_size // self.downsample_ratio), - min_pixels=img.extra_params["min_pixels"], - max_pixels=img.extra_params["max_pixels"], - ) - grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) - # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 - img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) - return token_num - - # only change the impl of the encode func: - def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): - # TEXTTEXTTEXT --> TEXTTEXTTEXT - image_tokens = IMG_START_TOKEN + IMG_END_TOKEN - if multimodal_params is None: - add_special_tokens = kwargs.get("add_special_tokens", True) - return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) - image_count = len(multimodal_params.images) - if not kwargs.get("already_tokenized", False): - prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) - origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) - else: - origin_ids = prompt - # --> id,id+1...id+num - input_ids = [] - image_id = 0 - start_idx = 0 - while True: - try: - start_idx = origin_ids.index(self.image_start_id) - if start_idx + 1 >= len(origin_ids): - break - if origin_ids[start_idx + 1] == self.image_end_id: - input_ids.extend(origin_ids[: start_idx + 1]) - token_id = multimodal_params.images[image_id].token_id - token_num = multimodal_params.images[image_id].token_num - multimodal_params.images[image_id].start_idx = len(input_ids) - input_ids.extend(range(token_id, token_id + token_num)) - input_ids.append(self.image_end_id) - origin_ids = origin_ids[start_idx + 2 :] - image_id += 1 - else: - raise ValueError("image token error") - except ValueError: - break - input_ids.extend(origin_ids) - return input_ids - - -@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe")) -class NeoTpMOEPartModel(Qwen3MOEModel): - - pre_layer_infer_class = LlamaMultimodalPreLayerInfer - transformer_layer_infer_class = NeoChatMOETransformerLayerInfer - - pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight - transformer_weight_class = NeoChatMOETransformerLayerWeight - - infer_state_class = NeoChatInferStateInfo - - def __init__(self, kvargs): - super().__init__(kvargs) - return - - def _init_inferstate_cls(self): - pass - - def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: - all_config = json.load(json_file) - self.config = all_config["llm_config"] - # rename keys - repair_config(self.config, same_names=["num_attention_heads", "n_head"]) - repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) - repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) - if self.finetune_config: - self.config["vocab_size"] = self.finetune_config.vocab_size - return diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py deleted file mode 100644 index 60fa82f2b9..0000000000 --- a/lightllm/models/neo_chat_moe/neo_visual.py +++ /dev/null @@ -1,281 +0,0 @@ -import os -import torch -import torch.nn.functional as F -from PIL import Image -from typing import List -from io import BytesIO -import torch.nn as nn -from transformers.activations import ACT2FN -from safetensors import safe_open -from lightllm.server.multimodal_params import ImageItem -from transformers.modeling_outputs import BaseModelOutputWithPooling -from transformers.modeling_utils import PreTrainedModel -from lightllm.models.neo_chat_moe.vision_process import load_image_native -from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data - - -def apply_rotary_emb_1d( - x: torch.Tensor, - cos_cached: torch.Tensor, - sin_cached: torch.Tensor, - positions: torch.Tensor, -): - """对输入张量的一部分应用1D RoPE。""" - # x: (..., seq_len, dim_part) - # positions: (..., seq_len) - # cos_cached: (max_pos, dim_part / 2) - cos_cached = cos_cached.to(device=positions.device) - sin_cached = sin_cached.to(device=positions.device) - - cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) - sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) - - x1 = x[..., 0::2] - x2 = x[..., 1::2] - - rotated_x1 = x1 * cos - x2 * sin - rotated_x2 = x1 * sin + x2 * cos - - x_rotated = torch.empty_like(x) - x_rotated[..., 0::2] = rotated_x1 - x_rotated[..., 1::2] = rotated_x2 - return x_rotated - - -def apply_2d_rotary_pos_emb( - x: torch.Tensor, - cos_cached_x: torch.Tensor, - sin_cached_x: torch.Tensor, - cos_cached_y: torch.Tensor, - sin_cached_y: torch.Tensor, - abs_positions_x: torch.Tensor, - abs_positions_y: torch.Tensor, -): - """应用2D RoPE到输入张量x。""" - dim = x.shape[-1] - dim_half = dim // 2 - - # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 - # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) - x_part_1 = x[..., :dim_half] - x_part_2 = x[..., dim_half:] - - # 将与 abs_positions_x 相关的旋转应用于 x_part_1 - rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) - # 将与 abs_positions_y 相关的旋转应用于 x_part_2 - rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) - - # 将它们重新拼接起来。确保顺序与你分割时一致。 - return torch.cat((rotated_part_1, rotated_part_2), dim=-1) - - -def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): - """ - Compute patch coordinates (x, y) - - Args: - grid_hw: (B, 2) tensor representing (H, W) per image - """ - device = grid_hw.device - B = grid_hw.shape[0] - - # Get the number of patches per image - H = grid_hw[:, 0] - W = grid_hw[:, 1] - N = H * W - N_total = N.sum() - - # Create the batch index for each patch (B x patch count) - patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) - - # Generate intra-image patch index (row-major order) - patch_id_within_image = torch.arange(N_total, device=device) - patch_id_within_image = ( - patch_id_within_image - - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] - ) - - # Get H/W for each patch according to its image - W_per_patch = W[patch_to_sample] - abs_x = patch_id_within_image % W_per_patch - abs_y = patch_id_within_image // W_per_patch - - return abs_x, abs_y - - -class NeoVisionTransformerPretrainedModel(nn.Module): - def __init__( - self, - kvargs, - hidden_size: int = 1024, - llm_hidden_size: int = 2048, - downsample_ratio: float = 0.5, - patch_size: int = 16, - num_channels: int = 3, - max_position_embeddings_vision: int = 10000, - rope_theta_vision: float = 10000.0, - min_pixels: int = 65536, - max_pixels: int = 2408448, - **kwargs, - ): - super().__init__() - self.weight_dir = kvargs["weight_dir"] - self.data_type = kvargs.get("data_type", "bfloat16") - self.embed_dim = hidden_size - self.llm_hidden_size = llm_hidden_size - self.patch_size = patch_size - self.num_channels = num_channels - self.downsample_ratio = downsample_ratio - self.downsample_factor = int(1 / downsample_ratio) - self.max_position_embeddings_vision = max_position_embeddings_vision - self.rope_theta_vision = rope_theta_vision - self.rope_dim_part = self.embed_dim // 2 - self.min_pixels = min_pixels - self.max_pixels = max_pixels - - self.patch_embedding = nn.Conv2d( - in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size - ) - - self.dense_embedding = nn.Conv2d( - in_channels=self.embed_dim, - out_channels=self.llm_hidden_size, - kernel_size=self.downsample_factor, - stride=self.downsample_factor, - ) - self.gelu = nn.GELU() - - self.repe_dim_part = self.embed_dim // 2 - self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() - self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() - self._init_datatype() - - def _init_datatype(self): - if isinstance(self.data_type, torch.dtype): - return - if self.data_type in ["fp16", "float16"]: - self.data_type = torch.float16 - elif self.data_type in ["bf16", "bfloat16"]: - self.data_type = torch.bfloat16 - elif self.data_type in ["fp32", "float32"]: - self.data_type = torch.float32 - else: - raise ValueError(f"Unsupport datatype {self.data_type}!") - return - - def load_model(self, weight_dir): - bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] - if bin_weight_files: - weight_dict = {} - for file_ in bin_weight_files: - f = torch.load(os.path.join(weight_dir, file_), "cpu") - for k, v in f.items(): - if "vision_model" in k: - weight_dict[k[len("vision_model.embeddings.") :]] = v - else: - hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] - weight_dict = {} - for file_ in hf_weight_files: - f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") - for k in f.keys(): - if "vision_model" in k: - weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) - self.load_state_dict(weight_dict) - - def precompute_rope_freqs_sincos(self): - inv_freq = 1.0 / ( - self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) - ) - t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) - freqs = torch.outer(t, inv_freq) - return torch.cos(freqs), torch.sin(freqs) - - def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): - """ - Apply 2D Rotary Position Embedding to the patch embeddings. - """ - abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) - embeddings = apply_2d_rotary_pos_emb( - patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 - self.cos_x, - self.sin_x, - self.cos_y, - self.sin_y, - abs_pos_x, - abs_pos_y, - ).to(self.patch_embedding.weight.dtype) - return embeddings - - def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: - pixel_values = pixel_values.view( - -1, - 3, - self.patch_size, - self.patch_size, - ) - patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) - patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) - assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ - 0 - ], "Grid size and patch embeds size mismatch." - - patches_list = [] - cur_position = 0 - for i in range(grid_hw.shape[0]): - h, w = grid_hw[i] - patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) - patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) - patches_per_img = patches_per_img.permute(0, 2, 3, 1) - patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) - cur_position += h * w - - embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) - assert cur_position == patch_embeds.shape[0] - assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) - - return embeddings - - def encode(self, images: List[ImageItem]): - img_tensors = [] - valid_ids = [] - valid_id = 0 - img_grids = [] - uuids = [] - - for i, img in enumerate(images): - if isinstance(img, ImageItem): - uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) - image_data = Image.open(BytesIO(image_data)) - # a = img.extra_params["min_pixels"] - # b = img.extra_params["max_pixels"] - # print(f"self.min_pixels is {a} ,max_pixelx is {b}") - pixel_values, image_grid_hw = load_image_native( - image_data, - patch_size=self.patch_size, - downsample_ratio=self.downsample_ratio, - min_pixels=img.extra_params["min_pixels"], - max_pixels=img.extra_params["max_pixels"], - ) - img_tensors.append(pixel_values) - img_grids.append(image_grid_hw) - else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - - # must devide merge_length - cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) - valid_ids.append([valid_id, valid_id + cur_num]) - valid_id += cur_num - - if len(img_tensors) <= 0: - return None - - imgs = torch.cat(img_tensors, dim=0) - grid_hw = torch.cat(img_grids, dim=0) - - pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) - image_grid_hw = grid_hw.to("cuda", non_blocking=True) - - all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) - - return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py deleted file mode 100644 index 74ff82cae4..0000000000 --- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py +++ /dev/null @@ -1,430 +0,0 @@ -import math -import torch -import triton -import triton.language as tl - -from lightllm.utils.device_utils import is_tesla - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - Out, - position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] - B_Start_Loc, - B_Seqlen, - Req_to_tokens, - B_req_idx, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - kv_group_num, - b_prompt_cache_len, - b_image_token_tag, - H: tl.constexpr, - QK_HEAD_DIM: tl.constexpr, - V_HEAD_DIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - cur_bh = tl.program_id(1) - cur_batch = cur_bh // H - cur_head = cur_bh % H - - cur_kv_head = cur_head // kv_group_num - - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) - total_len = tl.load(B_Seqlen + cur_batch) - cur_batch_seq_len = total_len - prompt_cache_len # NEW len - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - block_start_loc = BLOCK_M * start_m - if block_start_loc >= cur_batch_seq_len: - return - - offs_n = tl.arange(0, BLOCK_N) - offs_d_qk = tl.arange(0, QK_HEAD_DIM) - offs_d_v = tl.arange(0, V_HEAD_DIM) - offs_m = block_start_loc + tl.arange(0, BLOCK_M) - - # Q pointers - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d_qk[None, :] * stride_qd - ) - - q_valid = offs_m < cur_batch_seq_len - q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) - - # online softmax state - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) - block_end_loc = total_len - - # absolute q positions in the request - q_pos = prompt_cache_len + offs_m # [M] - q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False) - - for start_n in range(0, block_end_loc, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - k_pos = start_n + offs_n # [N] - k_valid = k_pos < block_end_loc - - # map logical pos -> mem_index (for K/V) - kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, - mask=k_valid, - other=0, - ).to(tl.int64) - - # load K - off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk[:, None] * stride_kd - k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - - # mask: causal OR same gid (only possible inside NEW part) - mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None] - qk = tl.where(mask, qk * sm_scale, -1.0e8) - - # online softmax - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - acc = acc * alpha[:, None] - - # load V - off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d_v[None, :] * stride_vd - v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) - - p = p.to(v.dtype) - acc = tl.dot(p, v, acc) - - m_i = m_ij - - acc = acc / l_i[:, None] - - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d_v[None, :] * stride_od - ) - tl.store(Out + off_o, acc, mask=q_valid[:, None]) - - -@torch.no_grad() -def context_attention_fwd_neo( - q, - k, - v, - o, - position_ids, # 1D packed like q (only NEW tokens) - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_input_len, - req_to_token_indexs, - b_image_token_tag, -): - # minimal safety: position_ids must cover packed q rows - assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) - - BLOCK_M = 128 if not is_tesla() else 64 - - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} - sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 - - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) - - BLOCK_N = BLOCK_M - num_warps = 4 if Lk <= 64 else 8 - num_stages = 1 - - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - o, - position_ids, - b_start_loc, - b_seq_len, - req_to_token_indexs, - b_req_idx, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - kv_group_num=kv_group_num, - b_prompt_cache_len=b_prompt_cache_len, - b_image_token_tag=b_image_token_tag, - H=head, - QK_HEAD_DIM=Lk, - V_HEAD_DIM=Lk, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def reference_attention( - q, - k, - v, - position_ids_q, # 1D packed like q (only NEW tokens) - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - req_to_token_indexs, -): - device = q.device - dtype = q.dtype - sum_q, Hq, D = q.shape - Hk = k.shape[1] - kv_group_num = Hq // Hk - - batch = b_seq_len.shape[0] - out = torch.empty_like(q) - scale = 1.0 / math.sqrt(D) - - for b in range(batch): - req = int(b_req_idx[b].item()) - total_len = int(b_seq_len[b].item()) - prompt_len = int(b_prompt_cache_len[b].item()) - new_len = total_len - prompt_len - - q_start = int(b_start_loc[b].item()) - q_blk = q[q_start : q_start + new_len] # [M, Hq, D] - gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] - - # gather K/V for full request by logical pos -> mem_index - token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] - k_blk = k[token_locs] # [L, Hk, D] - v_blk = v[token_locs] # [L, Hk, D] - - # expand kv heads to q heads (GQA) - k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] - v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] - - # positions - q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] - k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] - - # build allow mask: - # causal always - allow = k_pos[None, :] <= q_pos[:, None] - - # full-attn only inside NEW part by gid - # compare only when k_pos in NEW - k_in_new = k_pos >= prompt_len - k_rel = (k_pos - prompt_len).clamp_min(0) # [L] - # map k_rel to gid_new, but only valid where k_in_new - k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) - k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new - k_gid[k_in_new] = gid_new[k_rel[k_in_new]] - - allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) - - # scores: [Hq, M, L] - q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] - k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] - scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] - - neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) - scores = torch.where(allow[None, :, :], scores, neg) - - p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] - v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] - out_hq = torch.matmul(p, v_t) # [Hq, M, D] - out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] - - out[q_start : q_start + new_len] = out_blk - - return out - - -def make_test_case( - device="cuda", - dtype=torch.float16, - batch=3, - Hq=8, - Hk=4, - D=64, - seed=0, - base_index=50000, -): - torch.manual_seed(seed) - - # prompt (cached) len and new len - prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) - new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) - total_lens = (prompt_lens + new_lens).to(torch.int32) - - max_total_len = int(total_lens.max().item()) - max_new_len = int(new_lens.max().item()) - - # packed q start - b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) - cur = 0 - for b in range(batch): - b_start_loc[b] = cur - cur += int(new_lens[b].item()) - sum_q = cur - - b_seq_len = total_lens - b_prompt_cache_len = prompt_lens.to(torch.int32) - - # one req per batch - num_req = batch - b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) - - # global KV space large, indices not small - sum_kv = int(total_lens.sum().item()) - kv_size = base_index + sum_kv + 1024 - pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index - - # Req_to_tokens [num_req, max_total_len] - req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) - p = 0 - for r in range(num_req): - L = int(total_lens[r].item()) - req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) - p += L - - # position_ids_q: only NEW tokens, packed like q - position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) - for b in range(batch): - M = int(new_lens[b].item()) - start = int(b_start_loc[b].item()) - - gid = torch.arange(M, device=device, dtype=torch.int32) - - # make one repeated block inside NEW part to simulate image tokens - if M >= 4 and torch.rand((), device=device).item() > 0.3: - s = int(torch.randint(0, M - 2, (1,), device=device).item()) - e = min(M, s + 3) - gid[s:e] = gid[s] - - position_ids_q[start : start + M] = gid - - q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) - k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) - v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) - o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) - - return ( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) - - -def check_once(device="cuda", dtype=torch.float16, seed=0): - ( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) = make_test_case(device=device, dtype=dtype, seed=seed) - - context_attention_fwd_neo( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) - - ref = reference_attention( - q, - k, - v, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - req_to_token_indexs, - ) - - diff = (o - ref).abs() - max_abs = diff.max().item() - denom = ref.abs().max().item() + 1e-6 - max_rel = max_abs / denom - - print(f"seed={seed}, dtype={dtype}") - print(f"max_abs_error = {max_abs:.6e}") - print(f"max_rel_error = {max_rel:.6e}") - print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) - - -if __name__ == "__main__": - if not torch.cuda.is_available(): - print("No CUDA, skip.") - else: - torch.cuda.synchronize() - check_once(dtype=torch.bfloat16, seed=0) - check_once(dtype=torch.bfloat16, seed=1) - check_once(dtype=torch.bfloat16, seed=2) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py deleted file mode 100644 index 1a3d4af73b..0000000000 --- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py +++ /dev/null @@ -1,191 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _get_neo_position_triton( - b_image_start_idx: torch.Tensor, - b_image_thwd: torch.Tensor, - b_image_thwd_stride0: torch.Tensor, - b_image_nums: torch.Tensor, - b_image_start_num: torch.Tensor, - b_image_len: torch.Tensor, - position_ids: torch.Tensor, - position_ids_stride0: torch.Tensor, - b_ready_cache_len: torch.Tensor, - b_q_seq_len: torch.Tensor, - b_start_loc: torch.Tensor, - b_image_token_tag: torch.Tensor, - BLOCK_SIZE: tl.constexpr, -) -> torch.Tensor: - cur_batch = tl.program_id(0) - cache_len = tl.load(b_ready_cache_len + cur_batch) - q_seq_len = tl.load(b_q_seq_len + cur_batch) - image_num = tl.load(b_image_nums + cur_batch) - image_start_num = tl.load(b_image_start_num + cur_batch) - start_loc = tl.load(b_start_loc + cur_batch) - for i in range(image_num): - local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) - image_start_idx = start_loc + local_image_start_idx - cache_len - image_len = tl.load(b_image_len + image_start_num + i) - # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) - image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) - for j in range(0, image_len, BLOCK_SIZE): - off = j + tl.arange(0, BLOCK_SIZE) - # 目前没考虑视频,所以t 恒为 0 - t_pos = local_image_start_idx + off * 0 - h_pos = off // image_w - w_pos = off % image_w - tl.store( - b_image_token_tag + off + image_start_idx, - True, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - tl.store( - position_ids + off + image_start_idx, - t_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - tl.store( - position_ids + position_ids_stride0 + off + image_start_idx, - h_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - tl.store( - position_ids + position_ids_stride0 * 2 + off + image_start_idx, - w_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - - for i in range(image_num): - local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) - image_len = tl.load(b_image_len + image_start_num + i) - image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) - image_end = local_image_start_idx + image_len - cache_len - text_start = tl.maximum(0, image_end) - for j in range(text_start, q_seq_len, BLOCK_SIZE): - off = j + tl.arange(0, BLOCK_SIZE) - t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta - h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) - w_pos = tl.load( - position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 - ) - tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) - tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) - tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) - return - - -def get_neo_position_triton( - b_image_start_idx: torch.Tensor, - b_image_thwd: torch.Tensor, - b_image_nums: torch.Tensor, - b_image_start_num: torch.Tensor, - b_image_len: torch.Tensor, - position_ids: torch.Tensor, - b_ready_cache_len: torch.Tensor, - b_q_seq_len: torch.Tensor, - b_start_loc: torch.Tensor, - b_image_token_tag: torch.Tensor, -) -> torch.Tensor: - - batch_size = b_q_seq_len.shape[0] - assert batch_size == b_image_nums.shape[0] - grid = (batch_size,) - BLOCK_SIZE = 64 - _get_neo_position_triton[grid]( - b_image_start_idx=b_image_start_idx, - b_image_thwd=b_image_thwd, - b_image_thwd_stride0=b_image_thwd.stride(0), - b_image_nums=b_image_nums, - b_image_start_num=b_image_start_num, - b_image_len=b_image_len, - position_ids=position_ids, - position_ids_stride0=position_ids.stride(0), - b_ready_cache_len=b_ready_cache_len, - b_q_seq_len=b_q_seq_len, - b_start_loc=b_start_loc, - b_image_token_tag=b_image_token_tag, - BLOCK_SIZE=BLOCK_SIZE, - ) - - -def test(): - b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") - b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") - b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") - b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") - position_ids = ( - torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") - .unsqueeze(0) - .expand(3, -1) - .contiguous() - ) - b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda") - position_ids[1:].zero_() - b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") - b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") - b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") - get_neo_position_triton( - b_image_start_idx, - b_image_thwd, - b_image_nums, - b_image_start_num, - b_image_len, - position_ids, - b_ready_cache_len, - b_q_seq_len, - b_start_loc, - b_image_token_tag, - ) - - print(b_image_token_tag) - print(position_ids) - # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) - - # position_ids = ( - # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") - # .unsqueeze(0) - # .expand(3, -1) - # .contiguous() - # ) - # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") - # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") - # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") - - # get_neo_position_triton( - # b_image_start_idx, - # b_image_thwd, - # b_image_nums, - # b_image_start_num, - # b_image_len, - # position_ids, - # b_ready_cache_len, - # b_q_seq_len, - # b_start_loc, - # ) - - # print(f"old_value:\n{old_value}") - # print(f"position_ids:\n{position_ids}") - # assert torch.equal(old_value, position_ids) - - """ - tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], - [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], - [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], - device='cuda:0', dtype=torch.int32) - """ - - -if __name__ == "__main__": - test() diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py deleted file mode 100644 index fbd57a5e9c..0000000000 --- a/lightllm/models/neo_chat_moe/vision_process.py +++ /dev/null @@ -1,141 +0,0 @@ -import re -import math -import torch -import string -import numpy as np -import pandas as pd -from PIL import Image -import torch.distributed as dist -import torchvision.transforms as T - -IMAGENET_MEAN = (0.485, 0.456, 0.406) -IMAGENET_STD = (0.229, 0.224, 0.225) - - -def round_by_factor(number: int, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - - -# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 -def smart_resize( - height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 -) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - - 3. The aspect ratio of the image is maintained as closely as possible. - """ - if max(height, width) / min(height, width) > 200: - raise ValueError( - f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" - ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = max(factor, floor_by_factor(height / beta, factor)) - w_bar = max(factor, floor_by_factor(width / beta, factor)) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) - return h_bar, w_bar - - -def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): - width, height = image.size - resized_height, resized_width = smart_resize( - height, - width, - factor=size_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - image = image.resize((resized_width, resized_height)) - - return image - - -def preprocess_pixel_values(pixel_values, patch_size=16): - c, h, w = pixel_values.shape - grid_h = h // patch_size - grid_w = w // patch_size - - flatten_pixel_values = ( - pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) - .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] - .reshape(grid_h * grid_w, c * patch_size ** 2) - ) - - grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) - - return flatten_pixel_values, grid_hw - - -def get_contrasting_background(image): - """ - Calculate the color (white or black) that is different from the average foreground color - to use as the background color - """ - image_np = np.array(image) - if (image_np[:, :, 3] == 0).any(): - non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] - if non_transparent_pixels.size == 0: - return None - pixel_mean = non_transparent_pixels.mean() - contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) - return contrasting_color - else: - return None - - -def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): - """ - Load and preprocess an image file, converting it to RGB mode, - resizing, normalizing, and optionally adding a thumbnail version. - """ - if image.mode == "RGBA": - bg_color = get_contrasting_background(image) - if bg_color: - background = Image.new("RGB", image.size, bg_color) - background.paste(image, mask=image.split()[3]) - image = background.convert("RGB") - else: - image = image.convert("RGB") - else: - image = image.convert("RGB") - - if upscale: - image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) - - transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), - ] - ) - - new_image = dynamic_preprocess_native_resolution( - image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels - ) - pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) - - # print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") - - return pixel_values, grid_hw diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 8ddddc5428..25726b2578 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -30,7 +30,6 @@ from ..models.qwen2_vl.model import QWen2VLTokenizer from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer -from ..models.neo_chat_moe.model import NeoChatTokenizer from ..models.gemma3.model import Gemma3Tokenizer from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer @@ -131,7 +130,5 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) - elif model_type == "neo_chat": - tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) return tokenizer diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index c56a801706..92ca2e3836 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,7 +19,6 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env @@ -98,8 +97,6 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() - elif self.model_type == "neo_chat": - self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() elif ( model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") == "qwen3_omni_moe_vision_encoder" From 9e54f209f9932e4714c3574ae8924b9ba267a614 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 8 May 2026 13:36:53 +0000 Subject: [PATCH 168/180] remove unused code --- lightllm/common/basemodel/basemodel.py | 7 ------- .../common/basemodel/layer_weights/hf_load_utils.py | 3 +-- .../layer_weights/meta_weights/embedding_weight.py | 3 --- .../meta_weights/fused_moe/ep_redundancy.py | 4 +--- .../meta_weights/fused_moe/fused_moe_weight.py | 12 ++---------- .../meta_weights/mm_weight/mm_weight.py | 2 -- .../layer_weights/meta_weights/norm_weight.py | 6 ------ .../layer_weights/meta_weights/parameter_weight.py | 4 ++++ 8 files changed, 8 insertions(+), 33 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index de4ba34c5d..241ff07780 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -59,9 +59,6 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo - def get_radix_class(self): - return RadixCache - def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] @@ -1223,9 +1220,6 @@ def resume_kv_cache(self): torch.cuda.empty_cache() gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) - self.mem_manager.free_all() - self.req_manager.resume() - torch.cuda.synchronize() def resume_graph(self): torch.cuda.empty_cache() @@ -1238,4 +1232,3 @@ def resume_all(self): self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) - self.mem_manager.free_all() diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index ec0e282844..8cf66a5ad6 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -5,8 +5,6 @@ from tqdm import tqdm import lightllm.utils.petrel_helper as utils from lightllm.utils.dist_utils import get_current_device_id -from queue import Queue -from threading import Thread def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): @@ -67,6 +65,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str) + for _ in iterator: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index 9fef2b9084..d94a4c709b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -36,7 +36,6 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) self.weight.load_ok = True - del weights[self.weight_name] def verify_load(self): return self.weight.load_ok @@ -116,7 +115,6 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) self.weight.load_ok = True - del weights[self.weight_name] def verify_load(self): return self.weight.load_ok @@ -175,7 +173,6 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"max_position_embeddings: {loaded_max_position_embeddings} != expected: {self.max_position_embeddings}" self.weight.copy_(t_weight.to(self.data_type_)) self.weight.load_ok = True - del weights[self.weight_name] def verify_load(self): return self.weight.load_ok diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py index 98aa4b71bb..749400c8d8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py @@ -69,13 +69,11 @@ def load_hf_weights(self, weights): w3_weight = f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w3_weight_name}.weight" if w1_weight in weights: self.experts_gate_projs[i] = weights[w1_weight] - del weights[w1_weight] if w3_weight in weights: self.experts_up_projs[i] = weights[w3_weight] - del weights[w3_weight] if w2_weight in weights: self.w2_list[i] = weights[w2_weight] - del weights[w2_weight] + self._load_weight_scale(weights) self._fuse() diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 7bb31a6454..6ca48299f0 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -268,7 +268,6 @@ def load_hf_weights(self, weights): # Load bias if self.e_score_correction_bias_name in weights: self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) - del weights[self.e_score_correction_bias_name] self._load_weight(self.expert_idx_to_local_idx, weights) if self.redundancy_expert_num > 0: self._load_weight(self.redundancy_expert_idx_to_local_idx, weights) @@ -314,13 +313,12 @@ def _get_expert_weight_list(self, weight_pack: WeightPack): return weight_list def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): - + # for merged weights + self._load_merge_weight(weights) # Load each expert with TP slicing for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: self._load_expert(expert_idx, local_expert_idx, weights) - # for rl updated weight - self._load_merge_weight(weights) self._load_expert_scale( expert_idx, local_expert_idx, @@ -346,13 +344,10 @@ def _load_expert( col_slice_func = self.col_slicer._slice_weight if w1_weight in weights: self.quant_method.load_weight(row_slice_func(weights[w1_weight]), self.w1_list[local_expert_idx]) - del weights[w1_weight] if w3_weight in weights: self.quant_method.load_weight(row_slice_func(weights[w3_weight]), self.w3_list[local_expert_idx]) - del weights[w3_weight] if w2_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) - del weights[w2_weight] def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" @@ -362,13 +357,10 @@ def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): col_slice_func = self.col_slicer._slice_weight if w1_merge_weight in weights: self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1) - del weights[w1_merge_weight] if w2_merge_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2) - del weights[w2_merge_weight] if w3_merge_weight in weights: self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3) - del weights[w3_merge_weight] def _load_expert_scale( self, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index fd4b395811..5021699143 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -126,7 +126,6 @@ def _load_weight( slicer = self._get_param_slicer(sub_child_index) weight = slicer._slice_weight(weights[param_name]) self.quant_method.load_weight(weight, self.mm_param_list[sub_child_index]) - del weights[param_name] return def _load_bias( @@ -137,7 +136,6 @@ def _load_bias( bias = slicer._slice_bias(weights[param_name]) self.bias_list[sub_child_index].copy_(bias) self.bias_list[sub_child_index].load_ok = True - del weights[param_name] return def _load_weight_scale( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 2b01918b0e..ee9d1923c3 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -25,7 +25,6 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) self.weight.load_ok = True - del weights[self.weight_name] def verify_load(self): return self.weight.load_ok @@ -141,11 +140,9 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) self.weight.load_ok = True - del weights[self.weight_name] if self.bias_name in weights: self.bias.copy_(weights[self.bias_name]) self.bias.load_ok = True - del weights[self.bias_name] def verify_load(self): return self.weight.load_ok and self.bias.load_ok @@ -236,7 +233,6 @@ def load_hf_weights(self, weights): # the padding part is zero self.weight[end - start :].zero_() self.weight.load_ok = True - del weights[self.weight_name] class NoTpGEMMANormWeight(RMSNormWeight): @@ -269,11 +265,9 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.q_weight_name in weights: self.q_weight.copy_(weights[self.q_weight_name]) self.q_weight.load_ok = True - del weights[self.q_weight_name] if self.k_weight_name in weights: self.k_weight.copy_(weights[self.k_weight_name]) self.k_weight.load_ok = True - del weights[self.k_weight_name] def verify_load(self): return self.q_weight.load_ok and self.k_weight.load_ok diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py index 3dc732e821..80d733394b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -41,6 +41,10 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: t_bias = weights[self.bias_name] self.bias.copy_(t_bias.to(self.data_type_)) self.bias.load_ok = True + + def verify_load(self) -> bool: + if self.weight is not None and not getattr(self.weight, "load_ok", False): + return False if self.bias is not None and not getattr(self.bias, "load_ok", False): return False return True From 46d2ee2a4801ddd35551d5f49db03ddbb09ff0ae Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 8 May 2026 13:39:48 +0000 Subject: [PATCH 169/180] remove unused code --- .../common/kv_cache_mem_manager/mem_utils.py | 4 - .../kv_cache_mem_manager/neo_mem_manager.py | 46 --- .../mamba_cache_mem_manager/cache_manager.py | 281 ------------------ 3 files changed, 331 deletions(-) delete mode 100755 lightllm/common/kv_cache_mem_manager/neo_mem_manager.py delete mode 100644 lightllm/common/mamba_cache_mem_manager/cache_manager.py diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 0d96b59ed4..79ea448794 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -39,10 +39,6 @@ def select_mem_manager_class(): mem_class = Deepseek2MemoryManager logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class - # # 判断是否是 neo 系列的模型 - # elif issubclass(model_class, NeoTpMOEPartModel) or issubclass(model_class, NeoTpPartModel): - # mem_class = NeoMemoryManager - # return mem_class # case normal logger.info(f"mode setting params: {get_env_start_args().llm_kv_type}") diff --git a/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py deleted file mode 100755 index 1101386f6c..0000000000 --- a/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py +++ /dev/null @@ -1,46 +0,0 @@ -# import torch -# from lightllm.utils.dist_utils import get_current_rank_in_node -# from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -# from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager - - -# class NeoMemoryManager(MemoryManager): -# def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): -# self.size = size -# self.head_num = head_num -# self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 -# self.layer_num = layer_num -# self.always_copy = always_copy -# self.dtype = dtype -# # profile the max total token num if the size is None -# self.profile_size(mem_fraction) - -# self.mem_state = torch.arange( -# 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True -# ) -# self._mem_state_return = torch.arange( -# 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True -# ) -# self._return_start = 0 -# self.mark_start = 0 -# self.mark_end = self.size - -# self.can_use_mem_size = self.size - -# # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 -# from lightllm.utils.envs_utils import get_unique_server_name - -# rank_in_node = get_current_rank_in_node() -# self.shared_can_use_token_num = SharedInt( -# f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" -# ) - -# self.shared_can_use_token_num.set_value(self.can_use_mem_size) -# self._init_buffers( -# self.size, -# dtype, -# head_num, -# self.head_dim, -# layer_num, -# ) -# self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py deleted file mode 100644 index 9a3a65869a..0000000000 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ /dev/null @@ -1,281 +0,0 @@ -from typing import List, Tuple, Union - -import torch -import numpy as np - -from lightllm.utils.dist_utils import get_current_rank_in_node -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args -from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_mamba_buffer, fork_mamba_buffer -from lightllm.utils.log_utils import init_logger -from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt - -logger = init_logger(__name__) - - -class LayerCache: - def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_num: int): - self.size = size - self.dtype = dtype - self.shape = shape - self.layer_num = layer_num - self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda") - - def get_cell_size(self): - return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) - - -class MambaCacheManager: - def __init__( - self, - size: int, - layer_num: int, - conv_state_dtype: torch.dtype, - ssm_state_dtype: torch.dtype, - conv_kernel_size: int, - num_linear_k_heads: int, - num_linear_v_heads: int, - head_linear_k_dim: int, - head_linear_v_dim: int, - ): - # init the mem state - self.size = size - self.num_linear_k_heads = num_linear_k_heads - self.num_linear_v_heads = num_linear_v_heads - self.head_linear_k_dim = head_linear_k_dim - self.head_linear_v_dim = head_linear_v_dim - self.conv_dim = ( - self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads - ) - self.layer_num = layer_num - self.conv_kernel_size = conv_kernel_size - conv_state_shape = (self.conv_dim, conv_kernel_size - 1) - ssm_state_shape = ( - self.num_linear_v_heads, - self.head_linear_k_dim, - self.head_linear_v_dim, - ) - self.ssm_state_dtype = ssm_state_dtype - self.conv_state_dtype = conv_state_dtype - self.profile_size() - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mamba_cache_can_use_num_{get_current_rank_in_node()}" - ) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # init the layer cache - self.conv_state_cache = LayerCache(self.size, conv_state_dtype, conv_state_shape, layer_num) - self.ssm_state_cache = LayerCache(self.size, ssm_state_dtype, ssm_state_shape, layer_num) - self.HOLD_BUFFER_INDEX = self.size - - def get_mamba_cache(self, layer_idx: int): - conv_state = self.conv_state_cache.buffer[layer_idx] - ssm_state = self.ssm_state_cache.buffer[layer_idx] - return conv_state, ssm_state - - def copy_state_buffers(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): - copy_mamba_buffer( - self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes - ) - copy_mamba_buffer( - self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes - ) - - def fork_state_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): - fork_mamba_buffer( - self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes - ) - fork_mamba_buffer( - self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes - ) - - def fork_ssm_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): - """ - Fork ONLY SSM states (not conv states) from source indices to destination indices. - - This is used for MTP mode where each buffer maintains its own independent conv state, - but SSM states need to be synchronized. - """ - fork_mamba_buffer( - self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes - ) - - def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """ - Free the allocated cache buffers and clear them. - - Args: - free_index: Buffer indices to free (tensor or list of ints) - """ - # Convert to tensor if needed for indexing - if isinstance(free_index, list): - free_index_tensor = torch.tensor(free_index, dtype=torch.long, device="cuda") - else: - free_index_tensor = free_index.to(device="cuda", dtype=torch.long) - - # Clear the buffers for the freed indices - # Shape: [layer_num, buffer_index, *shape] - self.conv_state_cache.buffer[:, free_index_tensor, ...] = 0 - self.ssm_state_cache.buffer[:, free_index_tensor, ...] = 0 - - # update the mem state - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) - self.mem_state[start:end] = free_index_tensor - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - - return - - def free_all(self): - self.conv_state_cache.buffer.fill_(0) - self.ssm_state_cache.buffer.fill_(0) - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) - - return - - def resize_mem(self, new_size): - """ - just for test code - """ - self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - return - - def profile_size( - self, - ): - start_args = get_env_start_args() - if self.size is not None: - assert self.size < start_args.running_max_req_size * 2, ( - f"error mamba_cache_size {self.size} < running_max_req_size * 2 {start_args.running_max_req_size * 2}", - f"mamba_cache_size should be at least running_max_req_size * 2", - ) - return - from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory - import torch.distributed as dist - - mem_fraction = start_args.mem_fraction - world_size = dist.get_world_size() - total_memory = get_total_gpu_memory() - available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) - conv_cell_size = ( - self.layer_num - * self.conv_dim - * (self.conv_kernel_size - 1) - * torch._utils._element_size(self.conv_state_dtype) - ) - ssm_cell_size = ( - self.layer_num - * (self.num_linear_v_heads) - * self.head_linear_k_dim - * self.head_linear_v_dim - * torch._utils._element_size(self.ssm_state_dtype) - ) - total_cell_size = conv_cell_size + ssm_cell_size - mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 - mamba_memory_gb = available_memory * mamba_cache_ratio - mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) - - if mamba_cache_size < start_args.running_max_req_size * 2: - ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 - raise ValueError( - f"Insufficient memory for mamba cache allocation!\n\n" - f"mamba_cache_size should be at least running_max_req_size * 2\n" - f"Calculated mamba_cache_size ({mamba_cache_size}) < " - f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n\n" - f"Memory budget:\n" - f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" - f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" - f" Calculated buffers: {mamba_cache_size}\n" - f" Required buffers: {start_args.running_max_req_size}\n\n" - f"Solutions:\n" - f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" - f" 2. Increase --mamba_cache_ratio from {ratio} to " - f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" - f" 3. Increase --mem_fraction to leave more memory for caches\n" - ) - - logger.info( - f"Mamba cache allocation:\n" - f" Available memory: {mamba_memory_gb:.2f} GB\n" - f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" - f" Calculated mamba_cache_size: {mamba_cache_size}" - ) - self.size = mamba_cache_size - return - - -class ReadOnlyStaticsMambaCacheManager: - """ - 读取一些统计信息 - """ - - def __init__(self) -> None: - args = get_env_start_args() - self.global_world_size = args.tp - self.node_world_size = args.tp // args.nnodes - self.dp_world_size = self.global_world_size // args.dp - # 兼容多机 dp size=1 纯 tp 模式的情况 - self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 - self.shared_tp_can_use_token_nums = [ - SharedInt(f"{get_unique_server_name()}_mamba_cache_can_use_num_{rank_in_node}") - for rank_in_node in range(0, self.node_world_size, self.dp_world_size) - ] - - def get_unrefed_token_num(self, dp_rank_in_node: int): - if self.is_multinode_tp: - return self.shared_tp_can_use_token_nums[0].get_value() - return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() From f2c1a3ed8828e3755276f9b2941412873f865dbb Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 8 May 2026 13:41:17 +0000 Subject: [PATCH 170/180] remove unused code --- lightllm/common/req_manager.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 134b9f4bbe..11d6fbc562 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -101,14 +101,6 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return - def resume(self): - return - - @property - def has_recurrent_state(self): - """Whether this model uses per-request recurrent state buffers (e.g. Mamba/linear attention).""" - return self.req_to_buffer_index is not None - class ReqSamplingParamsManager: """ From e7c1475b9ed3fd49b6509b5b0e2f0b42cbdae6fd Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 8 May 2026 13:48:44 +0000 Subject: [PATCH 171/180] slime code --- lightllm/common/req_manager.py | 1 - lightllm/models/qwen3next/mem_manager.py | 76 - .../triton_kernel/fused_split_copy.py | 400 ----- .../qwen3next/triton_kernel/gdn_decode_mtp.py | 1333 ----------------- .../qwen3next/triton_kernel/gemma_rmsnorm.py | 141 -- .../dynamic_prompt/hybrid_radix_cache.py | 158 -- 6 files changed, 2109 deletions(-) delete mode 100644 lightllm/models/qwen3next/mem_manager.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fused_split_copy.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py delete mode 100644 lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 11d6fbc562..01e9c4ad35 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -74,7 +74,6 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num - self.req_to_buffer_index = None def alloc(self): return self.req_list.alloc() diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py deleted file mode 100644 index 12a6d56b8c..0000000000 --- a/lightllm/models/qwen3next/mem_manager.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -from typing import Tuple -from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager -from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager -from lightllm.server.core.objs.start_args_type import StartArgs - -logger = init_logger(__name__) - - -class Qwen3NextHybridMemManager(MemoryManager): - def __init__( - self, - full_attn_cache_size, - linear_attn_cache_size, - dtype, - num_kv_heads, - head_dim, - layer_num, - full_attention_interval: int, - conv_state_dtype: torch.dtype, - ssm_state_dtype: torch.dtype, - conv_kernel_size: int, - num_linear_k_heads: int, - num_linear_v_heads: int, - head_linear_k_dim: int, - head_linear_v_dim: int, - max_req_num: int, - always_copy=False, - mem_fraction=0.9, - network_config: dict = None, - ): - - self.full_attention_interval = full_attention_interval - assert layer_num % full_attention_interval == 0 - self.layer_num = layer_num - self.full_attn_layer_num = layer_num // full_attention_interval - self.linear_attn_layer_num = layer_num - self.full_attn_layer_num - - self.mamba_cache_mem_manager = MambaCacheManager( - size=linear_attn_cache_size, - layer_num=self.linear_attn_layer_num, - conv_state_dtype=conv_state_dtype, - ssm_state_dtype=ssm_state_dtype, - conv_kernel_size=conv_kernel_size, - num_linear_k_heads=num_linear_k_heads, - num_linear_v_heads=num_linear_v_heads, - head_linear_k_dim=head_linear_k_dim, - head_linear_v_dim=head_linear_v_dim, - ) - - super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) - - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., - # None, kv_cache, mtp_kv_cache, mtp_kv_cache] - # Only full attention layers have KV cache. - self.kv_buffer = [None for _ in range(self.layer_num)] - for layer_id in range(self.full_attn_layer_num): - self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( - (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" - ) - - def free_all(self): - super().free_all() - self.mamba_cache_mem_manager.free_all() - return - - def get_cell_size(self): - # Only full attention layers and MTP layers have KV cache - kv_cache_layer_num = self.full_attn_layer_num - return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype) - - def get_mamba_cache(self, layer_idx: int): - layer_idx_in_linear = layer_idx - (layer_idx // self.full_attention_interval) - return self.mamba_cache_mem_manager.get_mamba_cache(layer_idx_in_linear) diff --git a/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py b/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py deleted file mode 100644 index 5f4433fb34..0000000000 --- a/lightllm/models/qwen3next/triton_kernel/fused_split_copy.py +++ /dev/null @@ -1,400 +0,0 @@ -""" -Fused Split-Copy Triton Kernels for GDN Decode Path - -Replaces multiple separate .copy_() calls with single kernel launches to reduce -kernel launch overhead in the decode hot path (36 GDN layers per step). - -Kernel 1 (fused_split_copy_qkvzba): 4 copies → 1 kernel - Splits GEMM output [batch, total_dim] into qkv, z, b, a destination buffers. - -Kernel 2 (fused_split_copy_qkv): 3 copies → 1 kernel - Splits conv1d output [batch, qkv_dim] into q, k, v destination buffers. - Handles non-contiguous source (stride(0) != total_dim from column slicing). -""" - -import torch -import triton -import triton.language as tl - - -# ============================================================================= -# Kernel 1: Fused split-copy for qkv, z, b, a from GEMM output -# ============================================================================= - - -@triton.jit -def _fused_split_copy_qkvzba_kernel( - # Source pointer (contiguous GEMM output) - src_ptr, - # Destination pointers (pre-allocated contiguous buffers) - dst_qkv_ptr, - dst_z_ptr, - dst_b_ptr, - dst_a_ptr, - # Row strides - src_stride0, - dst_qkv_stride0, - dst_z_stride0, - dst_b_stride0, - dst_a_stride0, - # Segment boundaries (cumulative): [0, qkv_dim) [qkv_dim, z_end) [z_end, b_end) [b_end, total_dim) - qkv_dim, - z_end, - b_end, - total_dim, - # Block size - BLOCK_N: tl.constexpr, -): - """ - One program per (row, column_block). Loads a BLOCK_N chunk from the source row, - then conditionally stores to the correct destination based on column position. - - Grid: (batch, cdiv(total_dim, BLOCK_N)) - """ - row = tl.program_id(0) - col_block = tl.program_id(1) - - col_start = col_block * BLOCK_N - cols = col_start + tl.arange(0, BLOCK_N) - mask = cols < total_dim - - # Load source chunk - data = tl.load(src_ptr + row * src_stride0 + cols, mask=mask) - - # Store to qkv destination: columns [0, qkv_dim) - qkv_mask = mask & (cols < qkv_dim) - tl.store(dst_qkv_ptr + row * dst_qkv_stride0 + cols, data, mask=qkv_mask) - - # Store to z destination: columns [qkv_dim, z_end) - z_mask = mask & (cols >= qkv_dim) & (cols < z_end) - tl.store(dst_z_ptr + row * dst_z_stride0 + (cols - qkv_dim), data, mask=z_mask) - - # Store to b destination: columns [z_end, b_end) - b_mask = mask & (cols >= z_end) & (cols < b_end) - tl.store(dst_b_ptr + row * dst_b_stride0 + (cols - z_end), data, mask=b_mask) - - # Store to a destination: columns [b_end, total_dim) - a_mask = mask & (cols >= b_end) - tl.store(dst_a_ptr + row * dst_a_stride0 + (cols - b_end), data, mask=a_mask) - - -def fused_split_copy_qkvzba( - src: torch.Tensor, - dst_qkv: torch.Tensor, - dst_z: torch.Tensor, - dst_b: torch.Tensor, - dst_a: torch.Tensor, - qkv_dim: int, - z_dim: int, - b_dim: int, - a_dim: int, -): - """ - Fused split-copy from GEMM output into 4 contiguous destination buffers. - - Replaces: - conv_buf.copy_(mixed_qkvzba[:, :qkv_dim]) - z_buf.view(batch, -1).copy_(mixed_qkvzba[:, qkv_dim:z_end]) - b_buf.copy_(mixed_qkvzba[:, z_end:b_end]) - a_buf.copy_(mixed_qkvzba[:, b_end:]) - - Args: - src: [batch, total_dim] contiguous source (GEMM output) - dst_qkv: [batch, qkv_dim] contiguous destination for conv1d input - dst_z: [batch, z_dim] contiguous destination (z_buf viewed flat) - dst_b: [batch, b_dim] contiguous destination - dst_a: [batch, a_dim] contiguous destination - qkv_dim: width of qkv segment (tp_key_dim * 2 + tp_value_dim) - z_dim: width of z segment (tp_value_dim) - b_dim: width of b segment (tp_num_v_heads) - a_dim: width of a segment (tp_num_v_heads) - """ - total_dim = qkv_dim + z_dim + b_dim + a_dim - z_end = qkv_dim + z_dim - b_end = z_end + b_dim - - batch = src.shape[0] - BLOCK_N = 128 - num_col_blocks = triton.cdiv(total_dim, BLOCK_N) - - grid = (batch, num_col_blocks) - - _fused_split_copy_qkvzba_kernel[grid]( - src, - dst_qkv, - dst_z, - dst_b, - dst_a, - src.stride(0), - dst_qkv.stride(0), - dst_z.stride(0), - dst_b.stride(0), - dst_a.stride(0), - qkv_dim, - z_end, - b_end, - total_dim, - BLOCK_N=BLOCK_N, - num_warps=4, - ) - - -# ============================================================================= -# Kernel 2: Fused split-copy for q, k, v from conv1d output -# ============================================================================= - - -@triton.jit -def _fused_split_copy_qkv_kernel( - # Source pointer (may be non-contiguous column slice) - src_ptr, - # Destination pointers (contiguous buffers) - dst_q_ptr, - dst_k_ptr, - dst_v_ptr, - # Row strides - src_stride0, - dst_q_stride0, - dst_k_stride0, - dst_v_stride0, - # Segment boundaries: [0, q_dim) [q_dim, qk_end) [qk_end, total_dim) - q_dim, - qk_end, - total_dim, - # Block size - BLOCK_N: tl.constexpr, -): - """ - One program per (row, column_block). Loads a BLOCK_N chunk from the source row, - then conditionally stores to q, k, or v destination. - - Supports non-contiguous source via src_stride0 (stride may be > total_dim - when source is a column slice of a larger tensor). - - Grid: (batch, cdiv(total_dim, BLOCK_N)) - """ - row = tl.program_id(0) - col_block = tl.program_id(1) - - col_start = col_block * BLOCK_N - cols = col_start + tl.arange(0, BLOCK_N) - mask = cols < total_dim - - # Load source chunk (use src_stride0 for row advancement) - data = tl.load(src_ptr + row * src_stride0 + cols, mask=mask) - - # Store to q destination: columns [0, q_dim) - q_mask = mask & (cols < q_dim) - tl.store(dst_q_ptr + row * dst_q_stride0 + cols, data, mask=q_mask) - - # Store to k destination: columns [q_dim, qk_end) - k_mask = mask & (cols >= q_dim) & (cols < qk_end) - tl.store(dst_k_ptr + row * dst_k_stride0 + (cols - q_dim), data, mask=k_mask) - - # Store to v destination: columns [qk_end, total_dim) - v_mask = mask & (cols >= qk_end) - tl.store(dst_v_ptr + row * dst_v_stride0 + (cols - qk_end), data, mask=v_mask) - - -def fused_split_copy_qkv( - src: torch.Tensor, - dst_q: torch.Tensor, - dst_k: torch.Tensor, - dst_v: torch.Tensor, - q_dim: int, - k_dim: int, - v_dim: int, - src_stride0: int, -): - """ - Fused split-copy from conv1d output into 3 contiguous q/k/v buffers. - - Replaces: - q_split, k_split, v_split = torch.split(mixed_qkv, [...], dim=-1) - q_buf.view(batch, -1).copy_(q_split) - k_buf.view(batch, -1).copy_(k_split) - v_buf.view(batch, -1).copy_(v_split) - - Args: - src: [batch, total_dim] source tensor (may be non-contiguous if column slice) - dst_q: [batch, q_dim] contiguous destination - dst_k: [batch, k_dim] contiguous destination - dst_v: [batch, v_dim] contiguous destination - q_dim: width of q segment (tp_key_dim) - k_dim: width of k segment (tp_key_dim) - v_dim: width of v segment (tp_value_dim) - src_stride0: row stride of source (may be > q_dim+k_dim+v_dim) - """ - total_dim = q_dim + k_dim + v_dim - qk_end = q_dim + k_dim - - batch = src.shape[0] - BLOCK_N = 128 - num_col_blocks = triton.cdiv(total_dim, BLOCK_N) - - grid = (batch, num_col_blocks) - - _fused_split_copy_qkv_kernel[grid]( - src, - dst_q, - dst_k, - dst_v, - src_stride0, - dst_q.stride(0), - dst_k.stride(0), - dst_v.stride(0), - q_dim, - qk_end, - total_dim, - BLOCK_N=BLOCK_N, - num_warps=4, - ) - - -# ============================================================================= -# Test / Verification -# ============================================================================= - - -def test_fused_split_copy(): - """Verify fused kernels produce identical results to separate .copy_() calls.""" - torch.manual_seed(42) - device = "cuda" - dtype = torch.bfloat16 - - print("=" * 60) - print("Testing fused_split_copy_qkvzba") - print("=" * 60) - - # Typical dimensions for Qwen3-Coder-Next with TP=4 - # tp_key_dim=128, tp_value_dim=256, tp_num_v_heads=2 - qkv_dim = 128 + 128 + 256 # q + k + v = 512 - z_dim = 256 - b_dim = 2 - a_dim = 2 - total_dim = qkv_dim + z_dim + b_dim + a_dim # 772 - - for batch in [1, 4, 8, 32]: - src = torch.randn(batch, total_dim, dtype=dtype, device=device) - - # Reference: separate copies - ref_qkv = src[:, :qkv_dim].clone() - ref_z = src[:, qkv_dim : qkv_dim + z_dim].clone() - ref_b = src[:, qkv_dim + z_dim : qkv_dim + z_dim + b_dim].clone() - ref_a = src[:, qkv_dim + z_dim + b_dim :].clone() - - # Fused kernel - dst_qkv = torch.empty(batch, qkv_dim, dtype=dtype, device=device) - dst_z = torch.empty(batch, z_dim, dtype=dtype, device=device) - dst_b = torch.empty(batch, b_dim, dtype=dtype, device=device) - dst_a = torch.empty(batch, a_dim, dtype=dtype, device=device) - fused_split_copy_qkvzba(src, dst_qkv, dst_z, dst_b, dst_a, qkv_dim, z_dim, b_dim, a_dim) - - assert torch.equal(dst_qkv, ref_qkv), f"qkv mismatch at batch={batch}" - assert torch.equal(dst_z, ref_z), f"z mismatch at batch={batch}" - assert torch.equal(dst_b, ref_b), f"b mismatch at batch={batch}" - assert torch.equal(dst_a, ref_a), f"a mismatch at batch={batch}" - print(f" batch={batch:3d}: PASS") - - print() - print("=" * 60) - print("Testing fused_split_copy_qkv") - print("=" * 60) - - q_dim = 128 - k_dim = 128 - v_dim = 256 - qkv_dim = q_dim + k_dim + v_dim # 512 - - for batch in [1, 4, 8, 32]: - # Test with contiguous source - src = torch.randn(batch, qkv_dim, dtype=dtype, device=device) - - ref_q = src[:, :q_dim].clone() - ref_k = src[:, q_dim : q_dim + k_dim].clone() - ref_v = src[:, q_dim + k_dim :].clone() - - dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) - dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) - dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) - fused_split_copy_qkv(src, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src.stride(0)) - - assert torch.equal(dst_q, ref_q), f"q mismatch at batch={batch} (contiguous)" - assert torch.equal(dst_k, ref_k), f"k mismatch at batch={batch} (contiguous)" - assert torch.equal(dst_v, ref_v), f"v mismatch at batch={batch} (contiguous)" - print(f" batch={batch:3d} (contiguous src): PASS") - - # Test with non-contiguous source (column slice of wider tensor) - wider = torch.randn(batch, qkv_dim + 64, dtype=dtype, device=device) - src_nc = wider[:, :qkv_dim] # Non-contiguous: stride(0) = qkv_dim + 64 - assert src_nc.stride(0) == qkv_dim + 64, "expected non-contiguous slice" - - ref_q = src_nc[:, :q_dim].clone() - ref_k = src_nc[:, q_dim : q_dim + k_dim].clone() - ref_v = src_nc[:, q_dim + k_dim :].clone() - - dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) - dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) - dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) - fused_split_copy_qkv(src_nc, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src_nc.stride(0)) - - assert torch.equal(dst_q, ref_q), f"q mismatch at batch={batch} (non-contiguous)" - assert torch.equal(dst_k, ref_k), f"k mismatch at batch={batch} (non-contiguous)" - assert torch.equal(dst_v, ref_v), f"v mismatch at batch={batch} (non-contiguous)" - print(f" batch={batch:3d} (non-contiguous src): PASS") - - print() - print("=" * 60) - print("Testing edge cases") - print("=" * 60) - - # Edge case: different dimension ratios (small q/k, large v) - q_dim, k_dim, v_dim = 32, 32, 512 - qkv_dim = q_dim + k_dim + v_dim - batch = 2 - src = torch.randn(batch, qkv_dim, dtype=dtype, device=device) - - dst_q = torch.empty(batch, q_dim, dtype=dtype, device=device) - dst_k = torch.empty(batch, k_dim, dtype=dtype, device=device) - dst_v = torch.empty(batch, v_dim, dtype=dtype, device=device) - fused_split_copy_qkv(src, dst_q, dst_k, dst_v, q_dim, k_dim, v_dim, src.stride(0)) - - assert torch.equal(dst_q, src[:, :q_dim]) - assert torch.equal(dst_k, src[:, q_dim : q_dim + k_dim]) - assert torch.equal(dst_v, src[:, q_dim + k_dim :]) - print(" asymmetric dims (32, 32, 512): PASS") - - # Edge case: float32 dtype - src_f32 = torch.randn(4, 772, dtype=torch.float32, device=device) - dst_qkv = torch.empty(4, 512, dtype=torch.float32, device=device) - dst_z = torch.empty(4, 256, dtype=torch.float32, device=device) - dst_b = torch.empty(4, 2, dtype=torch.float32, device=device) - dst_a = torch.empty(4, 2, dtype=torch.float32, device=device) - fused_split_copy_qkvzba(src_f32, dst_qkv, dst_z, dst_b, dst_a, 512, 256, 2, 2) - - assert torch.equal(dst_qkv, src_f32[:, :512]) - assert torch.equal(dst_z, src_f32[:, 512:768]) - assert torch.equal(dst_b, src_f32[:, 768:770]) - assert torch.equal(dst_a, src_f32[:, 770:]) - print(" float32 dtype: PASS") - - # Edge case: float16 dtype - src_f16 = torch.randn(4, 772, dtype=torch.float16, device=device) - dst_qkv = torch.empty(4, 512, dtype=torch.float16, device=device) - dst_z = torch.empty(4, 256, dtype=torch.float16, device=device) - dst_b = torch.empty(4, 2, dtype=torch.float16, device=device) - dst_a = torch.empty(4, 2, dtype=torch.float16, device=device) - fused_split_copy_qkvzba(src_f16, dst_qkv, dst_z, dst_b, dst_a, 512, 256, 2, 2) - - assert torch.equal(dst_qkv, src_f16[:, :512]) - assert torch.equal(dst_z, src_f16[:, 512:768]) - assert torch.equal(dst_b, src_f16[:, 768:770]) - assert torch.equal(dst_a, src_f16[:, 770:]) - print(" float16 dtype: PASS") - - print() - print("All tests passed!") - - -if __name__ == "__main__": - test_fused_split_copy() diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py deleted file mode 100644 index 5a39debaa9..0000000000 --- a/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py +++ /dev/null @@ -1,1333 +0,0 @@ -""" -Optimized GDN Decode MTP (Multi-Token Prediction) Kernel - -This module provides an optimized Triton kernel for GDN decode with MTP support, -eliminating the need for sequential Python loops and reducing memory operations. - -Key optimizations: -1. Fused data reorganization from interleaved to batched layout -2. Parallel processing of all batch items with proper state indexing -3. Auto-tuned configurations for different batch sizes and model dimensions -""" - -import torch -import triton -import triton.language as tl -from lightllm.common.triton_utils.autotuner import autotune - - -@triton.jit -def _reorganize_mtp_data_kernel( - # Input pointers (interleaved layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...]) - src_ptr, - # Output pointers (batched layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...]) - dst_ptr, - # Dimensions - batch_size, - mtp_size, - dim_size, - # Strides - src_stride_token, - src_stride_dim, - dst_stride_token, - dst_stride_dim, - # Block sizes - BLOCK_DIM: tl.constexpr, -): - """ - Reorganize data from interleaved MTP layout to batched layout. - - Input layout: [step0_batch0, step0_batch1, ..., step0_batchN, step1_batch0, ...] - Output layout: [batch0_step0, batch0_step1, ..., batch0_stepM, batch1_step0, ...] - - This enables efficient processing with the recurrent kernel. - """ - batch_idx = tl.program_id(0) - step_idx = tl.program_id(1) - block_dim_idx = tl.program_id(2) - - # Calculate source and destination token indices - src_token_idx = step_idx * batch_size + batch_idx - dst_token_idx = batch_idx * mtp_size + step_idx - - # Calculate dimension offsets - dim_start = block_dim_idx * BLOCK_DIM - dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) - mask = dim_offsets < dim_size - - # Load from source (interleaved layout) - src_offset = src_token_idx * src_stride_token + dim_offsets * src_stride_dim - data = tl.load(src_ptr + src_offset, mask=mask, other=0.0) - - # Store to destination (batched layout) - dst_offset = dst_token_idx * dst_stride_token + dim_offsets * dst_stride_dim - tl.store(dst_ptr + dst_offset, data, mask=mask) - - -@triton.jit -def _reorganize_mtp_data_back_kernel( - # Input pointers (batched layout): [batch_size, mtp_size, num_heads, head_dim] - src_ptr, - # Output pointers (interleaved layout): [total_tokens, 1, num_heads, head_dim] - dst_ptr, - # Dimensions - batch_size, - mtp_size, - num_heads, - head_dim, - # Strides for src: [batch_size, mtp_size, num_heads, head_dim] - src_stride_batch, - src_stride_mtp, - src_stride_head, - src_stride_dim, - # Strides for dst: [total_tokens, 1, num_heads, head_dim] - dst_stride_token, - dst_stride_seq, - dst_stride_head, - dst_stride_dim, - # Block sizes - BLOCK_HEAD: tl.constexpr, - BLOCK_DIM: tl.constexpr, -): - """ - Reorganize output data from batched layout back to interleaved layout. - - Input shape: [batch_size, mtp_size, num_heads, head_dim] - Output shape: [batch_size * mtp_size, 1, num_heads, head_dim] (interleaved) - - Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] - """ - batch_idx = tl.program_id(0) - step_idx = tl.program_id(1) - block_idx = tl.program_id(2) - - # Decompose block_idx into head and dim blocks - num_dim_blocks = tl.cdiv(head_dim, BLOCK_DIM) - block_head_idx = block_idx // num_dim_blocks - block_dim_idx = block_idx % num_dim_blocks - - # Calculate destination token index (interleaved) - dst_token_idx = step_idx * batch_size + batch_idx - - # Calculate offsets - head_start = block_head_idx * BLOCK_HEAD - dim_start = block_dim_idx * BLOCK_DIM - - head_offsets = head_start + tl.arange(0, BLOCK_HEAD) - dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) - - head_mask = head_offsets < num_heads - dim_mask = dim_offsets < head_dim - mask = head_mask[:, None] & dim_mask[None, :] - - # Load from source (batched layout): [batch_size, mtp_size, num_heads, head_dim] - src_base = src_ptr + batch_idx * src_stride_batch + step_idx * src_stride_mtp - src_offset = head_offsets[:, None] * src_stride_head + dim_offsets[None, :] * src_stride_dim - data = tl.load(src_base + src_offset, mask=mask, other=0.0) - - # Store to destination (interleaved layout): [total_tokens, 1, num_heads, head_dim] - # The seq dimension (1) is skipped since it's always 0 - dst_base = dst_ptr + dst_token_idx * dst_stride_token - dst_offset = head_offsets[:, None] * dst_stride_head + dim_offsets[None, :] * dst_stride_dim - tl.store(dst_base + dst_offset, data, mask=mask) - - -def _get_reorganize_mtp_configs(): - """Generate candidate configurations for MTP data reorganization.""" - configs = [] - for block_dim in [64, 128, 256, 512]: - for num_warps in [2, 4, 8]: - for num_stages in [2, 3, 4]: - configs.append( - { - "BLOCK_DIM": block_dim, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - -def _get_reorganize_static_key(src: torch.Tensor, mtp_size: int): - """Static key based on tensor properties.""" - return { - "dtype": str(src.dtype), - "mtp_size": mtp_size, - } - - -def _get_reorganize_run_key(src: torch.Tensor, mtp_size: int): - """Run key based on batch size and dimension.""" - total_tokens = src.shape[0] - batch_size = total_tokens // mtp_size - dim_size = src.shape[-1] - return f"{batch_size}_{dim_size}" - - -@autotune( - kernel_name="gdn_decode_mtp_reorganize:v1", - configs_gen_func=_get_reorganize_mtp_configs, - static_key_func=_get_reorganize_static_key, - run_key_func=_get_reorganize_run_key, - mutates_args=["dst"], -) -def reorganize_mtp_to_batched( - src: torch.Tensor, - dst: torch.Tensor, - mtp_size: int, - run_config: dict = None, -): - """ - Reorganize data from interleaved MTP layout to batched layout. - - Args: - src: Input tensor with interleaved layout [total_tokens, dim] - Layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...] - dst: Output tensor with batched layout [total_tokens, dim] - Layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...] - mtp_size: Number of MTP steps - run_config: Auto-tuned configuration - """ - total_tokens = src.shape[0] - batch_size = total_tokens // mtp_size - dim_size = src.shape[-1] - - if run_config is None: - BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_DIM = run_config["BLOCK_DIM"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) - - grid = (batch_size, mtp_size, num_blocks_dim) - - _reorganize_mtp_data_kernel[grid]( - src, - dst, - batch_size, - mtp_size, - dim_size, - src.stride(0), - src.stride(-1) if src.ndim > 1 else 1, - dst.stride(0), - dst.stride(-1) if dst.ndim > 1 else 1, - BLOCK_DIM=BLOCK_DIM, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def _get_reorganize_back_configs(): - """Generate candidate configurations for MTP output reorganization.""" - configs = [] - for block_head in [4, 8, 16, 32]: - for block_dim in [32, 64, 128]: - for num_warps in [2, 4, 8]: - for num_stages in [2, 3]: - if block_head * block_dim <= 4096: # Limit shared memory - configs.append( - { - "BLOCK_HEAD": block_head, - "BLOCK_DIM": block_dim, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - -def _get_reorganize_back_static_key( - src: torch.Tensor, - batch_size: int, - mtp_size: int, - num_heads: int, - head_dim: int, -): - """Static key for output reorganization.""" - return { - "dtype": str(src.dtype), - "mtp_size": mtp_size, - "num_heads": num_heads, - "head_dim": head_dim, - } - - -def _get_reorganize_back_run_key( - src: torch.Tensor, - batch_size: int, - mtp_size: int, - num_heads: int, - head_dim: int, -): - """Run key for output reorganization.""" - return batch_size - - -@autotune( - kernel_name="gdn_decode_mtp_reorganize_back:v1", - configs_gen_func=_get_reorganize_back_configs, - static_key_func=_get_reorganize_back_static_key, - run_key_func=_get_reorganize_back_run_key, - mutates_args=["dst"], -) -def reorganize_mtp_output_to_interleaved( - src: torch.Tensor, - dst: torch.Tensor, - batch_size: int, - mtp_size: int, - num_heads: int, - head_dim: int, - run_config: dict = None, -): - """ - Reorganize output from batched layout back to interleaved layout. - - Args: - src: Input tensor [batch_size, mtp_size, num_heads, head_dim] (4D) - dst: Output tensor [batch_size * mtp_size, 1, num_heads, head_dim] (4D) - batch_size: Number of batch items - mtp_size: Number of MTP steps - num_heads: Number of attention heads - head_dim: Head dimension - run_config: Auto-tuned configuration - - Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] - """ - if run_config is None: - BLOCK_HEAD = min(triton.next_power_of_2(num_heads), 16) - BLOCK_DIM = min(triton.next_power_of_2(head_dim), 64) - num_warps = 4 - num_stages = 2 - else: - BLOCK_HEAD = run_config["BLOCK_HEAD"] - BLOCK_DIM = run_config["BLOCK_DIM"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_head_blocks = triton.cdiv(num_heads, BLOCK_HEAD) - num_dim_blocks = triton.cdiv(head_dim, BLOCK_DIM) - num_blocks_total = num_head_blocks * num_dim_blocks - - grid = (batch_size, mtp_size, num_blocks_total) - - # src is 4D: [batch_size, mtp_size, num_heads, head_dim] - # dst is 4D: [total_tokens, 1, num_heads, head_dim] - _reorganize_mtp_data_back_kernel[grid]( - src, - dst, - batch_size, - mtp_size, - num_heads, - head_dim, - src.stride(0), # batch stride - src.stride(1), # mtp stride - src.stride(2), # head stride - src.stride(3), # dim stride - dst.stride(0), # token stride - dst.stride(1), # seq stride (=1) - dst.stride(2), # head stride - dst.stride(3), # dim stride - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_DIM=BLOCK_DIM, - num_warps=num_warps, - num_stages=num_stages, - ) - - -@triton.jit -def _prepare_mtp_indices_kernel( - # Input indices (per-step buffer indices) - buffer_idx_ptr, - # Output 2D indices for recurrent kernel - output_idx_ptr, - # Dimensions - batch_size, - mtp_size, - # Strides - input_stride, - output_stride_batch, - output_stride_step, -): - """ - Prepare 2D indices for the fused recurrent kernel. - - Input: mtp_size tensors of shape [batch_size] (buffer indices for each step) - Output: 2D tensor [batch_size, mtp_size] for ssm_state_indices - """ - batch_idx = tl.program_id(0) - step_idx = tl.program_id(1) - - # Load the buffer index for this batch and step - buffer_idx = tl.load(buffer_idx_ptr + step_idx * input_stride + batch_idx) - - # Store to the 2D output - output_offset = batch_idx * output_stride_batch + step_idx * output_stride_step - tl.store(output_idx_ptr + output_offset, buffer_idx) - - -def prepare_mtp_state_indices( - mtp_buffer_idx_list: list, - batch_size: int, - device: torch.device, -) -> torch.Tensor: - """ - Prepare 2D state indices for the fused recurrent kernel. - - Args: - mtp_buffer_idx_list: List of buffer index tensors, one per MTP step - batch_size: Number of batch items - device: Target device - - Returns: - 2D tensor of shape [batch_size, mtp_size] for ssm_state_indices - """ - - # Stack indices to create [mtp_size, batch_size] tensor - stacked_indices = torch.stack(mtp_buffer_idx_list, dim=0) - - # Transpose to get [batch_size, mtp_size] - return stacked_indices.T.contiguous() - - -@triton.jit -def _fused_conv1d_mtp_step_kernel( - # Input/output data - mixed_qkv_ptr, - # Conv state buffer - conv_states_ptr, - # Conv weight and bias - conv_weight_ptr, - conv_bias_ptr, - # Buffer indices (one per MTP step, each [batch_size]) - buffer_indices_ptr, - next_buffer_indices_ptr, - # Dimensions - batch_size, - dim_size, - conv_width, - # Step info - step_idx, - mtp_size, - is_last_step: tl.constexpr, - # Strides - qkv_stride_token, - qkv_stride_dim, - state_stride_buffer, - state_stride_dim, - state_stride_width, - weight_stride_dim, - weight_stride_width, - # Block sizes - BLOCK_DIM: tl.constexpr, - ACTIVATION_SILU: tl.constexpr, -): - """ - Fused kernel for conv1d update in MTP decode. - - Handles one MTP step for all batch items: - 1. Reads current conv state - 2. Updates with new input - 3. Computes conv1d output - 4. Optionally copies state to next MTP step - """ - batch_idx = tl.program_id(0) - block_dim_idx = tl.program_id(1) - - # Calculate token index in interleaved layout - token_idx = step_idx * batch_size + batch_idx - - # Load buffer indices - cur_buffer_idx = tl.load(buffer_indices_ptr + batch_idx).to(tl.int64) - - # Calculate dimension offsets - dim_start = block_dim_idx * BLOCK_DIM - dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) - dim_mask = dim_offsets < dim_size - - # Load input value - input_offset = token_idx * qkv_stride_token + dim_offsets * qkv_stride_dim - input_val = tl.load(mixed_qkv_ptr + input_offset, mask=dim_mask, other=0.0) - - # Load conv bias - bias_val = tl.load(conv_bias_ptr + dim_offsets, mask=dim_mask, other=0.0) - - # Compute conv1d output and update state - output_val = bias_val - state_base = conv_states_ptr + cur_buffer_idx * state_stride_buffer - - # Process each position in the conv window - for w in range(conv_width): - # Load weight for this position - weight_offset = dim_offsets * weight_stride_dim + w * weight_stride_width - weight_val = tl.load(conv_weight_ptr + weight_offset, mask=dim_mask, other=0.0) - - if w < conv_width - 1: - # Load from state buffer - state_offset = dim_offsets * state_stride_dim + w * state_stride_width - state_val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) - output_val += state_val * weight_val - else: - # Use current input for the last position - output_val += input_val * weight_val - - # Update conv state (shift and insert new value) - for w in range(conv_width - 2, -1, -1): - if w == conv_width - 2: - # Insert new input at the end - state_offset = dim_offsets * state_stride_dim + w * state_stride_width - tl.store(state_base + state_offset, input_val, mask=dim_mask) - else: - # Shift state - src_offset = dim_offsets * state_stride_dim + (w + 1) * state_stride_width - dst_offset = dim_offsets * state_stride_dim + w * state_stride_width - val = tl.load(state_base + src_offset, mask=dim_mask, other=0.0) - tl.store(state_base + dst_offset, val, mask=dim_mask) - - # Apply activation (SiLU) - if ACTIVATION_SILU: - output_val = output_val * tl.sigmoid(output_val) - - # Store output - tl.store(mixed_qkv_ptr + input_offset, output_val, mask=dim_mask) - - # Copy state to next step if not last - if not is_last_step: - next_buffer_idx = tl.load(next_buffer_indices_ptr + batch_idx).to(tl.int64) - next_state_base = conv_states_ptr + next_buffer_idx * state_stride_buffer - - for w in range(conv_width - 1): - state_offset = dim_offsets * state_stride_dim + w * state_stride_width - val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) - tl.store(next_state_base + state_offset, val, mask=dim_mask) - - -def _get_conv1d_mtp_configs(): - """Generate candidate configurations for conv1d MTP kernel.""" - configs = [] - for block_dim in [64, 128, 256, 512]: - for num_warps in [2, 4, 8]: - for num_stages in [1, 2, 3]: - configs.append( - { - "BLOCK_DIM": block_dim, - "num_warps": num_warps, - "num_stages": num_stages, - } - ) - return configs - - -def _get_conv1d_mtp_static_key( - mixed_qkv: torch.Tensor, - conv_states: torch.Tensor, - conv_weight: torch.Tensor, - mtp_size: int, -): - """Static key for conv1d MTP kernel.""" - return { - "dtype": str(mixed_qkv.dtype), - "dim_size": mixed_qkv.shape[-1], - "conv_width": conv_weight.shape[-1], - "mtp_size": mtp_size, - } - - -def _get_conv1d_mtp_run_key( - mixed_qkv: torch.Tensor, - conv_states: torch.Tensor, - conv_weight: torch.Tensor, - mtp_size: int, -): - """Run key for conv1d MTP kernel.""" - total_tokens = mixed_qkv.shape[0] - batch_size = total_tokens // mtp_size - return batch_size - - -@autotune( - kernel_name="gdn_conv1d_mtp:v1", - configs_gen_func=_get_conv1d_mtp_configs, - static_key_func=_get_conv1d_mtp_static_key, - run_key_func=_get_conv1d_mtp_run_key, - mutates_args=["mixed_qkv", "conv_states"], -) -def fused_conv1d_mtp_update( - mixed_qkv: torch.Tensor, - conv_states: torch.Tensor, - conv_weight: torch.Tensor, - conv_bias: torch.Tensor, - mtp_buffer_idx_list: list, - mtp_size: int, - activation_silu: bool = True, - run_config: dict = None, -): - """ - Fused conv1d update for all MTP steps. - - Args: - mixed_qkv: Input tensor [batch_size * mtp_size, dim] (interleaved) - conv_states: Conv state buffer [num_buffers, dim, conv_width-1] - conv_weight: Conv weights [dim, conv_width] - conv_bias: Conv bias [dim] - mtp_buffer_idx_list: List of buffer index tensors per step - mtp_size: Number of MTP steps - activation_silu: Whether to apply SiLU activation - run_config: Auto-tuned configuration - """ - total_tokens = mixed_qkv.shape[0] - batch_size = total_tokens // mtp_size - dim_size = mixed_qkv.shape[-1] - conv_width = conv_weight.shape[-1] - - if run_config is None: - BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_DIM = run_config["BLOCK_DIM"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) - - for step_idx in range(mtp_size): - is_last_step = step_idx == mtp_size - 1 - cur_indices = mtp_buffer_idx_list[step_idx] - next_indices = mtp_buffer_idx_list[step_idx + 1] if not is_last_step else cur_indices - - grid = (batch_size, num_blocks_dim) - - _fused_conv1d_mtp_step_kernel[grid]( - mixed_qkv, - conv_states, - conv_weight, - conv_bias, - cur_indices, - next_indices, - batch_size, - dim_size, - conv_width, - step_idx, - mtp_size, - is_last_step, - mixed_qkv.stride(0), - mixed_qkv.stride(-1) if mixed_qkv.ndim > 1 else 1, - conv_states.stride(0), - conv_states.stride(1), - conv_states.stride(2), - conv_weight.stride(0), - conv_weight.stride(1), - BLOCK_DIM=BLOCK_DIM, - ACTIVATION_SILU=activation_silu, - num_warps=num_warps, - num_stages=num_stages, - ) - - -@triton.jit -def _copy_ssm_state_kernel( - # SSM state buffer - ssm_states_ptr, - # Buffer indices - src_indices_ptr, - dst_indices_ptr, - # Dimensions - batch_size, - num_heads, - key_dim, - value_dim, - # Strides - state_stride_buffer, - state_stride_head, - state_stride_key, - state_stride_value, - # Block sizes - BLOCK_KEY: tl.constexpr, - BLOCK_VALUE: tl.constexpr, -): - """ - Copy SSM states from source indices to destination indices. - """ - batch_idx = tl.program_id(0) - head_idx = tl.program_id(1) - block_idx = tl.program_id(2) - - # Calculate block positions - num_value_blocks = tl.cdiv(value_dim, BLOCK_VALUE) - block_key_idx = block_idx // num_value_blocks - block_value_idx = block_idx % num_value_blocks - - key_start = block_key_idx * BLOCK_KEY - value_start = block_value_idx * BLOCK_VALUE - - key_offsets = key_start + tl.arange(0, BLOCK_KEY) - value_offsets = value_start + tl.arange(0, BLOCK_VALUE) - - key_mask = key_offsets < key_dim - value_mask = value_offsets < value_dim - mask = key_mask[:, None] & value_mask[None, :] - - # Load indices - src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) - dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) - - # Calculate offsets - src_base = ssm_states_ptr + src_idx * state_stride_buffer + head_idx * state_stride_head - dst_base = ssm_states_ptr + dst_idx * state_stride_buffer + head_idx * state_stride_head - - offsets = key_offsets[:, None] * state_stride_key + value_offsets[None, :] * state_stride_value - - # Copy data - data = tl.load(src_base + offsets, mask=mask, other=0.0) - tl.store(dst_base + offsets, data, mask=mask) - - -@triton.jit -def _copy_conv_state_kernel( - # Conv state buffer [num_buffers, dim, conv_width-1] - conv_states_ptr, - # Buffer indices - src_indices_ptr, - dst_indices_ptr, - # Dimensions - batch_size, - dim_size, - width_size, - num_width_blocks, # Precomputed to avoid runtime division - # Strides - state_stride_buffer, - state_stride_dim, - state_stride_width, - # Block sizes - BLOCK_DIM: tl.constexpr, - BLOCK_WIDTH: tl.constexpr, -): - """ - Copy conv states from source indices to destination indices. - - Conv state shape: [num_buffers, dim, conv_width-1] - """ - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - # Calculate block positions using precomputed num_width_blocks - block_dim_idx = block_idx // num_width_blocks - block_width_idx = block_idx % num_width_blocks - - dim_start = block_dim_idx * BLOCK_DIM - width_start = block_width_idx * BLOCK_WIDTH - - dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) - width_offsets = width_start + tl.arange(0, BLOCK_WIDTH) - - dim_mask = dim_offsets < dim_size - width_mask = width_offsets < width_size - mask = dim_mask[:, None] & width_mask[None, :] - - # Load indices - src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) - dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) - - # Calculate offsets - src_base = conv_states_ptr + src_idx * state_stride_buffer - dst_base = conv_states_ptr + dst_idx * state_stride_buffer - - offsets = dim_offsets[:, None] * state_stride_dim + width_offsets[None, :] * state_stride_width - - # Copy data - data = tl.load(src_base + offsets, mask=mask, other=0.0) - tl.store(dst_base + offsets, data, mask=mask) - - -def _get_conv_copy_configs(): - """Generate candidate configurations for conv state copy.""" - configs = [] - for block_dim in [64, 128, 256]: - for block_width in [2, 4, 8]: - for num_warps in [2, 4]: - configs.append( - { - "BLOCK_DIM": block_dim, - "BLOCK_WIDTH": block_width, - "num_warps": num_warps, - "num_stages": 2, - } - ) - return configs - - -def _get_conv_copy_static_key( - conv_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for conv copy.""" - return { - "dtype": str(conv_states.dtype), - "dim_size": conv_states.shape[1], - "width_size": conv_states.shape[2], - } - - -def _get_conv_copy_run_key( - conv_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for conv copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_conv_state_copy:v1", - configs_gen_func=_get_conv_copy_configs, - static_key_func=_get_conv_copy_static_key, - run_key_func=_get_conv_copy_run_key, - mutates_args=["conv_states"], -) -def copy_conv_states( - conv_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Copy conv states from source indices to destination indices. - - Args: - conv_states: Conv state buffer [num_buffers, dim, conv_width-1] - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - batch_size = src_indices.shape[0] - dim_size = conv_states.shape[1] - width_size = conv_states.shape[2] - - if run_config is None: - BLOCK_DIM = triton.next_power_of_2(min(dim_size, 128)) - BLOCK_WIDTH = triton.next_power_of_2(min(width_size, 4)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_DIM = run_config["BLOCK_DIM"] - BLOCK_WIDTH = run_config["BLOCK_WIDTH"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_dim_blocks = triton.cdiv(dim_size, BLOCK_DIM) - num_width_blocks = triton.cdiv(width_size, BLOCK_WIDTH) - num_blocks_total = num_dim_blocks * num_width_blocks - - grid = (batch_size, num_blocks_total) - - _copy_conv_state_kernel[grid]( - conv_states, - src_indices, - dst_indices, - batch_size, - dim_size, - width_size, - num_width_blocks, # Pass precomputed value - conv_states.stride(0), - conv_states.stride(1), - conv_states.stride(2), - BLOCK_DIM=BLOCK_DIM, - BLOCK_WIDTH=BLOCK_WIDTH, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def _get_ssm_copy_configs(): - """Generate candidate configurations for SSM state copy.""" - configs = [] - for block_key in [16, 32, 64]: - for block_value in [16, 32, 64, 128]: - for num_warps in [2, 4, 8]: - if block_key * block_value <= 4096: - configs.append( - { - "BLOCK_KEY": block_key, - "BLOCK_VALUE": block_value, - "num_warps": num_warps, - "num_stages": 2, - } - ) - return configs - - -def _get_ssm_copy_static_key( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for SSM copy.""" - return { - "dtype": str(ssm_states.dtype), - "num_heads": ssm_states.shape[1], - "key_dim": ssm_states.shape[2], - "value_dim": ssm_states.shape[3], - } - - -def _get_ssm_copy_run_key( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for SSM copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_ssm_state_copy:v1", - configs_gen_func=_get_ssm_copy_configs, - static_key_func=_get_ssm_copy_static_key, - run_key_func=_get_ssm_copy_run_key, - mutates_args=["ssm_states"], -) -def copy_ssm_states( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Copy SSM states from source indices to destination indices. - - Args: - ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - batch_size = src_indices.shape[0] - num_heads = ssm_states.shape[1] - key_dim = ssm_states.shape[2] - value_dim = ssm_states.shape[3] - - if run_config is None: - BLOCK_KEY = triton.next_power_of_2(min(key_dim, 32)) - BLOCK_VALUE = triton.next_power_of_2(min(value_dim, 64)) - num_warps = 4 - num_stages = 2 - else: - BLOCK_KEY = run_config["BLOCK_KEY"] - BLOCK_VALUE = run_config["BLOCK_VALUE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_key_blocks = triton.cdiv(key_dim, BLOCK_KEY) - num_value_blocks = triton.cdiv(value_dim, BLOCK_VALUE) - num_blocks_total = num_key_blocks * num_value_blocks - - grid = (batch_size, num_heads, num_blocks_total) - - _copy_ssm_state_kernel[grid]( - ssm_states, - src_indices, - dst_indices, - batch_size, - num_heads, - key_dim, - value_dim, - ssm_states.stride(0), - ssm_states.stride(1), - ssm_states.stride(2), - ssm_states.stride(3), - BLOCK_KEY=BLOCK_KEY, - BLOCK_VALUE=BLOCK_VALUE, - num_warps=num_warps, - num_stages=num_stages, - ) - - -# ============================================================================= -# Optimized Flat Copy Kernels (for contiguous memory) -# ============================================================================= -# These kernels leverage the fact that both conv_states and ssm_states are -# contiguous in memory, allowing us to flatten the inner dimensions and use -# efficient 1D vectorized copy patterns. - - -@triton.jit -def _copy_state_flat_kernel( - # State buffer pointer (flattened view) - state_ptr, - # Buffer indices - src_indices_ptr, - dst_indices_ptr, - # Dimensions - batch_size, - flat_size, # Total elements per buffer entry (flattened inner dims) - # Strides - stride_buffer, # Stride to next buffer entry (in elements) - # Block size - BLOCK_SIZE: tl.constexpr, -): - """ - Optimized flat copy kernel for contiguous state buffers. - - Instead of using 2D/3D block patterns with stride calculations, this kernel - treats each buffer entry as a flat 1D array and uses vectorized loads/stores - for efficient memory transfer. - - Grid: (batch_size, num_blocks) where num_blocks = ceil(flat_size / BLOCK_SIZE) - """ - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - # Calculate element range for this block - elem_start = block_idx * BLOCK_SIZE - elem_offsets = elem_start + tl.arange(0, BLOCK_SIZE) - elem_mask = elem_offsets < flat_size - - # Load buffer indices for this batch item - src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) - dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) - - # Calculate source and destination base pointers - src_base = state_ptr + src_idx * stride_buffer - dst_base = state_ptr + dst_idx * stride_buffer - - # Vectorized copy - data = tl.load(src_base + elem_offsets, mask=elem_mask, other=0.0) - tl.store(dst_base + elem_offsets, data, mask=elem_mask) - - -@triton.jit -def _copy_states_fused_kernel( - # Conv state buffer (flattened view) - conv_state_ptr, - # SSM state buffer (flattened view) - ssm_state_ptr, - # Buffer indices - src_indices_ptr, - dst_indices_ptr, - # Dimensions - batch_size, - conv_flat_size, # Total elements per conv buffer entry - ssm_flat_size, # Total elements per ssm buffer entry - # Strides (in elements) - conv_stride_buffer, - ssm_stride_buffer, - # Block sizes - CONV_BLOCK_SIZE: tl.constexpr, - SSM_BLOCK_SIZE: tl.constexpr, -): - """ - Fused kernel to copy both conv_states and ssm_states in a single launch. - - This reduces kernel launch overhead by processing both state copies together. - Each thread block handles one batch item and copies both states sequentially. - - Grid: (batch_size, max(conv_blocks, ssm_blocks)) - """ - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - # Load buffer indices (same for both conv and ssm) - src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) - dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) - - # ========== Copy Conv State ========== - conv_num_blocks = tl.cdiv(conv_flat_size, CONV_BLOCK_SIZE) - if block_idx < conv_num_blocks: - conv_elem_start = block_idx * CONV_BLOCK_SIZE - conv_elem_offsets = conv_elem_start + tl.arange(0, CONV_BLOCK_SIZE) - conv_mask = conv_elem_offsets < conv_flat_size - - conv_src_base = conv_state_ptr + src_idx * conv_stride_buffer - conv_dst_base = conv_state_ptr + dst_idx * conv_stride_buffer - - conv_data = tl.load(conv_src_base + conv_elem_offsets, mask=conv_mask, other=0.0) - tl.store(conv_dst_base + conv_elem_offsets, conv_data, mask=conv_mask) - - # ========== Copy SSM State ========== - ssm_num_blocks = tl.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) - if block_idx < ssm_num_blocks: - ssm_elem_start = block_idx * SSM_BLOCK_SIZE - ssm_elem_offsets = ssm_elem_start + tl.arange(0, SSM_BLOCK_SIZE) - ssm_mask = ssm_elem_offsets < ssm_flat_size - - ssm_src_base = ssm_state_ptr + src_idx * ssm_stride_buffer - ssm_dst_base = ssm_state_ptr + dst_idx * ssm_stride_buffer - - ssm_data = tl.load(ssm_src_base + ssm_elem_offsets, mask=ssm_mask, other=0.0) - tl.store(ssm_dst_base + ssm_elem_offsets, ssm_data, mask=ssm_mask) - - -def _get_flat_copy_configs(): - """Generate candidate configurations for flat copy kernel.""" - configs = [] - # Larger block sizes for better memory throughput on contiguous data - for block_size in [256, 512, 1024, 2048]: - for num_warps in [4, 8]: - configs.append( - { - "BLOCK_SIZE": block_size, - "num_warps": num_warps, - "num_stages": 2, - } - ) - return configs - - -def _get_conv_flat_copy_static_key( - conv_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for conv flat copy.""" - return { - "dtype": str(conv_states.dtype), - "flat_size": conv_states.shape[1] * conv_states.shape[2], - } - - -def _get_conv_flat_copy_run_key( - conv_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for conv flat copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_conv_state_flat_copy:v1", - configs_gen_func=_get_flat_copy_configs, - static_key_func=_get_conv_flat_copy_static_key, - run_key_func=_get_conv_flat_copy_run_key, - mutates_args=["conv_states"], -) -def copy_conv_states_flat( - conv_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Optimized flat copy for conv states leveraging contiguous memory. - - Args: - conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - assert conv_states.is_contiguous(), "conv_states must be contiguous for flat copy" - - batch_size = src_indices.shape[0] - # Flatten inner dimensions - flat_size = conv_states.shape[1] * conv_states.shape[2] - stride_buffer = conv_states.stride(0) - - if run_config is None: - BLOCK_SIZE = 1024 - num_warps = 4 - num_stages = 2 - else: - BLOCK_SIZE = run_config["BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) - grid = (batch_size, num_blocks) - - _copy_state_flat_kernel[grid]( - conv_states, - src_indices, - dst_indices, - batch_size, - flat_size, - stride_buffer, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def _get_ssm_flat_copy_static_key( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for ssm flat copy.""" - return { - "dtype": str(ssm_states.dtype), - "flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], - } - - -def _get_ssm_flat_copy_run_key( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for ssm flat copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_ssm_state_flat_copy:v1", - configs_gen_func=_get_flat_copy_configs, - static_key_func=_get_ssm_flat_copy_static_key, - run_key_func=_get_ssm_flat_copy_run_key, - mutates_args=["ssm_states"], -) -def copy_ssm_states_flat( - ssm_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Optimized flat copy for SSM states leveraging contiguous memory. - - Args: - ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - assert ssm_states.is_contiguous(), "ssm_states must be contiguous for flat copy" - - batch_size = src_indices.shape[0] - # Flatten inner dimensions (num_heads * key_dim * value_dim) - flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] - stride_buffer = ssm_states.stride(0) - - if run_config is None: - BLOCK_SIZE = 1024 - num_warps = 4 - num_stages = 2 - else: - BLOCK_SIZE = run_config["BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) - grid = (batch_size, num_blocks) - - _copy_state_flat_kernel[grid]( - ssm_states, - src_indices, - dst_indices, - batch_size, - flat_size, - stride_buffer, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def _get_fused_copy_configs(): - """Generate candidate configurations for fused copy kernel.""" - configs = [] - # Use power-of-2 block sizes for both conv and ssm - for conv_block in [256, 512, 1024]: - for ssm_block in [256, 512, 1024]: - for num_warps in [4, 8]: - configs.append( - { - "CONV_BLOCK_SIZE": conv_block, - "SSM_BLOCK_SIZE": ssm_block, - "num_warps": num_warps, - "num_stages": 2, - } - ) - return configs - - -def _get_fused_copy_static_key( - conv_states: torch.Tensor, - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Static key for fused copy.""" - return { - "conv_dtype": str(conv_states.dtype), - "ssm_dtype": str(ssm_states.dtype), - "conv_flat_size": conv_states.shape[1] * conv_states.shape[2], - "ssm_flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], - } - - -def _get_fused_copy_run_key( - conv_states: torch.Tensor, - ssm_states: torch.Tensor, - src_indices: torch.Tensor, -): - """Run key for fused copy.""" - return src_indices.shape[0] - - -@autotune( - kernel_name="gdn_states_fused_copy:v1", - configs_gen_func=_get_fused_copy_configs, - static_key_func=_get_fused_copy_static_key, - run_key_func=_get_fused_copy_run_key, - mutates_args=["conv_states", "ssm_states"], -) -def copy_states_fused( - conv_states: torch.Tensor, - ssm_states: torch.Tensor, - src_indices: torch.Tensor, - dst_indices: torch.Tensor, - run_config: dict = None, -): - """ - Fused copy for both conv and SSM states in a single kernel launch. - - This reduces kernel launch overhead by processing both state copies together. - - Args: - conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) - ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) - src_indices: Source buffer indices [batch_size] - dst_indices: Destination buffer indices [batch_size] - run_config: Auto-tuned configuration - """ - assert conv_states.is_contiguous(), "conv_states must be contiguous for fused copy" - assert ssm_states.is_contiguous(), "ssm_states must be contiguous for fused copy" - - batch_size = src_indices.shape[0] - - # Flatten inner dimensions - conv_flat_size = conv_states.shape[1] * conv_states.shape[2] - ssm_flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] - - conv_stride_buffer = conv_states.stride(0) - ssm_stride_buffer = ssm_states.stride(0) - - if run_config is None: - CONV_BLOCK_SIZE = 512 - SSM_BLOCK_SIZE = 512 - num_warps = 4 - num_stages = 2 - else: - CONV_BLOCK_SIZE = run_config["CONV_BLOCK_SIZE"] - SSM_BLOCK_SIZE = run_config["SSM_BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - # Grid covers both conv and ssm blocks - conv_num_blocks = triton.cdiv(conv_flat_size, CONV_BLOCK_SIZE) - ssm_num_blocks = triton.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) - max_blocks = max(conv_num_blocks, ssm_num_blocks) - grid = (batch_size, max_blocks) - - _copy_states_fused_kernel[grid]( - conv_states, - ssm_states, - src_indices, - dst_indices, - batch_size, - conv_flat_size, - ssm_flat_size, - conv_stride_buffer, - ssm_stride_buffer, - CONV_BLOCK_SIZE=CONV_BLOCK_SIZE, - SSM_BLOCK_SIZE=SSM_BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) diff --git a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py deleted file mode 100644 index 0a2b4bd662..0000000000 --- a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch - -import triton -import triton.language as tl - -from lightllm.common.triton_utils.autotuner import autotune - - -@triton.jit -def _gemma_rmsnorm_fwd_kernel( - x_ptr, - w_ptr, - y_ptr, - x_stride0, - x_stride1, - y_stride0, - y_stride1, - N: tl.constexpr, - EPS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - row = tl.program_id(0) - x_ptr = x_ptr + row * x_stride0 - y_ptr = y_ptr + row * y_stride0 - - _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(x_ptr + cols * x_stride1, mask=cols < N, other=0.0).to(tl.float32) - _sum += x * x - - var = tl.sum(_sum, axis=0) / N - rstd = 1 / tl.sqrt(var + EPS) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) - x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - w = w + 1.0 - y = x_hat * w - # Write output - tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) - - -def _get_gemma_rmsnorm_configs(): - """Generate configurations for autotuning gemma RMSNorm kernel.""" - configs = [] - for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: - for num_warps in [1, 2, 4, 8]: - # num_stages has minimal impact on this simple kernel, use 1 - configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1}) - return configs - - -def _get_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): - """Generate static key for caching autotuned configurations.""" - N = x.shape[-1] - return { - "x_dtype": str(x.dtype), - "weight_dtype": str(w.dtype), - "N": N, - } - - -@autotune( - kernel_name="gemma_rmsnorm_forward:v1", - configs_gen_func=_get_gemma_rmsnorm_configs, - static_key_func=_get_gemma_rmsnorm_static_key, - run_key_func=lambda x: x.shape[-1], -) -def gemma_rmsnorm_forward(x, w, eps, out=None, run_config: dict = None): - # Inplace gemma RMS Norm - # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - N = x.shape[-1] - y = torch.empty_like(x) if out is None else out - x_arg = x.view(-1, N) - y_arg = y.view(-1, N) - - M, _ = x_arg.shape - - # Default heuristic when autotune is disabled or no config provided - if not run_config: - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_SIZE: - raise RuntimeError("This gemma rmsnorm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} - - BLOCK_SIZE = run_config["BLOCK_SIZE"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - _gemma_rmsnorm_fwd_kernel[(M,)]( - x_arg, - w, - y_arg, - x_stride0=x.stride(0), - x_stride1=x.stride(1), - y_stride0=y.stride(0), - y_stride1=y.stride(1), - N=N, - EPS=eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=num_stages, - ) - - return y - - -def _gemma_rmsnorm_fwd_torch(x, weight, eps): - original_dtype = x.dtype - x = x.to(torch.float32) - x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - x = x * (1.0 + weight.float()) - return x.to(original_dtype) - - -def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): - # create data - x_shape = (M, N) - w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - # forward pass - y_tri = gemma_rmsnorm_forward(x, weight, eps) - y_ref = _gemma_rmsnorm_fwd_torch(x, weight, eps) - - # compare - print("type:", y_tri.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - # Use appropriate tolerance based on dtype - atol = 1e-2 if dtype == torch.float32 else 5e-2 - assert torch.allclose(y_tri, y_ref, atol=atol, rtol=0) - return diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py deleted file mode 100644 index 6fb3f3adeb..0000000000 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ /dev/null @@ -1,158 +0,0 @@ -from typing import Set, Protocol, List, Optional, Tuple - -import torch -from sortedcontainers import SortedSet - -from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode -from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class HybridRadixCache(RadixCache): - def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager): - super().__init__(unique_name, total_token_num, rank_in_node, kv_cache_mem_manager) - assert hasattr(kv_cache_mem_manager, "mamba_cache_mem_manager") - self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.mamba_cache_mem_manager - self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,)) - - def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): - if need_buffer_num > self.buffer_mem_manager.can_use_mem_size: - need_evict_buffer_num = need_buffer_num - self.buffer_mem_manager.can_use_mem_size - release_buffers = [] - - def release_buffer(buffer_idx): - release_buffers.append(buffer_idx) - return - - self._evict_buffer(need_evict_buffer_num, release_buffer) - if len(release_buffers) > 0: - self.buffer_mem_manager.free(release_buffers) - return - - def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback): - while need_evict_buffer_num > 0: - node = self.evict_buffer_set.pop(0) - assert node.buffer_idx is not None - evict_buffer_callback(node.buffer_idx) - node.buffer_idx = None - need_evict_buffer_num -= 1 - return - - def match_prefix(self, key, update_refs=False): - assert len(key) != 0 - ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - miss_prefix_len = 0 - evict_token_list = [] - kv_len = tree_node.node_prefix_total_len - while tree_node != self.root_node and tree_node.buffer_idx is None: - if tree_node.is_leaf(): - self.evict_tree_set.discard(tree_node) - - # Only update ref_counter when update_refs is True to maintain consistency - # with _match_prefix_helper which only increments ref_counter when update_refs=True - if update_refs: - if tree_node.ref_counter == 1: - self.refed_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) - tree_node.ref_counter -= 1 # 只减少当前节点,不递归 - - if tree_node.is_leaf() and tree_node.ref_counter == 0: - evict_token_list.append(tree_node.token_mem_index_value) - self.tree_total_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) - parent_node: TreeNode = tree_node.parent - parent_node.remove_child(tree_node) - if parent_node.is_leaf(): - self.evict_tree_set.add(parent_node) - tree_node = parent_node - else: - if tree_node.is_leaf(): - self.evict_tree_set.add(tree_node) - tree_node = tree_node.parent - miss_prefix_len += len(ans_value_list.pop()) - - if len(evict_token_list) > 0: - evict_token_value = torch.concat(evict_token_list) - self.mem_manager.free(evict_token_value) - - if tree_node == self.root_node: - return None, kv_len - miss_prefix_len, None - - update_node = tree_node - while update_node != self.root_node: - if update_node.buffer_idx is not None: - self.evict_buffer_set.discard(update_node) - update_node.update_buffer_time() - self.evict_buffer_set.add(update_node) - update_node = update_node.parent - - value = torch.concat(ans_value_list) - return tree_node, miss_prefix_len, value - - def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int): - """Set buffer_idx for a node and add it to evict_buffer_set.""" - self.evict_buffer_set.discard(node) - if node.is_leaf(): - self.evict_tree_set.discard(node) - if node.buffer_idx is not None: - self.buffer_mem_manager.free([node.buffer_idx]) - node.buffer_idx = buffer_idx - node.update_buffer_time() - self.evict_buffer_set.add(node) - if node.is_leaf(): - self.evict_tree_set.add(node) - return - - def free_radix_cache_to_get_enough_token(self, need_token_num): - assert self.mem_manager is not None - if need_token_num > self.mem_manager.can_use_mem_size: - need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size - release_mems = [] - - def release_mem(mem_index): - release_mems.append(mem_index) - return - - release_buffers = [] - - def release_buffer(buffer_idx): - release_buffers.append(buffer_idx) - return - - self.evict(need_evict_token_num, release_buffer, release_mem) - mem_index = torch.concat(release_mems) - self.mem_manager.free(mem_index) - if len(release_buffers) > 0: - self.buffer_mem_manager.free(release_buffers) - return - - def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback): - if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: - assert False, f"""can not free tree tokens {need_remove_tokens}, - tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, - refed_tokens_num {self.refed_tokens_num.arr[0]}""" - num_evicted = 0 - while num_evicted < need_remove_tokens: - node: TreeNode = self.evict_tree_set.pop(0) - assert ( - node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node - ), f"error evict tree node state: {node.ref_counter}, {len(node.children)}" - num_evicted += len(node.token_mem_index_value) - evict_callback(node.token_mem_index_value) - if node.buffer_idx is not None: - self.evict_buffer_set.discard(node) - evict_buffer_callback(node.buffer_idx) - node.buffer_idx = None - # update total token num - self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) - parent_node: TreeNode = node.parent - parent_node.remove_child(node) - if parent_node.is_leaf(): - self.evict_tree_set.add(parent_node) - - return - - def flush_cache(self): - super().flush_cache() - self.evict_buffer_set.clear() From 1ecf015f3a29221ea2f1dab3855458e445f3d571 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 8 May 2026 14:09:23 +0000 Subject: [PATCH 172/180] slim code --- .../layer_infer/transformer_layer_infer.py | 0 lightllm/models/llama/model.py | 53 +--------- .../qwen3next/layer_infer/post_layer_infer.py | 12 --- .../layer_infer/shared_expert_mixin.py | 96 ------------------- 4 files changed, 5 insertions(+), 156 deletions(-) mode change 100755 => 100644 lightllm/models/llama/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/qwen3next/layer_infer/post_layer_infer.py delete mode 100644 lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py old mode 100755 new mode 100644 diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index cc1dc28178..c104ebccc9 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -74,19 +74,14 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() - elif "rope_type" in rope_scaling: + return + + if "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] - self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] - self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") - if "rope_theta_hw" in self.config: - self._init_to_get_hw_rotary() - super()._init_custom() - - def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -101,6 +96,7 @@ def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) @@ -110,6 +106,7 @@ def _init_to_get_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta", float(default_base)) + if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: @@ -142,46 +139,6 @@ def _init_to_get_rotary(self, default_base=10000): self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return - def _init_to_get_hw_rotary(self, default_base=10000): - partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) - if self.config.get("rope_scaling", {}) is None: - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - - base = self.config.get("rope_theta_hw", float(default_base)) - if "max_sequence_length" in self.config: - max_seq_len = self.config["max_sequence_length"] - else: - max_position_embeddings = self.config.get( - "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 - ) - max_seq_len = max_position_embeddings * rope_scaling_factor - - # NTK - try: - ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula - except: - pass - - inv_freq = 1.0 / ( - base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) - ) - t = ( - torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) - / rope_scaling_factor - ) - freqs = torch.outer(t, inv_freq) - - self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() - return - def _init_to_get_dynamic_ntk_rotary(self): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) max_position_embeddings = self.config.get("max_position_embeddings", 2048) diff --git a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py deleted file mode 100644 index 9dcab4e6fc..0000000000 --- a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch - -from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward - - -class Qwen3NextPostLayerInfer(LlamaPostLayerInfer): - def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) - gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) - return out diff --git a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py b/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py deleted file mode 100644 index be9000fcad..0000000000 --- a/lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py +++ /dev/null @@ -1,96 +0,0 @@ -# lightllm/models/qwen3next/layer_infer/shared_expert_mixin.py -import torch.nn.functional as F -from functools import partial -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -import os - - -class SharedExpertFFNMixin: - """ - Mixin providing shared expert + MoE FFN implementations. - - Used by both full attention and GDN layers in Qwen3Next. - - Requirements: - - Class must have: embed_dim_, tp_world_size_, alloc_tensor() - - Class must have MoE config: is_moe, n_routed_experts, num_experts_per_tok, norm_topk_prob - """ - - def _bind_ffn(self): - """Bind FFN implementation based on MoE configuration.""" - if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": - self._ffn = partial(SharedExpertFFNMixin._ffn_with_shared_expert_ep, self) - else: - self._ffn = partial(SharedExpertFFNMixin._ffn_with_shared_expert_tp, self) - else: - self._ffn = partial(SharedExpertFFNMixin._standard_ffn, self) - return - - def _ffn_core(self, input, layer_weight): - """Core FFN computation: gate_up -> silu_and_mul -> down.""" - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) - return ffn2_out, input - - def _standard_ffn(self, input, infer_state, layer_weight): - """Standard FFN using shared expert weights (non-MoE layers).""" - ffn2_out, _ = self._ffn_core(input, layer_weight) - return ffn2_out - - def _compute_shared_expert(self, input, layer_weight): - """Compute shared expert FFN output with gating.""" - ffn2_out, input_view = self._ffn_core(input, layer_weight) - return F.sigmoid(layer_weight.shared_expert_gate.mm(input_view)) * ffn2_out, input_view - - def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight): - """FFN with shared expert + MoE (tensor parallelism mode).""" - shared_expert_out, input = self._compute_shared_expert(input, layer_weight) - moe_out = self._moe_ffn(input, infer_state, layer_weight) - return shared_expert_out + moe_out - - def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight): - """FFN with shared expert + MoE (expert parallelism mode).""" - shared_expert_out, input = self._compute_shared_expert(input, layer_weight) - moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) - return shared_expert_out + moe_out - - def _moe_ffn(self, input, infer_state, layer_weight): - """MoE FFN with tensor parallelism.""" - hidden_states = input.view(-1, self.embed_dim_) - num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(hidden_states) - layer_weight.experts.experts( - hidden_states, - router_logits=router_logits, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - use_grouped_topk=False, - topk_group=None, - num_expert_group=None, - ) - return hidden_states.view(num_tokens, hidden_dim) - - def _moe_ffn_edp(self, input, infer_state, layer_weight): - """MoE FFN with expert parallelism.""" - hidden_states = input - token_num, hidden_dim = hidden_states.shape - - router_logits = layer_weight.moe_gate.mm(hidden_states) - ep_output = layer_weight.experts.experts( - hidden_states, - router_logits=router_logits, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - use_grouped_topk=False, - topk_group=None, - num_expert_group=None, - is_prefill=infer_state.is_prefill, - ) - - ep_output = ep_output.view(token_num, hidden_dim) - return ep_output From a93dcb6ba32d0c868e2428e3e299c78ff64aafcf Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 9 May 2026 06:57:27 +0000 Subject: [PATCH 173/180] slime code --- lightllm/common/basemodel/basemodel.py | 1 - lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 4 +-- .../layer_weights/transformer_layer_weight.py | 36 ++----------------- .../mode_backend/chunked_prefill/impl.py | 1 - lightllm/utils/device_utils.py | 3 +- lightllm/utils/envs_utils.py | 2 -- lightllm/utils/kv_cache_utils.py | 22 ------------ 7 files changed, 5 insertions(+), 64 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 241ff07780..219343e2c6 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -12,7 +12,6 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo -from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 825a985b46..7156a5ce23 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -227,14 +227,14 @@ def _init_datatype(self): def rot_pos_emb(self, grid_thw): pos_ids = [] s = self.spatial_merge_size - for t, h, w in grid_thw: + for _, h, w in grid_thw: pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() cos_full, sin_full = self.rotary_pos_emb(max_grid_size) diff --git a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py index 204700fa34..44425e7e10 100644 --- a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py @@ -12,7 +12,6 @@ def load_hf_weights(self, weights): def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_size: int): layer_prefix = f"model.layers.{layer_num}." keys = list(weights.keys()) - num_experts = 0 for k in keys: if not k.startswith(layer_prefix): @@ -20,39 +19,8 @@ def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_s if "mlp.experts.gate_up_proj" in k: fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size] - num_experts = fused_weight.shape[0] - prefix = k.rsplit(".gate_up_proj", 1)[0] gate_weight = fused_weight[:, :moe_intermediate_size, :] up_weight = fused_weight[:, moe_intermediate_size:, :] - - for expert_idx in range(num_experts): - weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] - weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] - - if "mlp.experts.gate_proj" in k: - gate_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] - num_experts = gate_weight.shape[0] - - prefix = k.rsplit(".gate_proj", 1)[0] - - for expert_idx in range(num_experts): - weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] - - elif "mlp.experts.up_proj" in k: - up_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] - num_experts = up_weight.shape[0] - - prefix = k.rsplit(".up_proj", 1)[0] - - for expert_idx in range(num_experts): - weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] - - elif "mlp.experts.down_proj" in k: - down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] - num_experts = down_weight.shape[0] - - prefix = k.rsplit(".down_proj", 1)[0] - - for expert_idx in range(num_experts): - weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx] + weights[f"{prefix}.gate_proj"] = gate_weight + weights[f"{prefix}.up_proj"] = up_weight diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 9a425bd10f..5875dcca0c 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -24,7 +24,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args -from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from .control_state import ControlState logger = init_logger(__name__) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 68afd0b613..43b10ec88b 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -112,8 +112,7 @@ def get_current_device_name(): gpu_name = gpu_name.replace(" ", "_") return gpu_name else: - return "unknown" # need fix - # raise RuntimeError("No GPU available") + return None @lru_cache(maxsize=None) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 151db5e01a..350507e897 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -23,8 +23,6 @@ def set_unique_server_name(args): @lru_cache(maxsize=None) def get_unique_server_name(): service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") - assert "None" not in service_uni_name, "service_uni_name is not set" - return service_uni_name diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 51996618ee..55284b27f3 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -114,28 +114,6 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": scale_head_dim=get_head_dim(args.model_dir) // 8, scale_data_type=get_llm_data_type(), ) - elif mem_manager_class is PPLINT8KVMemoryManager: - cpu_cache_meta = CpuKVCacheMeta( - page_num=0, - token_page_size=args.cpu_cache_token_page_size, - layer_num=get_layer_num(args.model_dir), - num_heads=get_num_key_value_heads(args.model_dir) * 2, - head_dim=get_head_dim(args.model_dir) * 2, - data_type=get_llm_data_type(), - scale_head_dim=0, - scale_data_type=get_llm_data_type(), - ) - elif mem_manager_class is MemoryManager: - cpu_cache_meta = CpuKVCacheMeta( - page_num=0, - token_page_size=args.cpu_cache_token_page_size, - layer_num=get_layer_num(args.model_dir), - num_heads=get_num_key_value_heads(args.model_dir) * 2, - head_dim=get_head_dim(args.model_dir), - data_type=get_llm_data_type(), - scale_head_dim=0, - scale_data_type=get_llm_data_type(), - ) else: logger.error(f"not support mem manager: {mem_manager_class} for cpu kv cache") raise Exception(f"not support mem manager: {mem_manager_class} for cpu kv cache") From ccc88322107ed45dc8f50dd6a8be51bf06ec327e Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 9 May 2026 07:07:14 +0000 Subject: [PATCH 174/180] slime radix cache --- .../server/router/dynamic_prompt/radix_cache.py | 16 +++------------- .../req_queue/chunked_prefill/beam_impl.py | 2 +- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index e866ae4222..5a3352b184 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -31,12 +31,6 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 - # Used by hybrid attention models (e.g., Qwen3Next) to track - # a per-request buffer_idx alongside the token-level KV cache. - # Pure attention models keep buffer_idx as None. - self.buffer_idx = None - self.buffer_time = time_gen.generate_time_id() - def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) @@ -84,9 +78,6 @@ def remove_child(self, child_node: "TreeNode"): def update_time(self): self.time_id = time_gen.generate_time_id() - def update_buffer_time(self): - self.buffer_time = time_gen.generate_time_id() - def is_leaf(self): return len(self.children) == 0 @@ -112,10 +103,10 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager=None): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): from lightllm.common.kv_cache_mem_manager import MemoryManager - self.mem_manager: MemoryManager = kv_cache_mem_manager + self.mem_manager: MemoryManager = mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 @@ -368,7 +359,6 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: or parent_node.ref_counter != 0 or len(parent_node.children) != 1 or child_node.ref_counter != 0 - or parent_node.buffer_idx is not None ): return None @@ -523,7 +513,7 @@ def _print_helper(self, node: TreeNode, indent): " " * indent, f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ - node_value_len: {node.node_value_len} buffer_idx: {node.buffer_idx}", + node_value_len: {node.node_value_len}", ) for _, child in node.children.items(): self._print_helper(child, indent=indent + 2) diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index 1e6340a8bb..23f94de704 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -127,7 +127,7 @@ def generate_new_batch(self, current_batch: Batch): ok_insert, new_batch_first_router_need_tokens = self._can_add_new_group_reqs( cur_group_reqs, is_busy, new_batch_first_router_need_tokens ) - if ok_insert and False: + if ok_insert: can_run_list.extend(cur_group_reqs) new_batch = None From 11ea37a1efaacd9f6457953126bcbf99345c4ec6 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 9 May 2026 07:16:53 +0000 Subject: [PATCH 175/180] slime radixcache --- .../router/dynamic_prompt/radix_cache.py | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 5a3352b184..955e808600 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -426,27 +426,7 @@ def clear_tree_nodes(self): return def flush_cache(self): - nodes_to_clear = collections.deque(self.root_node.children.values()) - self.root_node.children.clear() - while nodes_to_clear: - node = nodes_to_clear.popleft() - nodes_to_clear.extend(node.children.values()) - node.parent = None - node.children.clear() - - self.root_node.token_id_key[:] = 0 - self.root_node.token_mem_index_value[:] = 0 - self.root_node.ref_counter = 1 # 保持为1,确保不会被evict - self.root_node.time_id = time_gen.generate_time_id() - self.root_node.node_value_len = 0 - self.root_node.node_prefix_total_len = 0 - - self.evict_tree_set.clear() - self.evict_tree_set.add(self.root_node) - - self.tree_total_tokens_num.arr[0] = 0 - self.refed_tokens_num.arr[0] = 0 - + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) return def dec_node_ref_counter(self, node: TreeNode): From f446e5ba9f90f22a5ba717bc53ba7c3ade09ac1f Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 9 May 2026 07:37:29 +0000 Subject: [PATCH 176/180] slim code --- lightllm/server/core/objs/shm_array.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py index 1bf20535ad..c5ad512c6b 100644 --- a/lightllm/server/core/objs/shm_array.py +++ b/lightllm/server/core/objs/shm_array.py @@ -26,19 +26,6 @@ def link_shm(self): self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) return - def link_shm_partial(self): - """Link to an existing SHM that may be larger than the needed shape.""" - self.shm = create_or_link_shm(self.name, -1, force_mode="link") - assert self.shm.size >= self.dest_size, f"SHM {self.name} too small: need {self.dest_size}, got {self.shm.size}" - self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) - - def detach_shm(self): - """Close handle without unlinking (SHM persists for reuse).""" - if self.shm is not None: - self.shm.close() - self.shm = None - self.arr = None - def close_shm(self): if self.shm is not None: self.shm.close() From 998020a18dbfadbd1b7d5bf4d180aec3b4030648 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 9 May 2026 08:39:01 +0000 Subject: [PATCH 177/180] remove unused code --- .../mode_backend/chunked_prefill/impl.py | 17 ----------------- .../model_infer/mode_backend/dp_backend/impl.py | 14 -------------- 2 files changed, 31 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 5875dcca0c..e068c00c76 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -272,23 +272,6 @@ def decode_mtp( gpu_tensor=mtp_accept_len, ) - # Copy accepted buffer states back to buffer[0] for MTP - # Only copy when accept_len > 1 (accept_len == 1 means buffer[0] is already correct) - mask = mtp_accept_len > 1 - if mask.sum() > 0: - actual_req_idxes = model_input.b_req_idx[b_req_mtp_start_loc[mask]] - # Source: the accepted buffer (at index accept_len - 1) - src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ - actual_req_idxes, mtp_accept_len[mask] - 1 - ] - # Destination: buffer[0] for each request - dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] - # P2P copy both conv_states and ssm_states - if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_state_buffers"): - g_infer_context.req_manager.buffer_mem_manager.copy_state_buffers( - src_buffer_indexes, dst_buffer_indexes - ) - verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 243b8392a2..c9484dba6f 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -479,20 +479,6 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): gpu_tensor=mtp_accept_len, ) - # Copy accepted buffer states back to buffer[0] for MTP - # Only copy when accept_len > 1 - mask = mtp_accept_len > 1 - if mask.sum() > 0: - actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] - src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ - actual_req_idxes, mtp_accept_len[mask] - 1 - ] - dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] - if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): - g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( - src_buffer_indexes, dst_buffer_indexes - ) - verify_event = torch.cuda.Event() verify_event.record() From 5a745e5585b85ac95cfa4037e5ac01ae94077247 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 9 May 2026 09:14:49 +0000 Subject: [PATCH 178/180] fix --- lightllm/server/audioserver/manager.py | 7 +++++-- lightllm/server/multi_level_kv_cache/manager.py | 3 +++ lightllm/server/visualserver/manager.py | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index c1cf21cc7a..e78d56fc4f 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -12,7 +12,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.utils.log_utils import init_logger -from lightllm.server.io_struct import GenerateReqIndex +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.multimodal_params import AudioItem from .model_infer import start_model_process, AudioModelRpcClient @@ -153,13 +153,16 @@ async def infer_audios(self, dp_index: int, audios, events): async def loop_for_netio_req(self): try: while True: - recv_req: GenerateReqIndex = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) + recv_req: BaseReq = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) if isinstance(recv_req, GenerateReqIndex): logger.info( f"audio recv req id {recv_req.group_req_id} " f"audio count {len(recv_req.multimodal_params.audios)}" ) asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) + elif isinstance(recv_req, BaseReq): + # RL 等控制类 BaseReq 透传给下一模块,最终由 router 处理 + self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL) else: assert False, f"Error Req Inf {recv_req}" except Exception as e: diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py index f38c60211e..b771ac1a01 100644 --- a/lightllm/server/multi_level_kv_cache/manager.py +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -220,6 +220,9 @@ def recv_loop(self): for _ in range(recv_max_count): recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if not isinstance(recv_obj, GenerateReqIndex): + # RL 等控制类 BaseReq 透传给 router 处理,避免在此被静默丢弃 + if isinstance(recv_obj, BaseReq): + self.send_to_router.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) continue recv_objs.append(recv_obj) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 3ad191d1ff..6b013c4c63 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -174,13 +174,16 @@ async def infer_images(self, dp_index: int, images, events): async def loop_for_netio_req(self): try: while True: - recv_req: GenerateReqIndex = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) + recv_req: BaseReq = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) if isinstance(recv_req, GenerateReqIndex): logger.info( f"visual recv req id {recv_req.group_req_id} " f"img count {len(recv_req.multimodal_params.images)}" ) asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) + elif isinstance(recv_req, BaseReq): + # RL 等控制类 BaseReq 透传给下一模块,最终由 router 处理 + self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL) else: assert False, f"Error Req Inf {recv_req}" except Exception as e: From 90ed5560eb44b6346f341b74c2835bf3ba45a0d6 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 9 May 2026 09:29:54 +0000 Subject: [PATCH 179/180] lazy init cache dir --- lightllm/common/triton_utils/autotuner.py | 31 +++++++++++++++++------ 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..60c12e4c07 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -106,14 +106,10 @@ def __init__( self.configs_gen_func = configs_gen_func self.kernel_name = kernel_name - self.cache_dir = os.path.join( - Path(__file__).parent, - "autotune_kernel_configs", - get_triton_version(), - get_current_device_name(), - self.kernel_name, - ) - os.makedirs(self.cache_dir, exist_ok=True) + # cache_dir 依赖 get_current_device_name(),后者要求 torch.cuda.is_available()。 + # 这里 lazy 化,避免 CPU-only 的进程(例如 Ray driver / verl rollout replica + # 入口)在 import 时就触发 TypeError。 + self._cache_dir: Optional[str] = None self.fn = fn self.static_key_func = static_key_func self.run_key_func = run_key_func @@ -209,6 +205,25 @@ def __call__(self, *args, **kwargs): return self.fn(*args, **kwargs) + @property + def cache_dir(self) -> str: + if self._cache_dir is None: + device_name = get_current_device_name() + if device_name is None: + raise RuntimeError( + f"Autotuner for kernel {self.kernel_name} requires a visible CUDA/MUSA device " + f"to resolve its cache directory, but torch.cuda.is_available() is False." + ) + self._cache_dir = os.path.join( + Path(__file__).parent, + "autotune_kernel_configs", + get_triton_version(), + device_name, + self.kernel_name, + ) + os.makedirs(self._cache_dir, exist_ok=True) + return self._cache_dir + def _try_load_cache(self, static_key): if static_key in self.cached_configs: return False From eb42e5bfafb5d61a6108c2366ec93022c3f1f642 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 9 May 2026 09:57:23 +0000 Subject: [PATCH 180/180] fix linear flush_cache --- .../server/router/dynamic_prompt/linear_att_radix_cache.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py index 73c6dba54d..082f6bec08 100644 --- a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py @@ -467,6 +467,10 @@ def clear_tree_nodes(self): self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) return + def flush_cache(self): + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) + return + def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional[LinearAttPagedTreeNode]: assert not node.is_big_page_node() iter_node = node