Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions docs/CN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ PD 分离模式参数

.. option:: --max_req_total_len

请求输入长度 + 请求输出长度的最大值,默认为 ``16384``
请求输入长度 + 请求输出长度的最大值。若未显式设置,将从模型配置自动推导,
若推导失败则回退到 ``16384``。
对于部分 RoPE 类型(如 ``yarn/dynamic/su/llama3``),推导不会直接用 ``rope_scaling.factor``
去乘以 ``max_position_embeddings``,以避免过度估算最大长度。

.. option:: --eos_id

Expand Down Expand Up @@ -472,14 +475,6 @@ PD 分离模式参数

使用奖励模型

.. option:: --long_truncation_mode

当 input_token_len + max_new_tokens > max_req_total_len 时的处理方式,可选值:

* ``None``: 抛出异常(默认)
* ``head``: 移除一些头部 token 使 input_token_len + max_new_tokens <= max_req_total_len
* ``center``: 移除中心位置的一些 token 使 input_token_len + max_new_tokens <= max_req_total_len

.. option:: --use_tgi_api

使用 tgi 输入和输出格式
Expand Down
13 changes: 4 additions & 9 deletions docs/EN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ Memory and Batch Processing Parameters

.. option:: --max_req_total_len

Maximum value of request input length + request output length, default is ``16384``
Maximum value of request input length + request output length. If not set, it will be
automatically derived from model config.json and fall back to ``16384`` if derivation fails.
For some RoPE types (like ``yarn/dynamic/su/llama3``), the derivation does not multiply
``rope_scaling.factor`` by ``max_position_embeddings`` to avoid over-estimating the max length.

.. option:: --eos_id

Expand Down Expand Up @@ -472,14 +475,6 @@ Sampling and Generation Parameters

Use reward model

.. option:: --long_truncation_mode

How to handle when input_token_len + max_new_tokens > max_req_total_len, optional values:

* ``None``: Throw exception (default)
* ``head``: Remove some head tokens to make input_token_len + max_new_tokens <= max_req_total_len
* ``center``: Remove some tokens at the center position to make input_token_len + max_new_tokens <= max_req_total_len

.. option:: --use_tgi_api

Use tgi input and output format
Expand Down
4 changes: 3 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def _init_kv_move_buffer(self):

def _check_mem_size(self):
self.max_total_token_num = self.mem_manager.size
assert self.max_seq_length <= self.max_total_token_num
assert (
self.max_total_token_num > self.batch_max_tokens
), "max_total_token_num must be greater than batch_max_tokens"
return

def _init_req_manager(self):
Expand Down
14 changes: 3 additions & 11 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--max_req_total_len",
type=int,
default=16384,
default=None,
help="Maximum allowed length for a request (input tokens + output tokens). "
"If None, it will be automatically derived from the model config.json, "
"and fall back to 16384 if derivation fails. "
"In PD (Prefill-Decode) mode, this value must be synchronized across the "
"PD master, prefill, and decode nodes.",
)
Expand Down Expand Up @@ -457,16 +459,6 @@ def make_argument_parser() -> argparse.ArgumentParser:

parser.add_argument("--use_reward_model", action="store_true", help="use reward model")

parser.add_argument(
"--long_truncation_mode",
type=str,
choices=[None, "head", "center"],
default=None,
help="""use to select the handle way when input_token_len + max_new_tokens > max_req_total_len.
None : raise Exception
head : remove some head tokens to make input_token_len + max_new_tokens <= max_req_total_len
center : remove some tokens in center loc to make input_token_len + max_new_tokens <= max_req_total_len""",
)
parser.add_argument("--use_tgi_api", action="store_true", help="use tgi input and ouput format")
parser.add_argument(
"--health_monitor", action="store_true", help="check the health of service and restart when error"
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ async def anthropic_messages(raw_request: Request) -> Response:
@app.get("/v1/models", response_model=ModelListResponse)
async def get_models(raw_request: Request):
model_name = g_objs.args.model_name
max_model_len = g_objs.args.max_req_total_len
max_model_len = g_objs.httpserver_manager.get_real_supported_max_req_total_len()

if model_name == "default_model_name" and g_objs.args.model_dir:
model_name = os.path.basename(g_objs.args.model_dir.rstrip("/"))

Expand Down
17 changes: 14 additions & 3 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.utils.redis_utils import start_redis_service
from lightllm.utils.shm_size_check import check_recommended_shm_size
from lightllm.utils.config_utils import has_audio_module, has_vision_module, is_linear_att_mixed_model
from lightllm.utils.config_utils import (
has_audio_module,
has_vision_module,
is_linear_att_mixed_model,
auto_set_max_req_total_len,
)
from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args

logger = init_logger(__name__)
Expand Down Expand Up @@ -69,6 +74,7 @@ def normal_or_p_d_start(args):

args: StartArgs = args

auto_set_max_req_total_len(args)
set_unique_server_name(args)

if args.enable_mps:
Expand Down Expand Up @@ -124,10 +130,13 @@ def normal_or_p_d_start(args):
args.running_max_req_size = 3
args.batch_max_tokens = 2048
args.chunked_prefill_size = 1024
args.mem_fraction = 0.85
if args.mem_fraction > 0.82:
args.mem_fraction = 0.82
args.graph_max_batch_size = 32
logger.info(
f"performance_mode is personal, set running_max_req_size to 3,"
f"batch_max_tokens to 2048, chunked_prefill_size to 1024, mem_fraction to 0.85"
f"batch_max_tokens to 2048, chunked_prefill_size to 1024, mem_fraction to 0.82,"
f"graph_max_batch_size to 32"
)

if not args.disable_shm_warning:
Expand Down Expand Up @@ -518,6 +527,8 @@ def pd_master_start(args):
if args.run_mode != "pd_master":
return

auto_set_max_req_total_len(args)

# when use config_server to support multi pd_master node, we
# need generate unique node id for each pd_master node.
# otherwise, we use the 0 for single pd_master node.
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class StartArgs:
dp: int = field(default=1)
nnodes: int = field(default=1)
node_rank: int = field(default=0)
max_req_total_len: int = field(default=2048 + 1024)
# If None, will be automatically derived from model config in `lightllm.server.api_start`.
max_req_total_len: Optional[int] = field(default=None)
nccl_host: str = field(default="127.0.0.1")
nccl_port: int = field(default=None)
use_config_server_to_init_nccl: bool = field(default=False)
Expand Down Expand Up @@ -100,7 +101,6 @@ class StartArgs:
)
return_all_prompt_logprobs: bool = field(default=False)
use_reward_model: bool = field(default=False)
long_truncation_mode: Optional[str] = field(default=None, metadata={"choices": [None, "head", "center"]})
use_tgi_api: bool = field(default=False)
health_monitor: bool = field(default=False)
metric_gateway: Optional[str] = field(default=None)
Expand Down
55 changes: 29 additions & 26 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def __init__(
# If the timemark is not updated for a pre-set time, a prob request will be sent to the backend.
self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
self.latest_success_infer_time_mark.set_value(int(time.time()))

# 用于记录真实的--max_total_token_num 参数,当这个参数在启动参数中没有设置的时候,其是在推理进程中被分析出来的,
# 这个时候如果 --max_req_total_len > --max_total_token_num 时,如果httpserver放过一些非法的输入进入后续的模块可能
# 会触发整个系统崩溃,所以httpserver需要知道真实的 max_total_token_num的数据,用于提前拦截非法请求等参数。
# router 进程会在启动后向这个共享内存写入正确的max_total_token_num 参数,用于后续的请求控制。
self.shm_max_total_token_num = SharedInt(f"{get_unique_server_name()}_shm_max_total_token_num")
return

def _log_stage_timing(self, group_request_id: int, start_time: float, stage: str, **kwargs):
Expand Down Expand Up @@ -542,37 +548,34 @@ async def _encode(
raise ValueError(f"prompt format error, get type{type(prompt)}")
return

def get_real_supported_max_req_total_len(self):
# 得到系统真正能支持的最大长度,同时收到启动参数中模型支持长度的限制,也收到token容量的限制。
return min(self.shm_max_total_token_num.get_value() - 36, self.max_req_total_len)

async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: SamplingParams):
if not prompt_ids:
raise ValueError("prompt_ids is empty")
prompt_tokens = len(prompt_ids)
if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len:
# use long_truncation_mode to truncate long input len req.
if self.args.long_truncation_mode is None:
# 修改默认逻辑,如果 prompt_tokens + max_new_tokens 长度超过总的允许长度,则将
# 修改 max_new_tokens 的值,使其满足合法约束。
new_max_new_tokens = self.max_req_total_len - prompt_tokens
if new_max_new_tokens > 0:
logger.debug(
f"the input prompt token len {prompt_tokens} + max_new_tokens"
f"{sampling_params.max_new_tokens} > {self.max_req_total_len},"
f"so change max_new_tokens to {new_max_new_tokens}"
)
sampling_params.max_new_tokens = new_max_new_tokens
else:
raise ValueError(
f"the input prompt token len {prompt_tokens} + max_new_tokens \
{sampling_params.max_new_tokens} > {self.max_req_total_len}"
)
elif self.args.long_truncation_mode == "head":
prompt_ids = prompt_ids[-(self.max_req_total_len - sampling_params.max_new_tokens) :]
elif self.args.long_truncation_mode == "center":
req_input_len = self.max_req_total_len - sampling_params.max_new_tokens
prompt_ids = prompt_ids[0 : req_input_len // 2] + prompt_ids[-(req_input_len - req_input_len // 2) :]
prompt_tokens = len(prompt_ids)
assert prompt_tokens == req_input_len
# 这里 -36 是保留一些不可预知的边界余量,防止系统出错
real_supported_max_req_total_len = self.get_real_supported_max_req_total_len()

if prompt_tokens + sampling_params.max_new_tokens > real_supported_max_req_total_len:

# 修改默认逻辑,如果 prompt_tokens + max_new_tokens 长度超过总的允许长度,则将
# 修改 max_new_tokens 的值,使其满足合法约束。
new_max_new_tokens = real_supported_max_req_total_len - prompt_tokens
if new_max_new_tokens > 0:
logger.debug(
f"the input prompt token len {prompt_tokens} + max_new_tokens"
f"{sampling_params.max_new_tokens} > {real_supported_max_req_total_len},"
f"so change max_new_tokens to {new_max_new_tokens}"
)
sampling_params.max_new_tokens = new_max_new_tokens
else:
assert False, "error args"
raise ValueError(
f"the input prompt token len {prompt_tokens} + max_new_tokens \
{sampling_params.max_new_tokens} > {real_supported_max_req_total_len}"
)

# last repaired
req_total_len = len(prompt_ids) + sampling_params.max_new_tokens
Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def __init__(
self.per_token_costs = MovingAverage()
return

def get_real_supported_max_req_total_len(self):
# HttpServerManager.generate 会借用 _check_and_repair_length(self, ...),其中会调用本方法。
# PD master 无本地 token 池 shm 计数;上限与启动参数及子节点对齐的 max_req_total_len 一致。
return self.max_req_total_len

async def register_pd(self, pd_info_json, websocket):
self.pd_manager.register_pd(pd_info_json, websocket)
return
Expand Down
7 changes: 7 additions & 0 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from .stats import RouterStatics


Expand Down Expand Up @@ -60,6 +61,8 @@ def __init__(self, args: StartArgs):
self.is_safe_schedule = args.router_token_ratio == 0.0
self.load_way = args.load_way
self.max_total_token_num = args.max_total_token_num
# 存储在共享内存中的真实token容量数据
self.shm_max_total_token_num = SharedInt(f"{get_unique_server_name()}_shm_max_total_token_num")
self.shm_req_manager = ShmReqManager()
# 用共享内存进行共享,router 模块读取进行精确的调度估计
self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager()
Expand Down Expand Up @@ -185,6 +188,10 @@ async def wait_to_model_ready(self):
assert max(_nums) == min(_nums), "all rank must have same token num"
self.max_total_token_num = _nums[0]
self.args.max_total_token_num = self.max_total_token_num

self.shm_max_total_token_num.set_value(self.max_total_token_num)
logger.info(f"set shm_max_total_token_num value to {self.shm_max_total_token_num.get_value()}")

if not self.args.disable_dynamic_prompt_cache:
self.radix_cache_client = RadixCacheReadOnlyClient(
get_unique_server_name(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def init_dp_kv_shared(self):

self.dp_kv_shared_module = DPKVSharedMoudle(
max_req_num=self.args.running_max_req_size,
max_req_seq_len=self.args.max_req_total_len + 8,
dp_size_in_node=self.dp_size_in_node,
backend=self,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ class DPKVSharedMoudle:
_KV_LEN_INDEX = 0
_REQ_IDX_INDEX = 1

def __init__(self, max_req_num: int, max_req_seq_len: int, dp_size_in_node: int, backend):
def __init__(self, max_req_num: int, dp_size_in_node: int, backend):
from .impl import DPChunkedPrefillBackend

self.backend: DPChunkedPrefillBackend = backend
self.max_req_num = max_req_num
self.max_req_seq_len = max_req_seq_len

# 0 代表 kv_len, 1 代表 radix_cache_len
self.shared_req_infos = ShmArray(
Expand Down
Loading
Loading