diff --git a/docs/EN/source/cookbook/qwen35_deployment.rst b/docs/EN/source/cookbook/qwen35_deployment.rst index 36e1288eed..818ce162dc 100644 --- a/docs/EN/source/cookbook/qwen35_deployment.rst +++ b/docs/EN/source/cookbook/qwen35_deployment.rst @@ -233,3 +233,11 @@ Hardware Requirements - ``--tp 8`` required to fit model weights across GPUs - Reduce ``--max_req_total_len`` or ``--graph_max_batch_size`` if encountering OOM errors - Use ``--data_type fp8_e4m3`` for FP8 KV quantization to further reduce memory pressure +- Multimodal deployments get ViT OOM protection by default: when + ``--enable_multimodal`` is on, ``--visual_batch_max_tokens`` is auto-derived + from ``--batch_max_tokens``. The same value caps both per-step batch + output and per-image output (oversized images are auto-resized by the + Qwen-VL ``max_pixels`` clamp; anything still over budget is rejected + before reaching the ViT). To tighten the budget further, pass an explicit + value (e.g. ``--visual_batch_max_tokens 16384``); to opt out and restore + pre-PR behavior, pass ``--visual_batch_max_tokens 0``. diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index fef767d636..b806d06739 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -272,6 +272,27 @@ Multimodal Parameters Number of images processed in each inference batch, default is ``1`` +.. option:: --visual_batch_max_tokens + + Per-step ViT admission budget, measured in image output tokens (post + spatial_merge). The multimodal analogue of ``--batch_max_tokens``: the + ViT scheduler stops adding images to the current batch once their + cumulative ``token_num`` would exceed this value. Useful for bounding + peak ViT memory on dynamic-resolution models (Qwen2.5/3/3.5-VL, etc.) + where one 4K image or long video can contain more patches than many + small images combined. One image is always admitted per step to avoid + deadlock when a single request is larger than the budget — to make that + safe, the same value also drives the per-image budget: oversized images + are auto-resized by the Qwen-VL processor ``max_pixels`` clamp, and any + image that still exceeds the budget is rejected with a ``ValueError`` + before reaching the ViT. + + **Default behavior with** ``--enable_multimodal``: auto-derived from + ``--batch_max_tokens`` so multimodal deployments get OOM protection + without explicit opt-in. Pass an explicit positive integer to override. + Pass ``0`` to opt out and restore the pre-budget behavior (only + ``--visual_infer_batch_size`` applies). + .. option:: --visual_gpu_ids List of GPU IDs to use, e.g., 0 1 2 diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..3e6272810f 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -8,7 +8,8 @@ from io import BytesIO import torch.nn as nn from transformers.activations import ACT2FN -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor, clamp_processor_max_pixels +from lightllm.utils.envs_utils import get_env_start_args from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding @@ -208,6 +209,9 @@ def __init__( with open(processor_config_path, "r") as f: processor_config_dict = json.load(f) self.processor = Qwen2VLImageProcessor(**processor_config_dict) + clamp_processor_max_pixels( + self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="qwen2_5_vl-vit" + ) self._init_datatype() diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 6076756043..fea900f904 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -33,7 +33,8 @@ from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem from lightllm.server.visualserver import get_vit_attn_backend -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor, clamp_processor_max_pixels +from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -244,6 +245,9 @@ def load_model(self, weight_dir): with open(processor_config_path, "r") as f: processor_config_dict = json.load(f) self.processor = Qwen2VLImageProcessor(**processor_config_dict) + clamp_processor_max_pixels( + self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="qwen2_vl-vit" + ) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index bc313fe467..12302b1b1a 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -27,6 +27,61 @@ logger = init_logger(__name__) +def clamp_processor_max_pixels(processor, max_image_tokens, processor_name: str = "") -> None: + """Clamp a Qwen-VL style image processor's max-pixel limit so that even a + max-sized image produces ``token_num <= max_image_tokens``. + + Reuses the processor's built-in ``smart_resize`` mechanism — just tightens + the per-pixel budget so the existing resize path fits the server-wide + per-image token budget (``--visual_batch_max_tokens``). After the clamp, + ``get_image_token_length`` cannot return a value above the budget, so + request-level rejection becomes a defensive no-op in practice. + + Different Qwen-VL generations expose the limit on different attributes: + Qwen2-VL / Qwen2.5-VL / lightllm's own ``Qwen2VLImageProcessor`` use + ``processor.max_pixels``, while HF's Qwen3-VL / Qwen3.5-VL processors store + it in ``processor.size["longest_edge"]``. Both attributes are clamped when + present so any reader (HF runtime, tokenizer ``__init__``) sees the + tightened bound. + + No-op when ``max_image_tokens`` is None or the processor already enforces a + tighter bound. + """ + if max_image_tokens is None: + return + unit = processor.patch_size * processor.merge_size + allowed_max_pixels = max_image_tokens * unit * unit + if allowed_max_pixels < unit * unit: + raise ValueError( + f"max_image_tokens={max_image_tokens} is too small; " + f"need at least 1 patch's worth (={unit * unit} pixels) for {processor_name or 'processor'}." + ) + + # Track originals so the log line shows the pre-clamp values; some + # processors only expose one of the two schemas, so each branch is gated + # on its own attribute presence. + current_max_pixels = getattr(processor, "max_pixels", None) + size = getattr(processor, "size", None) + has_longest_edge = isinstance(size, dict) and "longest_edge" in size + current_longest_edge = size.get("longest_edge") if has_longest_edge else None + + clamped = False + if current_max_pixels is None or allowed_max_pixels < current_max_pixels: + processor.max_pixels = allowed_max_pixels + clamped = True + if has_longest_edge and (current_longest_edge is None or allowed_max_pixels < current_longest_edge): + size["longest_edge"] = allowed_max_pixels + clamped = True + + if clamped: + logger.info( + f"{processor_name or 'processor'}: clamping max_pixels/longest_edge to " + f"{allowed_max_pixels} (was max_pixels={current_max_pixels}, " + f"longest_edge={current_longest_edge}; " + f"max_image_tokens={max_image_tokens}, unit={unit})" + ) + + IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28 diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index 0276724749..381d989b6e 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -27,7 +27,8 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor, clamp_processor_max_pixels +from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention @@ -225,6 +226,9 @@ def load_model(self, weight_dir): with open(processor_config_path, "r") as f: processor_config_dict = json.load(f) self.processor = Qwen2VLImageProcessor(**processor_config_dict) + clamp_processor_max_pixels( + self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="qwen3_omni-vit" + ) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index bed8898115..3523fd92c3 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -27,7 +27,8 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor, clamp_processor_max_pixels +from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention from lightllm.utils.log_utils import init_logger @@ -220,6 +221,9 @@ def load_model(self, weight_dir): with open(processor_config_path, "r") as f: processor_config_dict = json.load(f) self.processor = Qwen2VLImageProcessor(**processor_config_dict) + clamp_processor_max_pixels( + self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="qwen3_vl-vit" + ) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 9deaf08575..3c5afb1f81 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -16,7 +16,8 @@ from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.server.multimodal_params import ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image, clamp_processor_max_pixels +from lightllm.utils.envs_utils import get_env_start_args def add_split_tokens(image_features, image_newline_embed, image_new_embed): @@ -221,6 +222,9 @@ def load_model(self, weight_dir): with open(processor_config_path, "r") as f: processor_config_dict = json.load(f) self.processor = Qwen2VLImageProcessor(**processor_config_dict) + clamp_processor_max_pixels( + self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="tarsier2-vit" + ) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 70b638715e..3cbc8be69f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -472,6 +472,30 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--visual_infer_batch_size", type=int, default=None, help="number of images to process in each inference batch" ) + parser.add_argument( + "--visual_batch_max_tokens", + type=int, + default=None, + help=""" + Per-step ViT admission budget measured in image output tokens (post + spatial_merge). The ViT scheduler stops adding images to the current + batch once their cumulative token_num would exceed this value. Acts as + the multimodal analogue of --batch_max_tokens and caps peak ViT + memory/compute for dynamic-resolution models (Qwen2.5/3/3.5-VL, etc.). + One image is always admitted per step to avoid deadlock when a single + request is larger than the budget — to make that safe, this value + also drives the per-image budget: oversized images are auto-resized + by the Qwen-VL processor max_pixels clamp, and any image that still + exceeds the budget is rejected with a ValueError before reaching the + ViT. + + Default behavior when --enable_multimodal is on: auto-derived from + --batch_max_tokens so multimodal deployments get OOM protection without + explicit opt-in. Pass an explicit positive integer to override; pass 0 + to opt out entirely and restore the pre-budget behavior (only the + image-count cap --visual_infer_batch_size applies). + """, + ) parser.add_argument( "--visual_send_batch_size", type=int, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 2acf8d3626..dce45c11fc 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -272,6 +272,24 @@ def normal_or_p_d_start(args): args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num logger.info(f"set cpu_cache_token_page_size to {args.cpu_cache_token_page_size} for linear hybrid att model") + # 多模态预算默认值(safety-on-by-default for multimodal deployments): + # - 不传:visual_batch_max_tokens 默认等于 batch_max_tokens(LLM 和 ViT 共用预算口径)。 + # - 传 0:显式关闭,恢复 PR 之前的"不限"行为(向后兼容用)。 + # - 传正整数:作为显式预算使用。 + # 同一个值同时充当 per-step batch budget、per-image hard cap 和 processor max_pixels + # clamp 的依据 —— "首图必放行" 规则要求单图必须能塞进一个批次,所以 batch budget 和 + # 单图上限本来就是同一个数。 + if args.enable_multimodal: + if args.visual_batch_max_tokens is None: + args.visual_batch_max_tokens = args.batch_max_tokens + logger.info( + f"visual_batch_max_tokens auto-derived from batch_max_tokens = {args.batch_max_tokens} " + f"(pass --visual_batch_max_tokens 0 to opt out)" + ) + elif args.visual_batch_max_tokens == 0: + logger.info("visual_batch_max_tokens explicitly disabled (=0); ViT token budget off") + args.visual_batch_max_tokens = None + # help to manage data stored on Ceph if "s3://" in args.model_dir: from lightllm.utils.petrel_helper import s3_model_prepare diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index bcc4b3798a..0e17b78d87 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -107,6 +107,7 @@ class StartArgs: push_interval: int = field(default=10) visual_node_id: int = field(default=None) visual_infer_batch_size: int = field(default=None) + visual_batch_max_tokens: Optional[int] = field(default=None) visual_send_batch_size: int = field(default=1) visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) visual_tp: int = field(default=1) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 892e202e2d..f7bc4b78b1 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -35,6 +35,7 @@ from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.multimodal_utils import enforce_image_token_budget from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -179,11 +180,12 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 只有 P 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): items, md5sums, tokens_nums, datas = [], [], [], [] - for img in multimodal_params.images: + for img_index, img in enumerate(multimodal_params.images): self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) + enforce_image_token_budget(token_num, self.args.visual_batch_max_tokens, image_index=img_index) md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) md5sums.append(md5sum) img.md5 = md5sum @@ -236,10 +238,12 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar img_count = 0 audio_tokens = 0 audio_count = 0 - for img in multimodal_params.images: + for img_index, img in enumerate(multimodal_params.images): img_count += 1 self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params) - image_tokens += self.tokenizer.get_image_token_length(img) + token_num = self.tokenizer.get_image_token_length(img) + enforce_image_token_budget(token_num, self.args.visual_batch_max_tokens, image_index=img_index) + image_tokens += token_num for audio in multimodal_params.audios: audio_count += 1 self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index d6a1a58b05..efde914548 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -21,6 +21,7 @@ from lightllm.server.httpserver.manager import AsyncQueue from lightllm.utils.error_utils import ServerBusyError from lightllm.utils.envs_utils import get_pd_split_max_new_tokens +from lightllm.utils.multimodal_utils import enforce_image_token_budget from .pd_selector import create_selector logger = init_logger(__name__) @@ -73,10 +74,12 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar img_count = 0 audio_tokens = 0 audio_count = 0 - for img in multimodal_params.images: + for img_index, img in enumerate(multimodal_params.images): img_count += 1 self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params) - image_tokens += self.tokenizer.get_image_token_length(img) + token_num = self.tokenizer.get_image_token_length(img) + enforce_image_token_budget(token_num, self.args.visual_batch_max_tokens, image_index=img_index) + image_tokens += token_num for audio in multimodal_params.audios: audio_count += 1 self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 25726b2578..247ca3e0d2 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -89,10 +89,19 @@ def get_tokenizer( logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") return DeepSeekV32Tokenizer(hf_tokenizer) + # Qwen-VL family shares a max_pixels clamp helper to keep get_image_token_length + # in sync with visual_batch_max_tokens budget. No-op for non-Qwen-VL tokenizers. + from ..models.qwen2_vl.vision_process import clamp_processor_max_pixels + from lightllm.utils.envs_utils import get_env_start_args + + _start_args = get_env_start_args() + _img_max_tokens = getattr(_start_args, "visual_batch_max_tokens", None) + if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor image_processor = Qwen2VLImageProcessor.from_pretrained(tokenizer_name) + clamp_processor_max_pixels(image_processor, _img_max_tokens, processor_name="tarsier2-tokenizer") tokenizer = Tarsier2Tokenizer(tokenizer=tokenizer, image_processor=image_processor, model_cfg=model_cfg) elif model_type == "llava" or model_type == "internlmxcomposer2": tokenizer = LlavaTokenizer(tokenizer, model_cfg) @@ -102,6 +111,7 @@ def get_tokenizer( from transformers import AutoProcessor processor = AutoProcessor.from_pretrained(tokenizer_name) + clamp_processor_max_pixels(processor.image_processor, _img_max_tokens, processor_name=f"{model_type}-tokenizer") tokenizer = QWen2VLTokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) @@ -109,6 +119,7 @@ def get_tokenizer( from transformers import AutoProcessor processor = AutoProcessor.from_pretrained(tokenizer_name) + clamp_processor_max_pixels(processor.image_processor, _img_max_tokens, processor_name=f"{model_type}-tokenizer") tokenizer = QWen3VLTokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) @@ -117,6 +128,7 @@ def get_tokenizer( from ..models.qwen3_5.model import QWen3_5Tokenizer processor = AutoProcessor.from_pretrained(tokenizer_name) + clamp_processor_max_pixels(processor.image_processor, _img_max_tokens, processor_name=f"{model_type}-tokenizer") tokenizer = QWen3_5Tokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) @@ -125,6 +137,7 @@ def get_tokenizer( model_cfg = model_cfg["thinker_config"] processor = AutoProcessor.from_pretrained(tokenizer_name) + clamp_processor_max_pixels(processor.image_processor, _img_max_tokens, processor_name="qwen3-omni-tokenizer") tokenizer = QWen3OmniTokenizer(tokenizer, processor=processor, model_cfg=model_cfg) elif model_type == "internvl_chat": tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a165be78f2..60bfafc62d 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -85,6 +85,7 @@ async def wait_to_model_ready(self): "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "visual_batch_max_tokens": self.args.visual_batch_max_tokens, "vit_attn_backend": self.vit_attn_backend, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) diff --git a/lightllm/server/visualserver/model_infer/batching.py b/lightllm/server/visualserver/model_infer/batching.py new file mode 100644 index 0000000000..17dec51d98 --- /dev/null +++ b/lightllm/server/visualserver/model_infer/batching.py @@ -0,0 +1,89 @@ +import queue +import threading +from typing import List, Optional + + +def _put_front(infer_queue: "queue.Queue", item) -> None: + """Push ``item`` back to the front of ``infer_queue``. + + ``queue.Queue.put`` appends to the tail, which would reorder pending items + relative to other consumers. The ViT scheduler runs rank-0-only admission + on a queue that every TP rank holds an identical copy of, and rank N + later pops ``len(images)`` items in FIFO order to follow rank 0's + decision. If a rejected item moved to the tail of rank 0's queue, the + queues across ranks would diverge and the next batch would encode + different images on different ranks. Re-inserting at the head preserves + FIFO order on rank 0 and keeps all ranks in sync. + + Note: ``Queue.get`` does *not* decrement ``unfinished_tasks`` — only + ``task_done()`` does. The original ``Queue.put`` already counted this + item, so we must NOT bump the counter again on re-insert; doing so would + desync ``Queue.join()``/``task_done()`` accounting (a latent footgun if + any future caller starts using them on this queue). + """ + with infer_queue.mutex: + infer_queue.queue.appendleft(item) + infer_queue.not_empty.notify() + + +def pull_batch_with_budget( + infer_queue: "queue.Queue", + semaphore: threading.Semaphore, + max_num: int, + max_tokens: Optional[int], + timeout: Optional[float] = None, +) -> List: + """Pull up to ``max_num`` image items from ``infer_queue`` while keeping the + cumulative ``item.token_num`` at or below ``max_tokens``. + + Rank-0-only admission logic for the ViT scheduler. The first item is always + admitted even when it alone exceeds ``max_tokens`` — this avoids a deadlock + when a single request is larger than the per-step budget. Each subsequent + item is pulled, inspected, and either kept or pushed back to the front of + the queue so non-rank-0 workers' FIFO pops stay aligned with rank 0's + admitted set. + + ``semaphore`` counts share with the caller (see ``_init_taskes``); callers + acquire before every get and release on over-pull so the permit count stays + consistent with queue contents. + + When ``timeout`` is not None, the first acquire/get is bounded so rank 0 + can emit a heartbeat broadcast instead of blocking indefinitely on the + gloo broadcast (avoids the 30-minute NCCL-style timeout on idle workers). + An empty list is returned on timeout. + """ + tasks: List = [] + + if timeout is not None: + if not semaphore.acquire(timeout=timeout): + return tasks + try: + first = infer_queue.get(timeout=timeout) + except queue.Empty: + semaphore.release() + return tasks + else: + semaphore.acquire() + first = infer_queue.get(block=True) + tasks.append(first) + total_tokens = first.token_num or 0 + + while len(tasks) < max_num: + try: + task = infer_queue.get(block=False) + except queue.Empty: + break + if not semaphore.acquire(blocking=False): + _put_front(infer_queue, task) + break + + next_tokens = task.token_num or 0 + if max_tokens is not None and total_tokens + next_tokens > max_tokens: + _put_front(infer_queue, task) + semaphore.release() + break + + tasks.append(task) + total_tokens += next_tokens + + return tasks diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 55f4704a31..b856da5845 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -26,6 +26,7 @@ from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend from lightllm.server.embed_cache.afs_utils import SepEmbedHandler +from lightllm.server.visualserver.model_infer.batching import pull_batch_with_budget from lightllm.utils.log_utils import init_logger @@ -53,6 +54,7 @@ def exposed_init_model(self, kvargs): weight_dir = kvargs["weight_dir"] self.infer_max_batch_size = kvargs["max_batch_size"] + self.visual_batch_max_tokens = kvargs.get("visual_batch_max_tokens", None) self.device_id = kvargs["device_id"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] @@ -186,33 +188,37 @@ def _init_taskes(self): def _forward(self, images: List[ImageItem]): return self.model.encode(images) - def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = False) -> List[ImageItem]: + def _get_image_items_from_infer_queue( + self, max_num: int, force_same: bool = False, timeout: float = None + ) -> List[ImageItem]: """ 从队列中批量获取任务,直到达到 max_num 或队列为空。 - """ - tasks = [] - # 至少获取一个任务,阻塞 - self.sempare.acquire() - task = self.infer_queue.get(block=True) - tasks.append(task) - if not force_same: - # 尝试继续获取更多任务,直到达到 max_num - while len(tasks) < max_num: - try: - self.sempare.acquire() - task = self.infer_queue.get(block=False) - tasks.append(task) - except queue.Empty: - self.sempare.release() - break - else: + timeout 仅对首个任务的阻塞等待生效;超时返回空列表。rank 0 使用 + timeout 作为心跳,避免其他 rank 在 gloo broadcast 上长时间无响应 + 而触发 30 分钟超时崩溃。 + + On rank 0 the cumulative ``img.token_num`` is additionally capped by + ``visual_batch_max_tokens`` so a dynamic-resolution image (or batch of + them) cannot blow the ViT's memory budget. The non-rank-0 ``force_same`` + path follows rank 0's already-decided count via the gloo broadcast. + """ + if force_same: + tasks = [] + self.sempare.acquire() + tasks.append(self.infer_queue.get(block=True)) while len(tasks) < max_num: self.sempare.acquire() - task = self.infer_queue.get(block=True) - tasks.append(task) + tasks.append(self.infer_queue.get(block=True)) + return tasks - return tasks + return pull_batch_with_budget( + infer_queue=self.infer_queue, + semaphore=self.sempare, + max_num=max_num, + max_tokens=self.visual_batch_max_tokens, + timeout=timeout, + ) def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: """ @@ -240,12 +246,18 @@ def _infer_worker(self): while True: try: # 从队列获取任务, 阻塞等待 + # rank 0 用带超时的 get, 空闲时也会广播 [0] 当作心跳, + # 避免其他 rank 在 gloo broadcast 上长时间无响应而触发 30 分钟超时崩溃。 if self.tp_rank_id == 0: - images = self._get_image_items_from_infer_queue(max_num=self.infer_max_batch_size) + images = self._get_image_items_from_infer_queue(max_num=self.infer_max_batch_size, timeout=60.0) dist.broadcast_object_list([len(images)], src=0, group=self.gloo_group) + if len(images) == 0: + continue else: ans = [None] dist.broadcast_object_list(ans, src=0, group=self.gloo_group) + if ans[0] == 0: + continue images = self._get_image_items_from_infer_queue(max_num=ans[0], force_same=True) for image in images: diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 27275c1e8c..c6ca15b775 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -117,6 +117,7 @@ async def wait_to_model_ready(self): "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "visual_batch_max_tokens": self.args.visual_batch_max_tokens, "vit_attn_backend": self.vit_attn_backend, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) diff --git a/lightllm/utils/multimodal_utils.py b/lightllm/utils/multimodal_utils.py index 4b49ea8891..f8151665b2 100644 --- a/lightllm/utils/multimodal_utils.py +++ b/lightllm/utils/multimodal_utils.py @@ -4,6 +4,7 @@ import httpx from PIL import Image from io import BytesIO +from typing import Optional from fastapi import Request from functools import lru_cache from lightllm.utils.log_utils import init_logger @@ -11,6 +12,20 @@ logger = init_logger(__name__) +def enforce_image_token_budget(token_num: int, max_tokens: Optional[int], image_index: int = 0) -> None: + """Reject a request when a single image's ``token_num`` exceeds the server + budget. Pairs with the per-step ``--visual_batch_max_tokens`` admission cap: + this guards the batch against one oversized request, since a single image + is always admitted (the "first image always runs" deadlock-avoidance rule). + """ + if max_tokens is not None and token_num > max_tokens: + raise ValueError( + f"image[{image_index}] token_num={token_num} exceeds " + f"visual_batch_max_tokens={max_tokens}; reduce image resolution, " + f"image_max_patch_num (InternVL-family), or preprocessor_config.json::max_pixels (Qwen-VL)" + ) + + def _httpx_async_client_proxy_kwargs(proxy) -> dict: """ httpx 0.28+ 使用 AsyncClient(proxy=...);更早版本使用 proxies=... diff --git a/unit_tests/models/qwen2_vl/test_clamp_processor_max_pixels.py b/unit_tests/models/qwen2_vl/test_clamp_processor_max_pixels.py new file mode 100644 index 0000000000..d561e8bb56 --- /dev/null +++ b/unit_tests/models/qwen2_vl/test_clamp_processor_max_pixels.py @@ -0,0 +1,162 @@ +import importlib.util +import os +import unittest + +# Load the helper directly from its file so we do not trigger heavy imports in +# lightllm.models.* (torch, triton kernels, etc.) just to test a pure function. +_PATH = os.path.normpath( + os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "lightllm", + "models", + "qwen2_vl", + "vision_process.py", + ) +) + + +def _load_helper(): + import sys + import types + + # Stub out heavy imports that vision_process.py pulls at module load. + # Only the pure helper is under test; nothing below depends on these stubs. + for name in ("torch", "numpy", "PIL", "PIL.Image"): + if name not in sys.modules: + sys.modules[name] = types.ModuleType(name) + + if "torchvision" not in sys.modules: + tv = types.ModuleType("torchvision") + tv_t = types.ModuleType("torchvision.transforms") + tv_tv2 = types.ModuleType("torchvision.transforms.v2") + tv_tf = types.ModuleType("torchvision.transforms.v2.functional") + sys.modules["torchvision"] = tv + sys.modules["torchvision.transforms"] = tv_t + sys.modules["torchvision.transforms.v2"] = tv_tv2 + sys.modules["torchvision.transforms.v2.functional"] = tv_tf + + # The file imports transformers pieces; stub them. + if "transformers" not in sys.modules: + sys.modules["transformers"] = types.ModuleType("transformers") + for sub in ( + "transformers.image_utils", + "transformers.image_processing_utils_fast", + "transformers.image_transforms", + ): + if sub not in sys.modules: + sys.modules[sub] = types.ModuleType(sub) + + spec = importlib.util.spec_from_file_location("_vp_under_test", _PATH) + mod = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(mod) + except Exception: + # If stubs aren't enough to import the whole file, fall back to + # reading the function source and exec'ing it directly. + with open(_PATH, "r") as f: + src = f.read() + start = src.index("def clamp_processor_max_pixels") + # Find the end — the next "def " at column 0. + tail = src[start:] + next_def = tail.find("\ndef ", 1) + fn_src = tail[:next_def] if next_def != -1 else tail + ns = {} + # Substitute logger with a noop. + import logging + + ns["logger"] = logging.getLogger("clamp_test") + exec("from typing import Optional\n" + fn_src, ns) + return ns["clamp_processor_max_pixels"] + return mod.clamp_processor_max_pixels + + +clamp_processor_max_pixels = _load_helper() + + +class _FakeProcessor: + def __init__(self, patch_size, merge_size, max_pixels): + self.patch_size = patch_size + self.merge_size = merge_size + self.max_pixels = max_pixels + + +class TestClampProcessorMaxPixels(unittest.TestCase): + def test_none_budget_is_noop(self): + p = _FakeProcessor(patch_size=14, merge_size=2, max_pixels=16384 * 28 * 28) + clamp_processor_max_pixels(p, None) + self.assertEqual(p.max_pixels, 16384 * 28 * 28) + + def test_budget_looser_than_processor_is_noop(self): + # Processor's max_pixels already gives 16384 tokens. Budget is 32768. Keep smaller. + p = _FakeProcessor(patch_size=14, merge_size=2, max_pixels=16384 * 28 * 28) + clamp_processor_max_pixels(p, max_image_tokens=32768) + self.assertEqual(p.max_pixels, 16384 * 28 * 28) + + def test_budget_tighter_clamps(self): + # patch=14, merge=2 -> unit=28, unit^2=784. Budget 4096 tokens -> 4096*784 pixels. + p = _FakeProcessor(patch_size=14, merge_size=2, max_pixels=16384 * 28 * 28) + clamp_processor_max_pixels(p, max_image_tokens=4096) + self.assertEqual(p.max_pixels, 4096 * 28 * 28) + + def test_budget_equal_to_original_is_noop(self): + # Original gives exactly 16384 tokens. Budget 16384 -> same value. + p = _FakeProcessor(patch_size=14, merge_size=2, max_pixels=16384 * 28 * 28) + clamp_processor_max_pixels(p, max_image_tokens=16384) + self.assertEqual(p.max_pixels, 16384 * 28 * 28) + + def test_budget_zero_raises(self): + p = _FakeProcessor(patch_size=14, merge_size=2, max_pixels=16384 * 28 * 28) + with self.assertRaises(ValueError): + clamp_processor_max_pixels(p, max_image_tokens=0) + + def test_different_patch_merge(self): + # patch=16, merge=1 -> unit=16, unit^2=256. Budget 1000 tokens -> 256000 pixels. + p = _FakeProcessor(patch_size=16, merge_size=1, max_pixels=10_000_000) + clamp_processor_max_pixels(p, max_image_tokens=1000) + self.assertEqual(p.max_pixels, 1000 * 16 * 16) + + def test_processor_max_pixels_none_is_clamped(self): + # HF Qwen3.5-VL's processor exposes max_pixels=None (no intrinsic upper + # bound); the clamp must treat that as "looser than any budget" and + # always apply our allowed_max_pixels instead of crashing on int 600 > 400 -> stop and put 300 back. + self.assertEqual([g.token_num for g in got], [100, 200]) + self.assertEqual(q.qsize(), 2) + remaining = [q.get_nowait().token_num for _ in range(q.qsize())] + self.assertIn(300, remaining) + self.assertIn(400, remaining) + + def test_first_image_always_admitted_even_if_over_budget(self): + q, sem = _setup([10_000, 5]) + got = pull_batch_with_budget(q, sem, max_num=10, max_tokens=100) + self.assertEqual([g.token_num for g in got], [10_000]) + self.assertEqual(q.qsize(), 1) + + def test_single_item_queue(self): + q, sem = _setup([42]) + got = pull_batch_with_budget(q, sem, max_num=5, max_tokens=1000) + self.assertEqual([g.token_num for g in got], [42]) + self.assertEqual(q.qsize(), 0) + + def test_budget_at_exact_boundary_admits(self): + q, sem = _setup([100, 200, 300]) + got = pull_batch_with_budget(q, sem, max_num=10, max_tokens=300) + # 100 + 200 = 300 == budget -> admit; +300 -> 600 > 300 -> stop. + self.assertEqual([g.token_num for g in got], [100, 200]) + + def test_none_token_num_treated_as_zero(self): + q = queue.Queue() + q.put(_FakeImg(100)) + q.put(_FakeImg(None)) + q.put(_FakeImg(50)) + sem = threading.Semaphore(3) + got = pull_batch_with_budget(q, sem, max_num=10, max_tokens=100) + # 100 (admitted first), 0 (None) -> 100 admitted, +50 -> 150 > 100 -> stop. + self.assertEqual([g.token_num for g in got], [100, None]) + self.assertEqual(q.qsize(), 1) + + def test_max_num_respected_under_budget(self): + q, sem = _setup([10, 10, 10, 10, 10]) + got = pull_batch_with_budget(q, sem, max_num=3, max_tokens=10_000) + self.assertEqual(len(got), 3) + self.assertEqual(q.qsize(), 2) + + def test_semaphore_permits_match_returned_items(self): + # After the pull, permits consumed must equal len(returned) so the outer + # backpressure accounting in _store_worker releases the right count. + q, sem = _setup([100, 200, 300, 400, 500]) + permits_before = sem._value + got = pull_batch_with_budget(q, sem, max_num=10, max_tokens=400) + permits_after = sem._value + self.assertEqual(permits_before - permits_after, len(got)) + + def test_semaphore_permits_match_on_queue_empty(self): + q, sem = _setup([100, 200]) + permits_before = sem._value + got = pull_batch_with_budget(q, sem, max_num=10, max_tokens=None) + permits_after = sem._value + self.assertEqual(permits_before - permits_after, len(got)) + self.assertEqual(len(got), 2) + + def test_rejected_item_returns_to_front_preserves_fifo(self): + # TP-correctness regression: after rank 0's budget admission, the + # residual queue must equal the original FIFO order with the admitted + # prefix removed. Other TP ranks pop ``len(returned)`` from their own + # identical queues, so any reorder on rank 0 makes ranks encode + # different images on the next step. + q, sem = _setup([100, 500, 100]) + got = pull_batch_with_budget(q, sem, max_num=10, max_tokens=200) + # 100 admitted, 500 over-budget -> rejected, loop breaks. + self.assertEqual([g.token_num for g in got], [100]) + remaining_in_order = [q.get_nowait().token_num for _ in range(q.qsize())] + self.assertEqual(remaining_in_order, [500, 100]) + + def test_rejected_on_sem_exhaustion_returns_to_front(self): + # Semaphore-skip path mirrors the budget-skip path: the popped item + # must end up at the front of the queue, not the tail. + q = queue.Queue() + for tn in [100, 200, 300]: + q.put(_FakeImg(tn)) + # Two permits => first acquire (200) succeeds, second (300) fails. + sem = threading.Semaphore(2) + got = pull_batch_with_budget(q, sem, max_num=10, max_tokens=10_000) + self.assertEqual([g.token_num for g in got], [100, 200]) + remaining_in_order = [q.get_nowait().token_num for _ in range(q.qsize())] + self.assertEqual(remaining_in_order, [300]) + + def test_unfinished_tasks_stays_consistent_through_reject(self): + # Queue.put bumps unfinished_tasks; Queue.get does NOT decrement it + # (only task_done() does). The reject path re-inserts at the front + # and must not bump the counter again — otherwise Queue.join() would + # hang forever even after every consumed item is task_done()'d. + q, sem = _setup([100, 500, 100]) + # 3 items put in => unfinished_tasks == 3 + self.assertEqual(q.unfinished_tasks, 3) + got = pull_batch_with_budget(q, sem, max_num=10, max_tokens=200) + self.assertEqual([g.token_num for g in got], [100]) + # One item consumed (returned to caller, awaiting task_done), two + # still pending in the queue. Counter should still match the number + # of logical outstanding tasks: 3. + self.assertEqual(q.unfinished_tasks, 3) + # task_done for the returned item, then drain the rest with task_done. + q.task_done() + self.assertEqual(q.qsize(), 2) + for _ in range(q.qsize()): + q.get_nowait() + q.task_done() + # join() must return promptly; if the counter were corrupted it would hang. + q.join() + + +if __name__ == "__main__": + unittest.main() diff --git a/unit_tests/utils/test_image_token_budget.py b/unit_tests/utils/test_image_token_budget.py new file mode 100644 index 0000000000..8155216263 --- /dev/null +++ b/unit_tests/utils/test_image_token_budget.py @@ -0,0 +1,47 @@ +import importlib.util +import os +import unittest + +# Load the helper directly to avoid triggering heavy package imports (torch, +# atomics, etc.) that the full lightllm package pulls in. +_UTILS_PATH = os.path.normpath( + os.path.join( + os.path.dirname(__file__), + "..", + "..", + "lightllm", + "utils", + "multimodal_utils.py", + ) +) +_spec = importlib.util.spec_from_file_location("_mm_utils_under_test", _UTILS_PATH) +_module = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_module) +enforce_image_token_budget = _module.enforce_image_token_budget + + +class TestEnforceImageTokenBudget(unittest.TestCase): + def test_none_budget_allows_anything(self): + enforce_image_token_budget(token_num=10_000_000, max_tokens=None) + + def test_under_budget_ok(self): + enforce_image_token_budget(token_num=1000, max_tokens=1024) + + def test_at_budget_ok(self): + enforce_image_token_budget(token_num=1024, max_tokens=1024) + + def test_over_budget_raises(self): + with self.assertRaises(ValueError) as cm: + enforce_image_token_budget(token_num=2048, max_tokens=1024, image_index=3) + msg = str(cm.exception) + self.assertIn("image[3]", msg) + self.assertIn("2048", msg) + self.assertIn("1024", msg) + + def test_zero_budget_rejects_any_positive_tokens(self): + with self.assertRaises(ValueError): + enforce_image_token_budget(token_num=1, max_tokens=0) + + +if __name__ == "__main__": + unittest.main()