diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index 04069edec..1472f0fc4 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -133,7 +133,10 @@ PD 分离模式参数 .. option:: --max_req_total_len - 请求输入长度 + 请求输出长度的最大值,默认为 ``16384`` + 请求输入长度 + 请求输出长度的最大值。若未显式设置,将从模型配置自动推导, + 若推导失败则回退到 ``16384``。 + 对于部分 RoPE 类型(如 ``yarn/dynamic/su/llama3``),推导不会直接用 ``rope_scaling.factor`` + 去乘以 ``max_position_embeddings``,以避免过度估算最大长度。 .. option:: --eos_id @@ -472,14 +475,6 @@ PD 分离模式参数 使用奖励模型 -.. option:: --long_truncation_mode - - 当 input_token_len + max_new_tokens > max_req_total_len 时的处理方式,可选值: - - * ``None``: 抛出异常(默认) - * ``head``: 移除一些头部 token 使 input_token_len + max_new_tokens <= max_req_total_len - * ``center``: 移除中心位置的一些 token 使 input_token_len + max_new_tokens <= max_req_total_len - .. option:: --use_tgi_api 使用 tgi 输入和输出格式 diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index 1b95ceef5..66b01c2ff 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -133,7 +133,10 @@ Memory and Batch Processing Parameters .. option:: --max_req_total_len - Maximum value of request input length + request output length, default is ``16384`` + Maximum value of request input length + request output length. If not set, it will be + automatically derived from model config.json and fall back to ``16384`` if derivation fails. + For some RoPE types (like ``yarn/dynamic/su/llama3``), the derivation does not multiply + ``rope_scaling.factor`` by ``max_position_embeddings`` to avoid over-estimating the max length. .. option:: --eos_id @@ -472,14 +475,6 @@ Sampling and Generation Parameters Use reward model -.. option:: --long_truncation_mode - - How to handle when input_token_len + max_new_tokens > max_req_total_len, optional values: - - * ``None``: Throw exception (default) - * ``head``: Remove some head tokens to make input_token_len + max_new_tokens <= max_req_total_len - * ``center``: Remove some tokens at the center position to make input_token_len + max_new_tokens <= max_req_total_len - .. option:: --use_tgi_api Use tgi input and output format diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 89d90f3c1..a980ef29a 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -212,7 +212,9 @@ def _init_kv_move_buffer(self): def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size - assert self.max_seq_length <= self.max_total_token_num + assert ( + self.max_total_token_num > self.batch_max_tokens + ), "max_total_token_num must be greater than batch_max_tokens" return def _init_req_manager(self): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4f497be12..f33f58b86 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -246,8 +246,10 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--max_req_total_len", type=int, - default=16384, + default=None, help="Maximum allowed length for a request (input tokens + output tokens). " + "If None, it will be automatically derived from the model config.json, " + "and fall back to 16384 if derivation fails. " "In PD (Prefill-Decode) mode, this value must be synchronized across the " "PD master, prefill, and decode nodes.", ) @@ -457,16 +459,6 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--use_reward_model", action="store_true", help="use reward model") - parser.add_argument( - "--long_truncation_mode", - type=str, - choices=[None, "head", "center"], - default=None, - help="""use to select the handle way when input_token_len + max_new_tokens > max_req_total_len. - None : raise Exception - head : remove some head tokens to make input_token_len + max_new_tokens <= max_req_total_len - center : remove some tokens in center loc to make input_token_len + max_new_tokens <= max_req_total_len""", - ) parser.add_argument("--use_tgi_api", action="store_true", help="use tgi input and ouput format") parser.add_argument( "--health_monitor", action="store_true", help="check the health of service and restart when error" diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 6c3f8b3fe..c106ca1cd 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -357,7 +357,8 @@ async def anthropic_messages(raw_request: Request) -> Response: @app.get("/v1/models", response_model=ModelListResponse) async def get_models(raw_request: Request): model_name = g_objs.args.model_name - max_model_len = g_objs.args.max_req_total_len + max_model_len = g_objs.httpserver_manager.get_real_supported_max_req_total_len() + if model_name == "default_model_name" and g_objs.args.model_dir: model_name = os.path.basename(g_objs.args.model_dir.rstrip("/")) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index a687971ed..8c6af128c 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -17,7 +17,12 @@ from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size -from lightllm.utils.config_utils import has_audio_module, has_vision_module, is_linear_att_mixed_model +from lightllm.utils.config_utils import ( + has_audio_module, + has_vision_module, + is_linear_att_mixed_model, + auto_set_max_req_total_len, +) from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args logger = init_logger(__name__) @@ -69,6 +74,7 @@ def normal_or_p_d_start(args): args: StartArgs = args + auto_set_max_req_total_len(args) set_unique_server_name(args) if args.enable_mps: @@ -124,10 +130,13 @@ def normal_or_p_d_start(args): args.running_max_req_size = 3 args.batch_max_tokens = 2048 args.chunked_prefill_size = 1024 - args.mem_fraction = 0.85 + if args.mem_fraction > 0.82: + args.mem_fraction = 0.82 + args.graph_max_batch_size = 32 logger.info( f"performance_mode is personal, set running_max_req_size to 3," - f"batch_max_tokens to 2048, chunked_prefill_size to 1024, mem_fraction to 0.85" + f"batch_max_tokens to 2048, chunked_prefill_size to 1024, mem_fraction to 0.82," + f"graph_max_batch_size to 32" ) if not args.disable_shm_warning: @@ -518,6 +527,8 @@ def pd_master_start(args): if args.run_mode != "pd_master": return + auto_set_max_req_total_len(args) + # when use config_server to support multi pd_master node, we # need generate unique node id for each pd_master node. # otherwise, we use the 0 for single pd_master node. diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index bff5adce2..954daa50f 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -65,7 +65,8 @@ class StartArgs: dp: int = field(default=1) nnodes: int = field(default=1) node_rank: int = field(default=0) - max_req_total_len: int = field(default=2048 + 1024) + # If None, will be automatically derived from model config in `lightllm.server.api_start`. + max_req_total_len: Optional[int] = field(default=None) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=None) use_config_server_to_init_nccl: bool = field(default=False) @@ -100,7 +101,6 @@ class StartArgs: ) return_all_prompt_logprobs: bool = field(default=False) use_reward_model: bool = field(default=False) - long_truncation_mode: Optional[str] = field(default=None, metadata={"choices": [None, "head", "center"]}) use_tgi_api: bool = field(default=False) health_monitor: bool = field(default=False) metric_gateway: Optional[str] = field(default=None) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index c772e97d1..4c049f77c 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -122,6 +122,12 @@ def __init__( # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + # 用于记录真实的--max_total_token_num 参数,当这个参数在启动参数中没有设置的时候,其是在推理进程中被分析出来的, + # 这个时候如果 --max_req_total_len > --max_total_token_num 时,如果httpserver放过一些非法的输入进入后续的模块可能 + # 会触发整个系统崩溃,所以httpserver需要知道真实的 max_total_token_num的数据,用于提前拦截非法请求等参数。 + # router 进程会在启动后向这个共享内存写入正确的max_total_token_num 参数,用于后续的请求控制。 + self.shm_max_total_token_num = SharedInt(f"{get_unique_server_name()}_shm_max_total_token_num") return def _log_stage_timing(self, group_request_id: int, start_time: float, stage: str, **kwargs): @@ -542,37 +548,34 @@ async def _encode( raise ValueError(f"prompt format error, get type{type(prompt)}") return + def get_real_supported_max_req_total_len(self): + # 得到系统真正能支持的最大长度,同时收到启动参数中模型支持长度的限制,也收到token容量的限制。 + return min(self.shm_max_total_token_num.get_value() - 36, self.max_req_total_len) + async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: SamplingParams): if not prompt_ids: raise ValueError("prompt_ids is empty") prompt_tokens = len(prompt_ids) - if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len: - # use long_truncation_mode to truncate long input len req. - if self.args.long_truncation_mode is None: - # 修改默认逻辑,如果 prompt_tokens + max_new_tokens 长度超过总的允许长度,则将 - # 修改 max_new_tokens 的值,使其满足合法约束。 - new_max_new_tokens = self.max_req_total_len - prompt_tokens - if new_max_new_tokens > 0: - logger.debug( - f"the input prompt token len {prompt_tokens} + max_new_tokens" - f"{sampling_params.max_new_tokens} > {self.max_req_total_len}," - f"so change max_new_tokens to {new_max_new_tokens}" - ) - sampling_params.max_new_tokens = new_max_new_tokens - else: - raise ValueError( - f"the input prompt token len {prompt_tokens} + max_new_tokens \ - {sampling_params.max_new_tokens} > {self.max_req_total_len}" - ) - elif self.args.long_truncation_mode == "head": - prompt_ids = prompt_ids[-(self.max_req_total_len - sampling_params.max_new_tokens) :] - elif self.args.long_truncation_mode == "center": - req_input_len = self.max_req_total_len - sampling_params.max_new_tokens - prompt_ids = prompt_ids[0 : req_input_len // 2] + prompt_ids[-(req_input_len - req_input_len // 2) :] - prompt_tokens = len(prompt_ids) - assert prompt_tokens == req_input_len + # 这里 -36 是保留一些不可预知的边界余量,防止系统出错 + real_supported_max_req_total_len = self.get_real_supported_max_req_total_len() + + if prompt_tokens + sampling_params.max_new_tokens > real_supported_max_req_total_len: + + # 修改默认逻辑,如果 prompt_tokens + max_new_tokens 长度超过总的允许长度,则将 + # 修改 max_new_tokens 的值,使其满足合法约束。 + new_max_new_tokens = real_supported_max_req_total_len - prompt_tokens + if new_max_new_tokens > 0: + logger.debug( + f"the input prompt token len {prompt_tokens} + max_new_tokens" + f"{sampling_params.max_new_tokens} > {real_supported_max_req_total_len}," + f"so change max_new_tokens to {new_max_new_tokens}" + ) + sampling_params.max_new_tokens = new_max_new_tokens else: - assert False, "error args" + raise ValueError( + f"the input prompt token len {prompt_tokens} + max_new_tokens \ + {sampling_params.max_new_tokens} > {real_supported_max_req_total_len}" + ) # last repaired req_total_len = len(prompt_ids) + sampling_params.max_new_tokens diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index ae3c3d896..af7a1e29f 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -48,6 +48,11 @@ def __init__( self.per_token_costs = MovingAverage() return + def get_real_supported_max_req_total_len(self): + # HttpServerManager.generate 会借用 _check_and_repair_length(self, ...),其中会调用本方法。 + # PD master 无本地 token 池 shm 计数;上限与启动参数及子节点对齐的 max_req_total_len 一致。 + return self.max_req_total_len + async def register_pd(self, pd_info_json, websocket): self.pd_manager.register_pd(pd_info_json, websocket) return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index f5e0b8df9..24f8da6e6 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -33,6 +33,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from .stats import RouterStatics @@ -60,6 +61,8 @@ def __init__(self, args: StartArgs): self.is_safe_schedule = args.router_token_ratio == 0.0 self.load_way = args.load_way self.max_total_token_num = args.max_total_token_num + # 存储在共享内存中的真实token容量数据 + self.shm_max_total_token_num = SharedInt(f"{get_unique_server_name()}_shm_max_total_token_num") self.shm_req_manager = ShmReqManager() # 用共享内存进行共享,router 模块读取进行精确的调度估计 self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() @@ -185,6 +188,10 @@ async def wait_to_model_ready(self): assert max(_nums) == min(_nums), "all rank must have same token num" self.max_total_token_num = _nums[0] self.args.max_total_token_num = self.max_total_token_num + + self.shm_max_total_token_num.set_value(self.max_total_token_num) + logger.info(f"set shm_max_total_token_num value to {self.shm_max_total_token_num.get_value()}") + if not self.args.disable_dynamic_prompt_cache: self.radix_cache_client = RadixCacheReadOnlyClient( get_unique_server_name(), diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 048869f86..ca982ec0f 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -283,7 +283,6 @@ def init_dp_kv_shared(self): self.dp_kv_shared_module = DPKVSharedMoudle( max_req_num=self.args.running_max_req_size, - max_req_seq_len=self.args.max_req_total_len + 8, dp_size_in_node=self.dp_size_in_node, backend=self, ) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py index b73c0476b..2fa2c9cb9 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py @@ -18,12 +18,11 @@ class DPKVSharedMoudle: _KV_LEN_INDEX = 0 _REQ_IDX_INDEX = 1 - def __init__(self, max_req_num: int, max_req_seq_len: int, dp_size_in_node: int, backend): + def __init__(self, max_req_num: int, dp_size_in_node: int, backend): from .impl import DPChunkedPrefillBackend self.backend: DPChunkedPrefillBackend = backend self.max_req_num = max_req_num - self.max_req_seq_len = max_req_seq_len # 0 代表 kv_len, 1 代表 radix_cache_len self.shared_req_infos = ShmArray( diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 85d5154d8..c64e8a912 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -14,6 +14,138 @@ def get_config_json(model_path: str): return json_obj +def _derive_max_req_total_len_from_model_config(model_dir: str) -> Optional[int]: + """ + Derive `max_req_total_len` from model config.json. + + Keep the derivation aligned with LightLLM's RoPE initialization logic: + - If `max_sequence_length` exists: use it directly. + - Otherwise: use `max_position_embeddings * rope_scaling.factor` (factor defaults to 1.0). + """ + + try: + cfg = get_config_json(model_dir) + except Exception as e: + logger.warning(f"failed to load config.json for max_req_total_len derive: {e}") + return None + + candidates = [cfg] + + llm_cfg = cfg.get("llm_config") + if isinstance(llm_cfg, dict): + candidates.append(llm_cfg) + + text_cfg = cfg.get("text_config") + if isinstance(text_cfg, dict): + candidates.append(text_cfg) + + thinker_cfg = cfg.get("thinker_config") + if isinstance(thinker_cfg, dict): + thinker_text_cfg = thinker_cfg.get("text_config") + if isinstance(thinker_text_cfg, dict): + candidates.append(thinker_text_cfg) + + def _find_key(key: str): + for c in candidates: + if isinstance(c, dict) and key in c and c[key] is not None: + return c.get(key) + return None + + def _find_rope_scaling() -> dict: + rope_scaling = _find_key("rope_scaling") + if rope_scaling is None: + return {} + if isinstance(rope_scaling, dict): + return rope_scaling + return {} + + max_sequence_length = _find_key("max_sequence_length") + if max_sequence_length is not None: + try: + val = int(max_sequence_length) + if val > 0: + return val + except Exception: + return None + + max_position_embeddings = _find_key("max_position_embeddings") + if max_position_embeddings is None: + return None + + rope_scaling = _find_rope_scaling() + rope_type = None + for k in ("rope_type", "type", "__type"): + v = rope_scaling.get(k) + if isinstance(v, str) and v.strip(): + rope_type = v.strip().lower() + break + + # Align with `lightllm/models/llama/model.py` RoPE initialization: + # - `yarn/dynamic/su/llama3`: do NOT multiply by `rope_scaling.factor` for max length. + # - `default/mrope` (and unknown): multiply by factor when present. + no_factor_types = {"yarn", "dynamic", "su", "llama3"} + multiply_factor = True + if rope_type is not None and rope_type in no_factor_types: + multiply_factor = False + + try: + factor_raw = rope_scaling.get("factor", 1.0) + factor = 1.0 if factor_raw is None else float(factor_raw) + except Exception: + factor = 1.0 + + try: + max_pos = float(max_position_embeddings) + val = int(max_pos * factor) if multiply_factor else int(max_pos) + if val > 0: + logger.info( + "auto set max_req_total_len=%s (rope_type=%s,max_position_embeddings=%s,factor=%s, multiply_factor=%s)", + val, + rope_type, + max_position_embeddings, + factor, + multiply_factor, + ) + return val + except Exception: + return None + + return None + + +def auto_set_max_req_total_len(args) -> None: + """ + Ensure `args.max_req_total_len` is an int. + + If the user provides a value, keep it. + If it's None, auto-derive from config.json; fallback to 16384. + """ + + default_fallback = 16384 + if args.max_req_total_len is not None: + return + + model_dir = args.model_dir + if not model_dir: + logger.warning("model_dir is empty; fallback max_req_total_len=16384") + args.max_req_total_len = default_fallback + return + + try: + derived = _derive_max_req_total_len_from_model_config(model_dir) + except Exception as e: + logger.warning(f"failed to derive max_req_total_len from model config: {e}") + derived = None + + if derived is None: + logger.warning(f"cannot derive max_req_total_len from model config; fallback to {default_fallback}") + args.max_req_total_len = default_fallback + return + + args.max_req_total_len = int(derived) + logger.info(f"auto derived max_req_total_len={args.max_req_total_len} from model config") + + def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): config_json = get_config_json(model_path) for key in key_name: diff --git a/test/acc/test_pd_nixl.sh b/test/acc/test_pd_nixl.sh index ba74dec62..8bbd7007e 100644 --- a/test/acc/test_pd_nixl.sh +++ b/test/acc/test_pd_nixl.sh @@ -52,10 +52,17 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=2,3 python -m lightllm.server.api_server \ export http_proxy= export https_proxy= $pd_master_ip 为pd_master的ip地址, 测试的时候,自己修改为对应的ip地址 +# warm up export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval \ --model local-completions --model_args \ '{"model":"qwen/qwen3-8b", "base_url":"http://$pd_master_ip:8089/v1/completions", "max_length": 16384, "tokenized_requests": false}' \ ---tasks gsm8k --batch_size 1 --confirm_run_unsafe_code +--tasks gsm8k --batch_size 1 --confirm_run_unsafe_code --limit 1 + +# test +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval \ +--model local-completions --model_args \ +'{"model":"qwen/qwen3-8b", "base_url":"http://$pd_master_ip:8089/v1/completions", "max_length": 16384, "tokenized_requests": false}' \ +--tasks gsm8k --batch_size 36 --confirm_run_unsafe_code # 1. 按顺序在不同的cmd中启动上面的程序,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。 # 2. 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。